from collections import deque, defaultdict
import gc
import logging
from typing import Any, Callable, Hashable, List, Optional, Literal
import dask
import ray
from ray.util.dask import ray_dask_get
from deisa.ray._scheduler import deisa_ray_get
from deisa.ray.config import config
from deisa.ray.errors import _default_exception_handler
from deisa.ray.head_node import HeadNodeActor
from deisa.ray.types import (
ActorID,
DeisaArray,
RayActorHandle,
WindowSpec,
_CallbackConfig,
)
from deisa.ray.utils import get_head_actor_options
def _ray_start_impl() -> None:
"""
Default Ray startup procedure used by :class:`Deisa`.
Notes
-----
Initializes Ray only once with minimal logging. Used when the caller
does not provide a custom ``ray_start`` hook.
"""
if not ray.is_initialized():
ray.init(address="auto", log_to_driver=False, logging_level=logging.ERROR)
[docs]
class Deisa:
"""
Entry point that orchestrates analytics callbacks on Ray.
Provides an API for registering sliding window callbacks and executing
them as arrays arrive from simulation ranks.
"""
def __init__(
self,
*,
ray_start: Optional[Callable[[], None]] = None,
max_simulation_ahead: int = 1,
) -> None:
"""
Initialize handler state without touching Ray.
Parameters
----------
ray_start : Callable[[], None], optional
Custom callable used to start Ray. Defaults to a built-in helper.
max_simulation_ahead : int, optional
Number of timesteps the analytics may lag behind the simulation.
Defaults to 1.
"""
# cheap constructor: no Ray side effects
config.lock()
self._experimental_distributed_scheduling_enabled = config.experimental_distributed_scheduling_enabled
# Do NOT mutate global config here if you want cheap unit tests;
# do it when connecting, or inject it similarly.
self._ray_start = ray_start or _ray_start_impl
self._connected = False
self.node_actors: dict[ActorID, RayActorHandle] = {}
self.registered_callbacks: list[_CallbackConfig] = []
self.queue_per_array: dict[str, deque]
self.max_simulation_ahead: int = max_simulation_ahead
self.has_new_timestep: dict[str, bool] = defaultdict(bool)
self.has_seen_array: dict[str, bool] = defaultdict(bool)
self.queue_per_array = {}
def _ensure_connected(self) -> None:
"""
Ensure the handler is connected to Ray and has a head actor ready.
Notes
-----
Starts Ray (if needed), configures Dask to use the correct scheduler,
creates the head actor, and exchanges configuration so that
scheduling actors can register themselves.
"""
if self._connected:
return
# Side effects begin here (only once)
self._ray_start()
# configure dask here if it must reflect actual cluster runtime
if self._experimental_distributed_scheduling_enabled:
dask.config.set(scheduler=deisa_ray_get, shuffle="tasks")
else:
dask.config.set(scheduler=ray_dask_get, shuffle="tasks")
# head is created
self._create_head_actor()
# readiness gate for head actor - only return when its alive
ray.get(
self.head.exchange_config.remote(
{"experimental_distributed_scheduling_enabled": self._experimental_distributed_scheduling_enabled}
)
)
self._connected = True
def _create_head_actor(self) -> None:
"""
Instantiate the head actor that coordinates array delivery.
Notes
-----
Uses :func:`get_head_actor_options` to pin the actor to the Ray head node
with a detached lifetime so that analytics can connect later.
"""
self.head = HeadNodeActor.options(**get_head_actor_options()).remote(
max_simulation_ahead=self.max_simulation_ahead
)
[docs]
def callback(
self,
*window_specs,
exception_handler: Optional[Callable] = None,
when: Literal["AND", "OR"] = "AND",
):
"""
Decorator that registers a sliding-window analytics callback.
Parameters
----------
*window_specs : WindowSpec
Array descriptions the callback should receive.
exception_handler : Optional[Callable], optional
Handler invoked when the user callback raises. Defaults to
:func:`deisa.ray.errors._default_exception_handler`.
when : Literal["AND", "OR"], optional
Governs whether all arrays (``"AND"``) or any array (``"OR"``)
must be available before the callback runs. Defaults to ``"AND"``.
Returns
-------
Callable
Decorator that registers ``simulation_callback`` with the window
handler.
"""
def deco(fn):
return self.register_callback(fn, list(window_specs), exception_handler, when)
return deco
[docs]
def register_callback(
self,
simulation_callback: Callable,
arrays_spec: list[WindowSpec],
exception_handler: Optional[Callable] = None,
when: Literal["AND", "OR"] = "AND",
) -> Callable:
"""
Register the analytics callback and array descriptions.
Parameters
----------
simulation_callback : Callable
Function to run for each iteration; receives arrays as kwargs
and ``timestep``.
arrays_spec : list[WindowSpec]
Descriptions of arrays to stream to the callback (with optional
sliding windows).
Maximum iterations to execute. Default is a large sentinel.
exception_handler : Optional[Callable]
Exception handler to handle any exception thrown by simulation
(like division by zero). Defaults to printing the error and moving on.
when : Literal['AND', 'OR']
When callback have multiple arrays, govern when callback should be called.
`AND`: only call callback if ALL required arrays have been shared for a given timestep.
`OR`: call callback if ANY array has been shared for a given timestep.
Returns
-------
Callable
The original callback, allowing decorator-style usage.
"""
self._ensure_connected() # connect + handshake before accepting callbacks
cfg = _CallbackConfig(
simulation_callback=simulation_callback,
arrays_description=arrays_spec,
exception_handler=exception_handler or _default_exception_handler,
when=when,
)
self.registered_callbacks.append(cfg)
return simulation_callback
[docs]
def unregister_callback(
self,
simulation_callback: Callable,
) -> None:
"""
Unregister a previously registered simulation callback.
Parameters
----------
simulation_callback : Callable
Callback to remove from the registry.
Raises
------
NotImplementedError
Always, as the feature has not been implemented yet.
"""
raise NotImplementedError("method not yet implemented.")
[docs]
def generate_queue_per_array(self):
"""
Prepare per-array queues that respect declared window sizes.
Notes
-----
Each queue is a :class:`collections.deque` with ``maxlen`` matching the
largest window requested for that array.
"""
for cb_cfg in self.registered_callbacks:
description = cb_cfg.arrays_description
for array_def in description:
name = array_def.name
window_size: int = array_def.window_size if array_def.window_size is not None else 1
if name in self.queue_per_array:
if self.queue_per_array[name].maxlen < window_size:
self.queue_per_array[name] = deque(maxlen=window_size)
else:
pass
else:
self.queue_per_array[name] = deque(maxlen=window_size)
[docs]
def execute_callbacks(
self,
) -> None:
"""
Execute the registered simulation callback loop.
Notes
-----
Supports a single registered callback at present. Manages array
retrieval from the head actor, windowed
array delivery, and garbage collection between iterations.
"""
# ensure connected to ray cluster
self._ensure_connected()
# signal analytics ready to start
ray.get(self.head.set_analytics_ready_for_execution.remote())
# ray.get(self.head.wait_for_bridges_ready.remote())
# TODO: test
# raise error and kill analytics
if not self.registered_callbacks:
raise RuntimeError("Please register at least one callback before calling execute_callbacks()")
# generate one queue per array which cleanly handles the window size
self.generate_queue_per_array()
# get first array to kickstart the process
# - Add to queue, mark as new timestep arrived
name, arr_timestep, array = ray.get(self.head.get_next_array.remote())
if name == "__deisa_last_iteration_array":
return
queue = self.queue_per_array.get(name)
if queue is not None:
queue.append(DeisaArray(dask=array, t=arr_timestep))
self.has_new_timestep[name] = True
self.has_seen_array[name] = True
end_reached = False
while not end_reached:
# inner while loop stops once a bigger timestep has been pushed to queue
# WARNING: Big assumption is that it is impossible for any array in timestep i+1 to be placed
# BEFORE timestep i. This is violated in embarrassingly parallel workflows where each rank can go ahead
# independently. Without this assumption, it would be much more complex to determine a good moment to analyze
# which callbacks should be called - as such, memory handling and flow execution become difficult to
# guarantee.
current_timestep = arr_timestep
while True:
name, arr_timestep, array = ray.get(self.head.get_next_array.remote())
# guarantee sequential flow of data.
# TODO add test
if arr_timestep < current_timestep:
raise RuntimeError(
f"Logical flow of data was violated. Timestep {arr_timestep} sent after timestep {current_timestep}. Exiting..."
)
if name == "__deisa_last_iteration_array":
end_reached = True
break
# simulation has produced a higher timestep -> process all arrays for current_timestep
if arr_timestep > current_timestep:
break
queue = self.queue_per_array.get(name)
if queue is not None:
queue.append(DeisaArray(dask=array, t=arr_timestep))
self.has_new_timestep[name] = True
self.has_seen_array[name] = True
# inspect what callbacks can be called
for cb_cfg in self.registered_callbacks:
simulation_callback = cb_cfg.simulation_callback
description_arrays_needed = cb_cfg.arrays_description
exception_handler = cb_cfg.exception_handler
when = cb_cfg.when
should_call = self.should_call(description_arrays_needed, when)
if should_call:
# Compute the arrays to pass to the callback
callback_args: dict[str, List[DeisaArray]] = self.determine_callback_args(description_arrays_needed)
try:
simulation_callback(**callback_args)
except TimeoutError as e:
raise e
except AssertionError as e:
raise e
except BaseException as e:
try:
exception_handler(e)
except BaseException as e:
_default_exception_handler(e)
del callback_args
gc.collect()
# set all new timesteps to be false
for queue in self.has_new_timestep:
self.has_new_timestep[queue] = False
# add the first "bigger" timestep back into queue and set new_timestep flag
if not end_reached:
queue = self.queue_per_array.get(name)
if queue is not None:
queue.append(DeisaArray(dask=array, t=arr_timestep))
self.has_new_timestep[name] = True
self.has_seen_array[name] = True
[docs]
def determine_callback_args(self, description_of_arrays_needed) -> dict[str, List[DeisaArray]]:
"""
Build the kwargs passed to a simulation callback.
Parameters
----------
description_of_arrays_needed : Sequence[WindowSpec]
Array descriptions requested by the callback.
Returns
-------
dict[str, List[DeisaArray]]
Mapping from array name to the latest (windowed) list of ``DeisaArray`` instances.
"""
callback_args = {}
for description in description_of_arrays_needed:
name = description.name
window_size = description.window_size
queue = self.queue_per_array[name]
if window_size is None:
callback_args[name] = [queue[-1]]
else:
callback_args[name] = list(queue)[-window_size:]
return callback_args
[docs]
def should_call(self, description_of_arrays_needed, when: Literal["AND", "OR"]) -> bool:
"""
Determine whether a callback should execute for the current state.
Parameters
----------
description_of_arrays_needed : Sequence[WindowSpec]
Array descriptions governing the callback.
when : Literal["AND", "OR"]
Execution mode specifying whether all arrays or any array must have
new data.
Returns
-------
bool
``True`` when the callback criteria are met.
"""
names = [d.name for d in description_of_arrays_needed]
if when == "AND":
return all(self.has_new_timestep[n] for n in names)
else: # when == 'OR'
return all(self.has_seen_array[n] for n in names) and any(self.has_new_timestep[n] for n in names)
# TODO add persist
[docs]
def set(self, *, key: Hashable, value: Any, chunked: bool = False, persist: bool = False) -> None:
"""
Broadcast a feedback value to all scheduling actors.
Parameters
----------
key : Hashable
Identifier for the shared value.
value : Any
Value to distribute.
chunked : bool, optional
Placeholder for future distributed-array feedback. Only ``False``
is supported today. Default is ``False``.
persist : bool, optional
Whether the value should survive the next retrieval.
Defaults to ``False``.
Notes
-----
The method lazily fetches node actors once and uses fire-and-forget
remote calls; callers should not assume synchronous delivery.
"""
# TODO test
if not self.node_actors:
# retrieve node actors at least once
self.node_actors = ray.get(self.head.list_scheduling_actors.remote())
if not chunked:
for _, handle in self.node_actors.items():
# set the value inside each node actor
# TODO: does it need to be blocking?
handle.set.remote(key, value, chunked, persist)
else:
# TODO: implement chunked version
raise NotImplementedError()