Source code for deisa.ray.window_handler

import gc
from typing import Any, Callable, Hashable, Optional

import dask
import dask.array as da
from ray.util.dask import ray_dask_get
import ray

from deisa.ray._scheduler import deisa_ray_get
from deisa.ray.head_node import HeadNodeActor
from deisa.ray.utils import get_head_actor_options
from deisa.ray.types import ActorID, RayActorHandle, WindowArrayDefinition, _CallbackConfig
from deisa.ray.config import config
import logging


@ray.remote(num_cpus=0, max_retries=0)
def _call_prepare_iteration(prepare_iteration: Callable, array: da.Array, timestep: int):
    """
    Call the prepare_iteration function with the given array and timestep.

    This is a Ray remote function that executes the prepare_iteration
    callback with a Dask array. It configures Dask to use the Deisa-Ray
    scheduler before calling the function.

    Parameters
    ----------
    prepare_iteration : Callable
        The function to call. Should accept a Dask array and a timestep
        keyword argument.
    array : da.Array
        The Dask array to pass to the prepare_iteration function.
    timestep : int
        The current timestep to pass to the prepare_iteration function.

    Returns
    -------
    Any
        The result of calling `prepare_iteration(array, timestep=timestep)`.

    Notes
    -----
    This function is executed as a Ray remote task with no CPU requirements
    and no retries. It configures Dask to use the Deisa-Ray scheduler before
    executing the prepare_iteration callback.
    """
    dask.config.set(scheduler=deisa_ray_get, shuffle="tasks")
    return prepare_iteration(array, timestep=timestep)


[docs] class Deisa: def __init__( self, *, ray_start: Optional[Callable[[], None]] = None, handshake: Optional[Callable[["Deisa"], None]] = None, ) -> None: # cheap constructor: no Ray side effects config.lock() self._experimental_distributed_scheduling_enabled = config.experimental_distributed_scheduling_enabled # Do NOT mutate global config here if you want cheap unit tests; # do it when connecting, or inject it similarly. self._ray_start = ray_start or self._ray_start_impl self._handshake = handshake or self._handshake_impl self._connected = False self.node_actors: dict[ActorID, RayActorHandle] = {} self.registered_callbacks: list[_CallbackConfig] = [] def _handshake_impl(self, _: "Deisa") -> None: """ Implementation for handshake between window handler (Deisa) and the Simulation side Bridges. The handshake occurs when all the expected Ray Node Actors are connected. :param self: Description :param _: Description :type _: "Deisa" """ # TODO finish and add this config option to Deisa self.total_nodes = 0 from ray.util.state import list_actors expected_ray_actors = self.total_nodes connected_actors = 0 while connected_actors < expected_ray_actors: connected_actors = 0 for a in list_actors(filters=[("state", "=", "ALIVE")]): if a.get("ray_namespace") == "deisa_ray": connected_actors += 1 def _ensure_connected(self) -> None: """ Ensures that the widow handler has connected to the Ray Cluster. This function connects to ray, creates a head_actor, and waits until a handshake occurs, which happens when all node actors have connected to the cluster. It also changes the dask on ray scheduler based on whether the user wants to set centralized scheduling or not. :param self: Description """ if self._connected: return # Side effects begin here (only once) self._ray_start() # configure dask here if it must reflect actual cluster runtime if self._experimental_distributed_scheduling_enabled: dask.config.set(scheduler=deisa_ray_get, shuffle="tasks") else: dask.config.set(scheduler=ray_dask_get, shuffle="tasks") # head is created self._create_head_actor() # readyness gate for head actor - only return when its alive ray.get( self.head.exchange_config.remote( {"experimental_distributed_scheduling_enabled": self._experimental_distributed_scheduling_enabled} ) ) self._handshake(self) self._connected = True def _create_head_actor(self) -> None: self.head = HeadNodeActor.options(**get_head_actor_options()).remote() def _ray_start_impl(self) -> None: if not ray.is_initialized(): ray.init(address="auto", log_to_driver=False, logging_level=logging.ERROR)
[docs] def register_callback( self, simulation_callback: Callable, arrays_description: list[WindowArrayDefinition], *, max_iterations=1000_000_000, prepare_iteration: Callable | None = None, preparation_advance: int = 3, ) -> None: """ Register the analytics callback and array descriptions. Parameters ---------- simulation_callback : Callable Function to run for each iteration; receives arrays as kwargs and ``timestep``. arrays_description : list[WindowArrayDefinition] Descriptions of arrays to stream to the callback (with optional sliding windows). max_iterations : int, optional Maximum iterations to execute. Default is a large sentinel. prepare_iteration : Callable or None, optional Optional preparatory callback run ``preparation_advance`` steps ahead. Receives the array and ``timestep``. preparation_advance : int, optional How many iterations ahead to prepare when ``prepare_iteration`` is provided. Default is 3. """ self._ensure_connected() # connect + handshake before accepting callbacks cfg = _CallbackConfig( simulation_callback=simulation_callback, arrays_description=arrays_description, max_iterations=max_iterations, prepare_iteration=prepare_iteration, preparation_advance=preparation_advance, ) self.registered_callbacks.append(cfg)
[docs] def unregister_callback( self, simulation_callback: Callable, ) -> None: """ Unregister a previously registered simulation callback. Parameters ---------- simulation_callback : Callable Callback to remove from the registry. Raises ------ NotImplementedError Always, as the feature has not been implemented yet. """ raise NotImplementedError("method not yet implemented.")
# TODO: introduce a method that will generate the final array spec for each registered array
[docs] def execute_callbacks( self, ) -> None: """ Execute the registered simulation callback loop. Notes ----- Supports a single registered callback at present. Manages array retrieval from the head actor, optional preparation tasks, windowed array delivery, and garbage collection between iterations. """ self._ensure_connected() if not self.registered_callbacks: raise RuntimeError("Please register at least one callback before calling execute_callbacks()") if len(self.registered_callbacks) > 1: raise RuntimeError( "execute_callbacks currently supports exactly one registered " "callback. Multiple-callback execution will be implemented later." ) cfg = self.registered_callbacks[0] simulation_callback = cfg.simulation_callback arrays_description = cfg.arrays_description max_iterations = cfg.max_iterations prepare_iteration = cfg.prepare_iteration preparation_advance = cfg.preparation_advance max_pending_arrays = 2 * len(arrays_description) # Convert the definitions to the type expected by the head node head_arrays_description = [(definition.name, definition.preprocess) for definition in arrays_description] # TODO maybe this goes in the register callbacks ray.get(self.head.register_arrays.remote(head_arrays_description, max_pending_arrays)) arrays_by_iteration: dict[int, dict[str, da.Array]] = {} if prepare_iteration is not None: preparation_results: dict[int, ray.ObjectRef] = {} for timestep in range(min(preparation_advance, max_iterations)): # Get the next array from the head node array: da.Array = ray.get(self.head.get_preparation_array.remote(arrays_description[0].name, timestep)) preparation_results[timestep] = _call_prepare_iteration.remote(prepare_iteration, array, timestep) for iteration in range(max_iterations): # Start preparing in advance if iteration + preparation_advance < max_iterations and prepare_iteration is not None: array = self.head.get_preparation_array.remote( arrays_description[0].name, iteration + preparation_advance ) preparation_results[iteration + preparation_advance] = _call_prepare_iteration.remote( prepare_iteration, array, iteration + preparation_advance ) # Get new arrays while len(arrays_by_iteration.get(iteration, {})) < len(arrays_description): name: str timestep: int array: da.Array name, timestep, array = ray.get(self.head.get_next_array.remote()) if timestep not in arrays_by_iteration: arrays_by_iteration[timestep] = {} assert name not in arrays_by_iteration[timestep] arrays_by_iteration[timestep][name] = array # Compute the arrays to pass to the callback all_arrays: dict[str, da.Array | list[da.Array]] = {} for description in arrays_description: if description.window_size is None: all_arrays[description.name] = arrays_by_iteration[iteration][description.name] else: all_arrays[description.name] = [ arrays_by_iteration[timestep][description.name] for timestep in range(max(iteration - description.window_size + 1, 0), iteration + 1) ] if prepare_iteration is not None: preparation_result = ray.get(preparation_results[iteration]) simulation_callback(**all_arrays, timestep=timestep, preparation_result=preparation_result) else: simulation_callback(**all_arrays, timestep=timestep) del all_arrays # Remove the oldest arrays for description in arrays_description: older_timestep = iteration - (description.window_size or 1) + 1 if older_timestep >= 0: del arrays_by_iteration[older_timestep][description.name] if not arrays_by_iteration[older_timestep]: del arrays_by_iteration[older_timestep] # Free the memory used by the arrays now. Since an ObjectRef is a small object, # Python may otherwise choose to keep it in memory for some time, preventing the # actual data to be freed. gc.collect()
# TODO add persist
[docs] def set(self, *args, key: Hashable, value: Any, chunked: bool = False, **kwargs) -> None: """ Broadcast a feedback value to all node actors. Parameters ---------- key : Hashable Identifier for the shared value. value : Any Value to distribute. chunked : bool, optional Placeholder for future distributed-array feedback. Only ``False`` is supported today. Default is ``False``. Notes ----- The method lazily fetches node actors once and uses fire-and-forget remote calls; callers should not assume synchronous delivery. """ # TODO test if not self.node_actors: # retrieve node actors at least once self.node_actors = ray.get(self.head.list_scheduling_actors.remote()) if not chunked: for _, handle in self.node_actors.items(): # set the value inside each node actor # TODO: does it need to be blocking? handle.set.remote(key, value, chunked) else: # TODO: implement chunked version raise NotImplementedError()