Spaces:
Running
Running
import warnings | |
from abc import ABC, abstractmethod | |
from types import TracebackType | |
from typing import Any, List, NamedTuple, Optional, Type | |
import torch | |
import torch.distributed as dist | |
__all__ = ['JoinHook', 'Joinable', 'Join'] | |
class JoinHook: | |
r""" | |
This defines a join hook, which provides two entry points in the join context manager. | |
Entry points : a main hook, which is called repeatedly while there exists a non-joined | |
process, and a post-hook, which is called once all processes have joined. | |
To implement a join hook for the generic join context manager, define a | |
class that inherits from :class:`JoinHook` and override ``main_hook()`` and | |
``post_hook()`` as appropriate. | |
""" | |
def main_hook(self) -> None: | |
r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration. | |
Training iteration i.e., in one forward pass, backward pass, and optimizer step. | |
""" | |
... | |
def post_hook(self, is_last_joiner: bool) -> None: | |
r""" | |
Call hook after all processes have joined. | |
It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join. | |
Arguments: | |
is_last_joiner (bool): ``True`` if the rank is one of the last to | |
join; ``False`` otherwise. | |
""" | |
... | |
class Joinable(ABC): | |
r""" | |
This defines an abstract base class for joinable classes. | |
A joinable class | |
(inheriting from :class:`Joinable`) should implement :meth:`join_hook`, | |
which returns a :class:`JoinHook` instance, in addition to | |
:meth:`join_device` and :meth:`join_process_group` that return device and | |
process group information, respectively. | |
""" | |
def __init__(self): | |
super().__init__() | |
self._join_config = _JoinConfig.construct_disabled_join_config() | |
def join_hook(self, **kwargs) -> JoinHook: | |
r""" | |
Return a :class:`JoinHook` instance for the given :class:`Joinable`. | |
Arguments: | |
kwargs (dict): a :class:`dict` containing any keyword arguments | |
to modify the behavior of the join hook at run time; all | |
:class:`Joinable` instances sharing the same join context | |
manager are forwarded the same value for ``kwargs``. | |
""" | |
... | |
def join_device(self) -> torch.device: | |
r"""Return the device from which to perform collective communications needed by the join context manager.""" | |
... | |
def join_process_group(self) -> Any: | |
r"""Returns the process group for the collective communications needed by the join context manager itself.""" | |
... | |
class _JoinConfig(NamedTuple): | |
r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side.""" | |
enable: bool | |
throw_on_early_termination: bool | |
is_first_joinable: bool | |
def construct_disabled_join_config(): | |
r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled. | |
e.g. if the caller is not in a join context manager. | |
""" | |
return _JoinConfig( | |
enable=False, | |
throw_on_early_termination=False, | |
is_first_joinable=False | |
) | |
class Join: | |
r""" | |
This class defines the generic join context manager, which allows custom hooks to be called after a process joins. | |
These hooks should shadow the | |
collective communications of non-joined processes to prevent hanging and | |
erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook` | |
for details about the hook definition. | |
.. warning:: | |
The context manager requires each participating :class:`Joinable` to | |
call the method :meth:`notify_join_context()` before its own per- | |
iteration collective communications to ensure correctness. | |
.. warning:: | |
The context manager requires that all ``process_group`` attributes in | |
the :class:`JoinHook` objects are the same. If there are multiple | |
:class:`JoinHook` objects, then the ``device`` of the first is used. | |
The process group and device information is used for checking for non- | |
joined processes and for notifying processes to throw an exception if | |
``throw_on_early_termination`` is enabled, both of which using an all- | |
reduce. | |
Arguments: | |
joinables (List[Joinable]): a list of the participating | |
:class:`Joinable` s; their hooks are iterated over in the given | |
order. | |
enable (bool): a flag enabling uneven input detection; setting to | |
``False`` disables the context manager's functionality and should | |
only be set when the user knows the inputs will not be uneven | |
(default: ``True``). | |
throw_on_early_termination (bool): a flag controlling whether to throw an | |
exception upon detecting uneven inputs (default: ``False``). | |
Example:: | |
>>> import os | |
>>> import torch | |
>>> import torch.distributed as dist | |
>>> import torch.multiprocessing as mp | |
>>> # xdoctest: +SKIP | |
>>> import torch.nn.parallel.DistributedDataParallel as DDP | |
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO | |
>>> from torch.distributed.algorithms.join import Join | |
>>> | |
>>> # On each spawned worker | |
>>> def worker(rank): | |
>>> dist.init_process_group("nccl", rank=rank, world_size=2) | |
>>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) | |
>>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) | |
>>> # Rank 1 gets one more input than rank 0 | |
>>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] | |
>>> with Join([model, optim]): | |
>>> for input in inputs: | |
>>> loss = model(input).sum() | |
>>> loss.backward() | |
>>> optim.step() | |
>>> # All ranks reach here without hanging/erroring | |
""" | |
def __init__( | |
self, | |
joinables: List[Joinable], | |
enable: bool = True, | |
throw_on_early_termination: bool = False, | |
**kwargs, | |
): | |
if len(joinables) == 0: | |
raise ValueError("The join context manager requires at least one joinable") | |
self._joinables = joinables | |
self._join_hooks = [joinable.join_hook(**kwargs) for joinable in self._joinables] | |
self._enable = enable | |
self._throw_on_early_termination = throw_on_early_termination | |
self._set_joinable_configs() | |
self._extract_dist_info() | |
def _set_joinable_configs(self) -> None: | |
r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`.""" | |
assert len(self._joinables) > 0 | |
is_first_joinable = True | |
for joinable in self._joinables: | |
joinable._join_config = _JoinConfig( | |
enable=self._enable, | |
throw_on_early_termination=self._throw_on_early_termination, | |
is_first_joinable=is_first_joinable | |
) | |
is_first_joinable = False | |
def _extract_dist_info(self) -> None: | |
r""" | |
Extract the process group and device information from the joinables. | |
If there are multiple joinables, then the context manager uses the | |
first specified device. | |
Preconditions: | |
``self._joinables`` is not ``None`` and is non-empty. | |
Raises: | |
ValueError | |
If there are multiple conflicting ``process_group`` attributes | |
among the ``Joinable`` objects. | |
""" | |
process_group = None | |
device = None | |
for joinable in self._joinables: | |
if process_group is None: | |
process_group = joinable.join_process_group | |
elif process_group != joinable.join_process_group: | |
raise ValueError("Using join context manager with multiple process groups") | |
if device is None: | |
device = joinable.join_device | |
self._process_group = process_group | |
self._rank = dist.get_rank(self._process_group) | |
self._device = device | |
def __enter__(self): | |
... | |
def __exit__( | |
self, | |
type: Optional[Type[BaseException]], | |
value: Optional[BaseException], | |
traceback: Optional[TracebackType] | |
): | |
r""" | |
Repeatedly runs the main hooks until all processes join; then, runs the post-hooks. | |
Raises: | |
RuntimeError | |
If ``throw_on_early_termination=True``. | |
""" | |
if not self._enable or type: | |
return # propagate the exception directly if one was raised | |
all_procs_joined = False | |
is_last_joiner = True | |
i = 0 | |
WARN_THRESHOLD = 1000 | |
warnings.simplefilter("once") | |
while not all_procs_joined: | |
if i > WARN_THRESHOLD: | |
warnings.warn( | |
"Detected uneven input skew of greater than " | |
f"{WARN_THRESHOLD}. This means that rank " | |
f"{self._rank} has at least {WARN_THRESHOLD} " | |
f"fewer inputs than other currently-active ranks. " | |
"This level of skew could lead to performance " | |
"degradation during training." | |
) | |
# Shadow the all-reduce in non-joined processes | |
num_nonjoined_procs = self._get_num_nonjoined_procs() | |
if num_nonjoined_procs == 0: | |
all_procs_joined = True | |
else: | |
if self._throw_on_early_termination: | |
self._notify_procs_to_terminate() | |
# Run main hooks | |
for join_hook in self._join_hooks: | |
join_hook.main_hook() | |
is_last_joiner = False | |
i += 1 | |
# Run post-hooks | |
for join_hook in self._join_hooks: | |
join_hook.post_hook(is_last_joiner) | |
def _get_num_nonjoined_procs(self): | |
r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes.""" | |
num_nonjoined_procs = torch.zeros(1, device=self._device) | |
dist.all_reduce(num_nonjoined_procs, group=self._process_group) | |
return num_nonjoined_procs.item() | |
def _notify_procs_to_terminate(self): | |
r"""Schedule an all-reduce to notify non-joined processes to terminate. | |
Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs. | |
""" | |
ones = torch.ones(1, device=self._device) | |
dist.all_reduce(ones, group=self._process_group) | |
raise RuntimeError(f"Rank {self._rank} exhausted all inputs.") | |
def notify_join_context(joinable: Joinable): | |
r""" | |
Notifies the join context manager that the calling process has not yet joined. | |
Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected | |
(i.e. if one process has already joined) and throws an exception if so. | |
This method should be called from a :class:`Joinable` object before | |
its per-iteration collective communications. For example, this should | |
be called at the beginning of the forward pass in | |
:class:`DistributedDataParallel`. | |
Only the first :class:`Joinable` object passed into the context | |
manager performs the collective communications in this method, and | |
for the others, this method is vacuous. | |
Arguments: | |
joinable (Joinable): the :class:`Joinable` object calling this | |
method. | |
Returns: | |
An async work handle for the all-reduce meant to notify the context | |
manager that the process has not yet joined if ``joinable`` is the | |
first one passed into the context manager; ``None`` otherwise. | |
""" | |
assert hasattr(joinable, "_join_config"), \ | |
f"Check that the {type(joinable)} constructor calls the " \ | |
"``Joinable`` constructor" | |
join_config = joinable._join_config | |
# First joinable is responsible for the collective communications | |
if not join_config.is_first_joinable or not join_config.enable: | |
return None | |
device = joinable.join_device | |
process_group = joinable.join_process_group | |
# Schedule an all-reduce to indicate that the caller has not yet joined | |
ones = torch.ones(1, device=device) | |
work = dist.all_reduce(ones, group=process_group, async_op=True) | |
if join_config.throw_on_early_termination: | |
# Check if uneven inputs have been detected | |
zeros = torch.zeros(1, device=device) | |
dist.all_reduce(zeros, group=process_group) | |
should_throw = zeros.item() | |
if should_throw: | |
raise RuntimeError( | |
"Detected at least one rank that exhausted inputs. " | |
"Throwing across all ranks." | |
) | |
return work | |