File size: 7,216 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from typing import Dict, List, Optional, Union

import torch
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
from . import constants as rpc_contants


DeviceType = Union[int, str, torch.device]

__all__ = ["TensorPipeRpcBackendOptions"]

def _to_device(device: DeviceType) -> torch.device:
    device = torch.device(device)
    if device.type != "cuda":
        raise ValueError(
            "`set_devices` expect a list of CUDA devices, but got "
            f"device type {device.type}."
        )
    return device


def _to_device_map(

    device_map: Dict[DeviceType, DeviceType]

) -> Dict[torch.device, torch.device]:
    full_device_map: Dict[torch.device, torch.device] = {}
    reverse_map: Dict[torch.device, torch.device] = {}
    for k, v in device_map.items():
        k, v = torch.device(k), torch.device(v)
        if v in reverse_map:
            raise ValueError(
                "`device_map` only supports 1-to-1 mapping, "
                f"trying to map {k} and {reverse_map[v]} to {v}"
            )
        full_device_map[k] = v
        reverse_map[v] = k
    return full_device_map


def _to_device_list(devices: List[DeviceType]) -> List[torch.device]:
    return list(map(_to_device, devices))


class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
    r"""

    The backend options for

    :class:`~torch.distributed.rpc.TensorPipeAgent`, derived from

    :class:`~torch.distributed.rpc.RpcBackendOptions`.



    Args:

        num_worker_threads (int, optional): The number of threads in the

            thread-pool used by

            :class:`~torch.distributed.rpc.TensorPipeAgent` to execute

            requests (default: 16).

        rpc_timeout (float, optional): The default timeout, in seconds,

            for RPC requests (default: 60 seconds). If the RPC has not

            completed in this timeframe, an exception indicating so will

            be raised. Callers can override this timeout for individual

            RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and

            :meth:`~torch.distributed.rpc.rpc_async` if necessary.

        init_method (str, optional): The URL to initialize the distributed

            store used for rendezvous. It takes any value accepted for the

            same argument of :meth:`~torch.distributed.init_process_group`

            (default: ``env://``).

        device_maps (Dict[str, Dict], optional): Device placement mappings from

            this worker to the callee. Key is the callee worker name and value

            the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``)

            that maps this worker's devices to the callee worker's devices.

            (default: ``None``)

        devices (List[int, str, or ``torch.device``], optional): all local

            CUDA devices used by RPC agent. By Default, it will be initialized

            to all local devices from its own ``device_maps`` and corresponding

            devices from its peers' ``device_maps``. When processing CUDA RPC

            requests, the agent will properly synchronize CUDA streams for

            all devices in this ``List``.

    """

    def __init__(

        self,

        *,

        num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,

        rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,

        init_method: str = rpc_contants.DEFAULT_INIT_METHOD,

        device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,

        devices: Optional[List[DeviceType]] = None,

        _transports: Optional[List] = None,

        _channels: Optional[List] = None,

    ):
        full_device_maps = (
            {}
            if device_maps is None
            else {k: _to_device_map(v) for k, v in device_maps.items()}
        )
        full_device_list = [] if devices is None else _to_device_list(devices)
        super().__init__(
            num_worker_threads,
            _transports,
            _channels,
            rpc_timeout,
            init_method,
            full_device_maps,
            full_device_list,
        )

    def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]):
        r"""

        Set device mapping between each RPC caller and callee pair. This

        function can be called multiple times to incrementally add

        device placement configurations.



        Args:

            to (str): Callee name.

            device_map (Dict of int, str, or torch.device): Device placement

                mappings from this worker to the callee. This map must be

                invertible.



        Example:

            >>> # xdoctest: +SKIP("distributed")

            >>> # both workers

            >>> def add(x, y):

            >>>     print(x)  # tensor([1., 1.], device='cuda:1')

            >>>     return x + y, (x + y).to(2)

            >>>

            >>> # on worker 0

            >>> options = TensorPipeRpcBackendOptions(

            >>>     num_worker_threads=8,

            >>>     device_maps={"worker1": {0: 1}}

            >>>     # maps worker0's cuda:0 to worker1's cuda:1

            >>> )

            >>> options.set_device_map("worker1", {1: 2})

            >>> # maps worker0's cuda:1 to worker1's cuda:2

            >>>

            >>> rpc.init_rpc(

            >>>     "worker0",

            >>>     rank=0,

            >>>     world_size=2,

            >>>     backend=rpc.BackendType.TENSORPIPE,

            >>>     rpc_backend_options=options

            >>> )

            >>>

            >>> x = torch.ones(2)

            >>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1))

            >>> # The first argument will be moved to cuda:1 on worker1. When

            >>> # sending the return value back, it will follow the invert of

            >>> # the device map, and hence will be moved back to cuda:0 and

            >>> # cuda:1 on worker0

            >>> print(rets[0])  # tensor([2., 2.], device='cuda:0')

            >>> print(rets[1])  # tensor([2., 2.], device='cuda:1')

        """
        full_device_map = _to_device_map(device_map)
        curr_device_maps = super().device_maps

        if to in curr_device_maps:
            for k, v in full_device_map.items():
                if k in curr_device_maps[to] and v != curr_device_maps[to][k]:
                    raise ValueError(
                        "`set_device_map` only supports 1-to-1 mapping, trying"
                        f" to map {k} to {v} and {curr_device_maps[to][k]}"
                    )

        super()._set_device_map(to, full_device_map)

    def set_devices(self, devices: List[DeviceType]):
        r"""

        Set local devices used by the TensorPipe RPC agent. When processing

        CUDA RPC requests, the TensorPipe RPC agent will properly synchronize

        CUDA streams for all devices in this ``List``.



        Args:

            devices (List of int, str, or torch.device): local devices used by

                the TensorPipe RPC agent.

        """
        self.devices = _to_device_list(devices)