File size: 4,009 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import sys
from enum import Enum
import pdb
import io

import torch

def is_available() -> bool:
    """

    Return ``True`` if the distributed package is available.



    Otherwise,

    ``torch.distributed`` does not expose any other APIs. Currently,

    ``torch.distributed`` is available on Linux, MacOS and Windows. Set

    ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source.

    Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows,

    ``USE_DISTRIBUTED=0`` for MacOS.

    """
    return hasattr(torch._C, "_c10d_init")


if is_available() and not torch._C._c10d_init():
    raise RuntimeError("Failed to initialize torch.distributed")

# Custom Runtime Errors thrown from the distributed package
DistError = torch._C._DistError
DistBackendError = torch._C._DistBackendError
DistNetworkError = torch._C._DistNetworkError
DistStoreError = torch._C._DistStoreError

if is_available():
    from torch._C._distributed_c10d import (
        Store,
        FileStore,
        TCPStore,
        ProcessGroup as ProcessGroup,
        Backend as _Backend,
        PrefixStore,
        Reducer,
        Logger,
        BuiltinCommHookType,
        GradBucket,
        Work as _Work,
        _DEFAULT_FIRST_BUCKET_BYTES,
        _register_comm_hook,
        _register_builtin_comm_hook,
        _broadcast_coalesced,
        _compute_bucket_assignment_by_size,
        _verify_params_across_processes,
        _test_python_store,
        DebugLevel,
        get_debug_level,
        set_debug_level,
        set_debug_level_from_env,
        _make_nccl_premul_sum,
    )

    class _DistributedPdb(pdb.Pdb):
        """

        Supports using PDB from inside a multiprocessing child process.



        Usage:

        _DistributedPdb().set_trace()

        """
        def interaction(self, *args, **kwargs):
            _stdin = sys.stdin
            try:
                sys.stdin = open('/dev/stdin')
                pdb.Pdb.interaction(self, *args, **kwargs)
            finally:
                sys.stdin = _stdin

    def breakpoint(rank: int = 0):
        """

        Set a breakpoint, but only on a single rank.  All other ranks will wait for you to be

        done with the breakpoint before continuing.



        Args:

            rank (int): Which rank to break on.  Default: ``0``

        """
        if get_rank() == rank:
            pdb = _DistributedPdb()
            pdb.message(
                "\n!!! ATTENTION !!!\n\n"
                f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
            )
            pdb.set_trace()
        barrier()

    if sys.platform != "win32":
        from torch._C._distributed_c10d import (
            HashStore,
            _round_robin_process_groups,
        )

    from .distributed_c10d import *  # noqa: F403

    # Variables prefixed with underscore are not auto imported
    # See the comment in `distributed_c10d.py` above `_backend` on why we expose
    # this.

    from .distributed_c10d import (
        _all_gather_base,
        _reduce_scatter_base,
        _create_process_group_wrapper,
        _rank_not_in_group,
        _coalescing_manager,
        _CoalescingManager,
        _get_process_group_name,
    )

    from .rendezvous import (
        rendezvous,
        _create_store_from_options,
        register_rendezvous_handler,
    )

    from .remote_device import _remote_device

    set_debug_level_from_env()

else:
    # This stub is sufficient to get
    #   python test/test_public_bindings.py -k test_correct_module_names
    # working even when USE_DISTRIBUTED=0.  Feel free to add more
    # stubs as necessary.
    # We cannot define stubs directly because they confuse pyre

    class _ProcessGroupStub:
        pass
    sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub  # type: ignore[attr-defined]