Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
from abc import ABC, abstractmethod | |
from typing import Any, Callable, Dict, Optional, Tuple | |
from torch.distributed import Store | |
class RendezvousError(Exception): | |
"""Represents the base type for rendezvous errors.""" | |
class RendezvousClosedError(RendezvousError): | |
"""Raised when a rendezvous is closed.""" | |
class RendezvousTimeoutError(RendezvousError): | |
"""Raised when a rendezvous did not complete on time.""" | |
class RendezvousConnectionError(RendezvousError): | |
"""Raised when the connection to a rendezvous backend has failed.""" | |
class RendezvousStateError(RendezvousError): | |
"""Raised when the state of a rendezvous is corrupt.""" | |
class RendezvousGracefulExitError(RendezvousError): | |
"""Raised when node wasn't not included in rendezvous and gracefully exits. | |
Exception is a mechanism to exit the stack, however does not mean a failure. | |
""" | |
class RendezvousHandler(ABC): | |
"""Main rendezvous interface. | |
Note: | |
Distributed Torch users normally **do not** need to implement their own | |
``RendezvousHandler``. An implementation based on C10d Store is already | |
provided, and is recommended for most users. | |
""" | |
def get_backend(self) -> str: | |
"""Return the name of the rendezvous backend.""" | |
def next_rendezvous( | |
self, | |
) -> Tuple[Store, int, int]: | |
"""Main entry-point into the rendezvous barrier. | |
Blocks until the rendezvous is complete and the current process is | |
included in the formed worker group, or a timeout occurs, or the | |
rendezvous was marked closed. | |
Returns: | |
A tuple of :py:class:`torch.distributed.Store`, ``rank``, and | |
``world size``. | |
Raises: | |
RendezvousClosedError: | |
The rendezvous is closed. | |
RendezvousConnectionError: | |
The connection to the rendezvous backend has failed. | |
RendezvousStateError: | |
The rendezvous state is corrupt. | |
RendezvousTimeoutError: | |
The rendezvous did not complete on time. | |
""" | |
def is_closed(self) -> bool: | |
"""Check whether the rendezvous has been closed. | |
A closed rendezvous means all future attempts to re-rendezvous within | |
same job will fail. | |
``is_closed()`` and :py:meth:`set_closed` have semantics of eventual | |
propagation and should not be used for synchronization. The intention is | |
that if at least one node decides the job is finished, it will close the | |
rendezvous, and other nodes will soon observe this and stop running as | |
well. | |
""" | |
def set_closed(self): | |
"""Mark the rendezvous as closed.""" | |
def num_nodes_waiting(self) -> int: | |
"""Return the number of nodes who arrived late at the rendezvous | |
barrier, hence were not included in the current worker group. | |
Callers should periodically call this method to check whether new | |
nodes are waiting to join the job and if so admit them by calling | |
:py:meth:`next_rendezvous()` (re-rendezvous). | |
""" | |
def get_run_id(self) -> str: | |
"""Return the run id of the rendezvous. | |
The run id is a user-defined id that uniquely identifies an instance of | |
a distributed application. It typically maps to a job id and is used to | |
allow nodes to join the correct distributed application. | |
""" | |
def shutdown(self) -> bool: | |
"""Close all resources that were open for the rendezvous. | |
Example:: | |
rdzv_handler = ... | |
try: | |
store, rank, world_size = rdzv_handler.next_rendezvous() | |
finally: | |
rdzv_handler.shutdown() | |
""" | |
class RendezvousParameters: | |
"""Hold the parameters to construct a :py:class:`RendezvousHandler`. | |
Args: | |
backend: | |
The name of the backend to use to handle the rendezvous. | |
endpoint: | |
The endpoint of the rendezvous, usually in form <hostname>[:<port>]. | |
run_id: | |
The id of the rendezvous. | |
min_nodes: | |
The minimum number of nodes to admit to the rendezvous. | |
max_nodes: | |
The maximum number of nodes to admit to the rendezvous. | |
local_addr: | |
The address of the local node. | |
**kwargs: | |
Additional parameters for the specified backend. | |
""" | |
def __init__( | |
self, | |
backend: str, | |
endpoint: str, | |
run_id: str, | |
min_nodes: int, | |
max_nodes: int, | |
local_addr: Optional[str] = None, | |
**kwargs, | |
): | |
if not backend: | |
raise ValueError("The rendezvous backend name must be a non-empty string.") | |
if min_nodes < 1: | |
raise ValueError( | |
f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero." | |
) | |
if max_nodes < min_nodes: | |
raise ValueError( | |
f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or " | |
f"equal to the minimum number of rendezvous nodes ({min_nodes})." | |
) | |
self.backend = backend | |
self.endpoint = endpoint | |
self.run_id = run_id | |
self.min_nodes = min_nodes | |
self.max_nodes = max_nodes | |
self.config = kwargs | |
self.local_addr = local_addr | |
def get(self, key: str, default: Any = None) -> Any: | |
"""Return the value for ``key`` if ``key`` exists, else ``default``.""" | |
return self.config.get(key, default) | |
def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]: | |
"""Return the value for ``key`` as a ``bool``.""" | |
value = self.get(key, default) | |
if value is None or isinstance(value, bool): | |
return value | |
if isinstance(value, int): | |
if value == 1: | |
return True | |
if value == 0: | |
return False | |
elif isinstance(value, str): | |
if value.lower() in ["1", "true", "t", "yes", "y"]: | |
return True | |
if value.lower() in ["0", "false", "f", "no", "n"]: | |
return False | |
raise ValueError( | |
f"The rendezvous configuration option '{key}' does not represent a valid boolean value." | |
) | |
def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]: | |
"""Return the value for ``key`` as an ``int``.""" | |
value = self.get(key, default) | |
if value is None: | |
return value | |
try: | |
return int(value) | |
except ValueError as e: | |
raise ValueError( | |
f"The rendezvous configuration option '{key}' does not represent a valid integer " | |
"value." | |
) from e | |
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler] | |
class RendezvousHandlerRegistry: | |
"""Represent a registry of :py:class:`RendezvousHandler` backends.""" | |
_registry: Dict[str, RendezvousHandlerCreator] | |
def __init__(self) -> None: | |
self._registry = {} | |
def register(self, backend: str, creator: RendezvousHandlerCreator) -> None: | |
"""Register a new rendezvous backend. | |
Args: | |
backend: | |
The name of the backend. | |
creator: | |
The callback to invoke to construct the | |
:py:class:`RendezvousHandler`. | |
""" | |
if not backend: | |
raise ValueError("The rendezvous backend name must be a non-empty string.") | |
current_creator: Optional[RendezvousHandlerCreator] | |
try: | |
current_creator = self._registry[backend] | |
except KeyError: | |
current_creator = None | |
if current_creator is not None and current_creator != creator: | |
raise ValueError( | |
f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it " | |
f"is already registered with '{current_creator}'." | |
) | |
self._registry[backend] = creator | |
def create_handler(self, params: RendezvousParameters) -> RendezvousHandler: | |
"""Create a new :py:class:`RendezvousHandler`.""" | |
try: | |
creator = self._registry[params.backend] | |
except KeyError as e: | |
raise ValueError( | |
f"The rendezvous backend '{params.backend}' is not registered. Did you forget " | |
f"to call `{self.register.__name__}`?" | |
) from e | |
handler = creator(params) | |
# Do some sanity check. | |
if handler.get_backend() != params.backend: | |
raise RuntimeError( | |
f"The rendezvous backend '{handler.get_backend()}' does not match the requested " | |
f"backend '{params.backend}'." | |
) | |
return handler | |
# The default global registry instance used by launcher scripts to instantiate | |
# rendezvous handlers. | |
rendezvous_handler_registry = RendezvousHandlerRegistry() | |