Spaces:
Running
Running
import os | |
import sys | |
from enum import Enum | |
import pdb | |
import io | |
import torch | |
def is_available() -> bool: | |
""" | |
Return ``True`` if the distributed package is available. | |
Otherwise, | |
``torch.distributed`` does not expose any other APIs. Currently, | |
``torch.distributed`` is available on Linux, MacOS and Windows. Set | |
``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source. | |
Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows, | |
``USE_DISTRIBUTED=0`` for MacOS. | |
""" | |
return hasattr(torch._C, "_c10d_init") | |
if is_available() and not torch._C._c10d_init(): | |
raise RuntimeError("Failed to initialize torch.distributed") | |
# Custom Runtime Errors thrown from the distributed package | |
DistError = torch._C._DistError | |
DistBackendError = torch._C._DistBackendError | |
DistNetworkError = torch._C._DistNetworkError | |
DistStoreError = torch._C._DistStoreError | |
if is_available(): | |
from torch._C._distributed_c10d import ( | |
Store, | |
FileStore, | |
TCPStore, | |
ProcessGroup as ProcessGroup, | |
Backend as _Backend, | |
PrefixStore, | |
Reducer, | |
Logger, | |
BuiltinCommHookType, | |
GradBucket, | |
Work as _Work, | |
_DEFAULT_FIRST_BUCKET_BYTES, | |
_register_comm_hook, | |
_register_builtin_comm_hook, | |
_broadcast_coalesced, | |
_compute_bucket_assignment_by_size, | |
_verify_params_across_processes, | |
_test_python_store, | |
DebugLevel, | |
get_debug_level, | |
set_debug_level, | |
set_debug_level_from_env, | |
_make_nccl_premul_sum, | |
) | |
class _DistributedPdb(pdb.Pdb): | |
""" | |
Supports using PDB from inside a multiprocessing child process. | |
Usage: | |
_DistributedPdb().set_trace() | |
""" | |
def interaction(self, *args, **kwargs): | |
_stdin = sys.stdin | |
try: | |
sys.stdin = open('/dev/stdin') | |
pdb.Pdb.interaction(self, *args, **kwargs) | |
finally: | |
sys.stdin = _stdin | |
def breakpoint(rank: int = 0): | |
""" | |
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be | |
done with the breakpoint before continuing. | |
Args: | |
rank (int): Which rank to break on. Default: ``0`` | |
""" | |
if get_rank() == rank: | |
pdb = _DistributedPdb() | |
pdb.message( | |
"\n!!! ATTENTION !!!\n\n" | |
f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n" | |
) | |
pdb.set_trace() | |
barrier() | |
if sys.platform != "win32": | |
from torch._C._distributed_c10d import ( | |
HashStore, | |
_round_robin_process_groups, | |
) | |
from .distributed_c10d import * # noqa: F403 | |
# Variables prefixed with underscore are not auto imported | |
# See the comment in `distributed_c10d.py` above `_backend` on why we expose | |
# this. | |
from .distributed_c10d import ( | |
_all_gather_base, | |
_reduce_scatter_base, | |
_create_process_group_wrapper, | |
_rank_not_in_group, | |
_coalescing_manager, | |
_CoalescingManager, | |
_get_process_group_name, | |
) | |
from .rendezvous import ( | |
rendezvous, | |
_create_store_from_options, | |
register_rendezvous_handler, | |
) | |
from .remote_device import _remote_device | |
set_debug_level_from_env() | |
else: | |
# This stub is sufficient to get | |
# python test/test_public_bindings.py -k test_correct_module_names | |
# working even when USE_DISTRIBUTED=0. Feel free to add more | |
# stubs as necessary. | |
# We cannot define stubs directly because they confuse pyre | |
class _ProcessGroupStub: | |
pass | |
sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined] | |