from collections import deque, defaultdict
import gc
import logging
from typing import Any, Callable, Hashable, List, Optional, Literal
import dask
import ray
from ray.util.dask import ray_dask_get
from deisa.ray._scheduler import deisa_ray_get
from deisa.ray.config import config
from deisa.ray.errors import _default_exception_handler
from deisa.ray.head_node import HeadNodeActor
from deisa.ray.types import (
ActorID,
DeisaArray,
RayActorHandle,
Timestep,
WindowSpec,
_CallbackConfig,
)
from deisa.ray.utils import get_head_actor_options
def _ray_start_impl() -> None:
"""
Default Ray startup procedure used by :class:`Deisa`.
Notes
-----
Initializes Ray only once with minimal logging. Used when the caller
does not provide a custom ``ray_start`` hook.
"""
if not ray.is_initialized():
ray.init(address="auto", log_to_driver=False, logging_level=logging.ERROR)
[docs]
class Deisa:
"""
Entry point that orchestrates analytics callbacks on Ray.
Provides an API for registering sliding window callbacks and executing
them as arrays arrive from simulation ranks.
"""
def __init__(
self,
*,
ray_start: Optional[Callable[[], None]] = None,
max_simulation_ahead: int = 1,
feedback_queue_size: int = 1024,
) -> None:
"""
Initialize handler state without touching Ray.
Parameters
----------
ray_start : Callable[[], None], optional
Custom callable used to start Ray. Defaults to a built-in helper.
max_simulation_ahead : int, optional
Number of timesteps the analytics may lag behind the simulation.
Defaults to 1.
"""
# cheap constructor: no Ray side effects
config.lock()
self._experimental_distributed_scheduling_enabled = config.experimental_distributed_scheduling_enabled
# Do NOT mutate global config here if you want cheap unit tests;
# do it when connecting, or inject it similarly.
self._ray_start = ray_start or _ray_start_impl
self._connected = False
self.node_actors: dict[ActorID, RayActorHandle] = {}
self.registered_callbacks: list[_CallbackConfig] = []
self.queue_per_array: dict[str, deque]
self.max_simulation_ahead: int = max_simulation_ahead
self.feedback_queue_size: int = feedback_queue_size
self.has_new_timestep: dict[str, bool] = defaultdict(bool)
self.has_seen_array: dict[str, bool] = defaultdict(bool)
self.queue_per_array = {}
def _ensure_connected(self) -> None:
"""
Ensure the handler is connected to Ray and has a head actor ready.
Notes
-----
Starts Ray (if needed), creates the head actor, and exchanges
configuration so that scheduling actors can register themselves.
"""
if self._connected:
return
# Side effects begin here (only once)
self._ray_start()
# head is created
self._create_head_actor()
# readiness gate for head actor - only return when its alive
ray.get(
self.head.exchange_config.remote(
{"experimental_distributed_scheduling_enabled": self._experimental_distributed_scheduling_enabled}
)
)
self._connected = True
def _dask_config(self):
"""
Return the Dask scheduler config needed while executing analytics callbacks.
The scheduler is intentionally scoped by callers with Dask's config
context manager. Leaving it set globally in the driver process leaks
into unrelated Dask computations in later tests or user code.
"""
if self._experimental_distributed_scheduling_enabled:
return dask.config.set(scheduler=deisa_ray_get, shuffle="tasks")
return dask.config.set(scheduler=ray_dask_get, shuffle="tasks")
def _create_head_actor(self) -> None:
"""
Instantiate the head actor that coordinates array delivery.
Notes
-----
Uses :func:`get_head_actor_options` to pin the actor to the Ray head node
with a detached lifetime so that analytics can connect later.
"""
self.head = HeadNodeActor.options(**get_head_actor_options()).remote(
max_simulation_ahead=self.max_simulation_ahead,
feedback_queue_size=self.feedback_queue_size,
)
[docs]
def callback(
self,
*window_specs,
exception_handler: Optional[Callable] = None,
when: Literal["AND", "OR"] = "AND",
):
"""
Decorator that registers a sliding-window analytics callback.
Parameters
----------
*window_specs : WindowSpec
Array descriptions the callback should receive.
exception_handler : Optional[Callable], optional
Handler invoked when the user callback raises. Defaults to
:func:`deisa.ray.errors._default_exception_handler`.
when : Literal["AND", "OR"], optional
Governs whether all arrays (``"AND"``) or any array (``"OR"``)
must be available before the callback runs. Defaults to ``"AND"``.
Returns
-------
Callable
Decorator that registers ``simulation_callback`` with the window
handler.
"""
def deco(fn):
return self.register_callback(fn, list(window_specs), exception_handler, when)
return deco
[docs]
def register_callback(
self,
simulation_callback: Callable,
arrays_spec: list[WindowSpec],
exception_handler: Optional[Callable] = None,
when: Literal["AND", "OR"] = "AND",
) -> Callable:
"""
Register the analytics callback and array descriptions.
Parameters
----------
simulation_callback : Callable
Function to run for each iteration; receives arrays as kwargs
and ``timestep``.
arrays_spec : list[WindowSpec]
Descriptions of arrays to stream to the callback (with optional
sliding windows).
Maximum iterations to execute. Default is a large sentinel.
exception_handler : Optional[Callable]
Exception handler to handle any exception thrown by simulation
(like division by zero). Defaults to printing the error and moving on.
when : Literal['AND', 'OR']
When callback have multiple arrays, govern when callback should be called.
`AND`: only call callback if ALL required arrays have been shared for a given timestep.
`OR`: call callback if ANY array has been shared for a given timestep.
Returns
-------
Callable
The original callback, allowing decorator-style usage.
"""
self._ensure_connected() # connect + handshake before accepting callbacks
cfg = _CallbackConfig(
simulation_callback=simulation_callback,
arrays_description=arrays_spec,
exception_handler=exception_handler or _default_exception_handler,
when=when,
)
self.registered_callbacks.append(cfg)
return simulation_callback
[docs]
def unregister_callback(
self,
simulation_callback: Callable,
) -> None:
"""
Unregister a previously registered simulation callback.
Parameters
----------
simulation_callback : Callable
Callback to remove from the registry.
Raises
------
NotImplementedError
Always, as the feature has not been implemented yet.
"""
raise NotImplementedError("method not yet implemented.")
[docs]
def generate_queue_per_array(self):
"""
Prepare per-array queues that respect declared window sizes.
Notes
-----
Each queue is a :class:`collections.deque` with ``maxlen`` matching the
largest window requested for that array.
"""
for cb_cfg in self.registered_callbacks:
description = cb_cfg.arrays_description
for array_def in description:
name = array_def.name
window_size: int = array_def.window_size if array_def.window_size is not None else 1
if name in self.queue_per_array:
if self.queue_per_array[name].maxlen < window_size:
self.queue_per_array[name] = deque(maxlen=window_size)
else:
pass
else:
self.queue_per_array[name] = deque(maxlen=window_size)
[docs]
def execute_callbacks(
self,
) -> None:
"""
Execute the registered simulation callback loop.
Notes
-----
Supports a single registered callback at present. Manages array
retrieval from the head actor, windowed
array delivery, and garbage collection between iterations.
"""
# ensure connected to ray cluster
self._ensure_connected()
with self._dask_config():
# signal analytics ready to start
ray.get(self.head.set_analytics_ready_for_execution.remote())
# ray.get(self.head.wait_for_bridges_ready.remote())
# TODO: test
# raise error and kill analytics
if not self.registered_callbacks:
raise RuntimeError("Please register at least one callback before calling execute_callbacks()")
# generate one queue per array which cleanly handles the window size
self.generate_queue_per_array()
# get first array to kickstart the process
# - Add to queue, mark as new timestep arrived
name, arr_timestep, array = ray.get(self.head.get_next_array.remote())
if name == "__deisa_last_iteration_array":
return
queue = self.queue_per_array.get(name)
if queue is not None:
queue.append(DeisaArray(dask=array, t=arr_timestep))
self.has_new_timestep[name] = True
self.has_seen_array[name] = True
end_reached = False
while not end_reached:
# inner while loop stops once a bigger timestep has been pushed to queue
# WARNING: Big assumption is that it is impossible for any array in timestep i+1 to be placed
# BEFORE timestep i. This is violated in embarrassingly parallel workflows where each rank can go ahead
# independently. Without this assumption, it would be much more complex to determine a good moment to analyze
# which callbacks should be called - as such, memory handling and flow execution become difficult to
# guarantee.
current_timestep = arr_timestep
while True:
name, arr_timestep, array = ray.get(self.head.get_next_array.remote())
# guarantee sequential flow of data.
# TODO add test
if arr_timestep < current_timestep:
raise RuntimeError(
f"Logical flow of data was violated. Timestep {arr_timestep} sent after timestep {current_timestep}. Exiting..."
)
if name == "__deisa_last_iteration_array":
end_reached = True
break
# simulation has produced a higher timestep -> process all arrays for current_timestep
if arr_timestep > current_timestep:
break
queue = self.queue_per_array.get(name)
if queue is not None:
queue.append(DeisaArray(dask=array, t=arr_timestep))
self.has_new_timestep[name] = True
self.has_seen_array[name] = True
# inspect what callbacks can be called
for cb_cfg in self.registered_callbacks:
simulation_callback = cb_cfg.simulation_callback
description_arrays_needed = cb_cfg.arrays_description
exception_handler = cb_cfg.exception_handler
when = cb_cfg.when
should_call = self.should_call(description_arrays_needed, when)
if should_call:
# Compute the arrays to pass to the callback
callback_args: dict[str, List[DeisaArray]] = self.determine_callback_args(
description_arrays_needed
)
try:
simulation_callback(**callback_args)
except TimeoutError as e:
raise e
except AssertionError as e:
raise e
except BaseException as e:
try:
exception_handler(e)
except BaseException as e:
_default_exception_handler(e)
del callback_args
gc.collect()
# set all new timesteps to be false
for queue in self.has_new_timestep:
self.has_new_timestep[queue] = False
# add the first "bigger" timestep back into queue and set new_timestep flag
if not end_reached:
queue = self.queue_per_array.get(name)
if queue is not None:
queue.append(DeisaArray(dask=array, t=arr_timestep))
self.has_new_timestep[name] = True
self.has_seen_array[name] = True
[docs]
def determine_callback_args(self, description_of_arrays_needed) -> dict[str, List[DeisaArray]]:
"""
Build the kwargs passed to a simulation callback.
Parameters
----------
description_of_arrays_needed : Sequence[WindowSpec]
Array descriptions requested by the callback.
Returns
-------
dict[str, List[DeisaArray]]
Mapping from array name to the latest (windowed) list of ``DeisaArray`` instances.
"""
callback_args = {}
for description in description_of_arrays_needed:
name = description.name
window_size = description.window_size
queue = self.queue_per_array[name]
if window_size is None:
callback_args[name] = [queue[-1]]
else:
callback_args[name] = list(queue)[-window_size:]
return callback_args
[docs]
def should_call(self, description_of_arrays_needed, when: Literal["AND", "OR"]) -> bool:
"""
Determine whether a callback should execute for the current state.
Parameters
----------
description_of_arrays_needed : Sequence[WindowSpec]
Array descriptions governing the callback.
when : Literal["AND", "OR"]
Execution mode specifying whether all arrays or any array must have
new data.
Returns
-------
bool
``True`` when the callback criteria are met.
"""
names = [d.name for d in description_of_arrays_needed]
if when == "AND":
return all(self.has_new_timestep[n] for n in names)
else: # when == 'OR'
return all(self.has_seen_array[n] for n in names) and any(self.has_new_timestep[n] for n in names)
[docs]
def set(
self,
key: Hashable,
*,
value: Any,
timestep: Timestep,
) -> None:
"""
Publish a feedback value for bridges.
Parameters
----------
key : Hashable
Identifier for the shared value.
value : Any
Value to store.
timestep : Hashable
Timestep associated with ``value``.
Notes
-----
Timestamped values are stored in a fixed-size queue on the head actor.
For a given key, timesteps must be strictly increasing; publishing the
same timestep twice or publishing an older timestep raises
:class:`ValueError`.
Bridges retrieve them collectively with
``bridge.get("foo", timestep=t)``.
"""
self._ensure_connected()
ray.get(self.head.set_feedback.remote(key, timestep, value))