Source code for deisa.ray.comm
from typing import Protocol
import torch.distributed as dist
import datetime
# TODO : Add test about comm size > declared wolrd size
[docs]
def init_gloo_comm(
world_size: int, rank: int, master_addr: str = "127.0.0.1", master_port: int = 29500, timeout_s: int = 120
) -> "TorchDistComm":
"""
Set up a Gloo communicator backed by a TCP store.
Parameters
----------
world_size : int
Number of ranks participating in the communicator.
rank : int
Rank ID of the current process.
master_addr : str, optional
Hostname or IP address of the master rendezvous node. Defaults to
``"127.0.0.1"``.
master_port : int, optional
Port of the master rendezvous node. Defaults to 29500.
timeout_s : int, optional
Timeout (seconds) for rendezvous setup. Defaults to 120.
Returns
-------
TorchDistComm
Wrapper around the initialized PyTorch process group.
"""
timeout = datetime.timedelta(seconds=timeout_s)
# Rank 0 hosts the rendezvous store; everyone else connects.
store = dist.TCPStore(
host_name=master_addr,
port=master_port,
world_size=world_size,
is_master=(rank == 0),
timeout=timeout,
wait_for_workers=True, # optional; OK to leave default
)
dist.init_process_group(
backend="gloo",
store=store,
world_size=world_size,
rank=rank,
timeout=timeout,
)
return TorchDistComm(rank=rank, world_size=world_size)
[docs]
class Comm(Protocol):
rank: int
world_size: int
[docs]
def barrier(self) -> None:
"""Block until all ranks reach this barrier."""
[docs]
class MPICommAdapter:
"""Adapter exposing an MPI communicator via the shared Comm protocol."""
def __init__(self, comm):
"""
Wrap an MPI communicator.
Parameters
----------
comm : mpi4py.MPI.Comm
MPI communicator to wrap.
"""
self._comm = comm
self.rank = comm.Get_rank()
self.world_size = comm.Get_size()
[docs]
def barrier(self) -> None:
"""Block until all MPI ranks reach this barrier."""
self._comm.Barrier()
[docs]
class TorchDistComm:
"""Torch distributed communicator implementing the Comm protocol."""
def __init__(self, *, rank: int, world_size: int):
"""
Initialize metadata for a torch distributed communicator.
Parameters
----------
rank : int
Rank of the current process.
world_size : int
Total number of ranks in the communicator.
"""
self.rank = rank
self.world_size = world_size
[docs]
def barrier(self) -> None:
"""Block until all Torch distributed ranks reach this barrier."""
dist.barrier()
[docs]
class NoOpComm:
"""Fallback communicator that no-ops synchronization calls."""
def __init__(self, rank: int = 0, world_size: int = 1):
"""
Create a dummy communicator for single-process environments.
Parameters
----------
rank : int, optional
Rank to report. Defaults to 0.
world_size : int, optional
World size to report. Defaults to 1.
"""
self.rank = rank
self.world_size = world_size
[docs]
def barrier(self) -> None:
"""No-op barrier for single-process setups."""
return