Spaces:
Sleeping
Sleeping
File size: 7,216 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
from typing import Dict, List, Optional, Union
import torch
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
from . import constants as rpc_contants
DeviceType = Union[int, str, torch.device]
__all__ = ["TensorPipeRpcBackendOptions"]
def _to_device(device: DeviceType) -> torch.device:
device = torch.device(device)
if device.type != "cuda":
raise ValueError(
"`set_devices` expect a list of CUDA devices, but got "
f"device type {device.type}."
)
return device
def _to_device_map(
device_map: Dict[DeviceType, DeviceType]
) -> Dict[torch.device, torch.device]:
full_device_map: Dict[torch.device, torch.device] = {}
reverse_map: Dict[torch.device, torch.device] = {}
for k, v in device_map.items():
k, v = torch.device(k), torch.device(v)
if v in reverse_map:
raise ValueError(
"`device_map` only supports 1-to-1 mapping, "
f"trying to map {k} and {reverse_map[v]} to {v}"
)
full_device_map[k] = v
reverse_map[v] = k
return full_device_map
def _to_device_list(devices: List[DeviceType]) -> List[torch.device]:
return list(map(_to_device, devices))
class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
r"""
The backend options for
:class:`~torch.distributed.rpc.TensorPipeAgent`, derived from
:class:`~torch.distributed.rpc.RpcBackendOptions`.
Args:
num_worker_threads (int, optional): The number of threads in the
thread-pool used by
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
requests (default: 16).
rpc_timeout (float, optional): The default timeout, in seconds,
for RPC requests (default: 60 seconds). If the RPC has not
completed in this timeframe, an exception indicating so will
be raised. Callers can override this timeout for individual
RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
:meth:`~torch.distributed.rpc.rpc_async` if necessary.
init_method (str, optional): The URL to initialize the distributed
store used for rendezvous. It takes any value accepted for the
same argument of :meth:`~torch.distributed.init_process_group`
(default: ``env://``).
device_maps (Dict[str, Dict], optional): Device placement mappings from
this worker to the callee. Key is the callee worker name and value
the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``)
that maps this worker's devices to the callee worker's devices.
(default: ``None``)
devices (List[int, str, or ``torch.device``], optional): all local
CUDA devices used by RPC agent. By Default, it will be initialized
to all local devices from its own ``device_maps`` and corresponding
devices from its peers' ``device_maps``. When processing CUDA RPC
requests, the agent will properly synchronize CUDA streams for
all devices in this ``List``.
"""
def __init__(
self,
*,
num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
devices: Optional[List[DeviceType]] = None,
_transports: Optional[List] = None,
_channels: Optional[List] = None,
):
full_device_maps = (
{}
if device_maps is None
else {k: _to_device_map(v) for k, v in device_maps.items()}
)
full_device_list = [] if devices is None else _to_device_list(devices)
super().__init__(
num_worker_threads,
_transports,
_channels,
rpc_timeout,
init_method,
full_device_maps,
full_device_list,
)
def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]):
r"""
Set device mapping between each RPC caller and callee pair. This
function can be called multiple times to incrementally add
device placement configurations.
Args:
to (str): Callee name.
device_map (Dict of int, str, or torch.device): Device placement
mappings from this worker to the callee. This map must be
invertible.
Example:
>>> # xdoctest: +SKIP("distributed")
>>> # both workers
>>> def add(x, y):
>>> print(x) # tensor([1., 1.], device='cuda:1')
>>> return x + y, (x + y).to(2)
>>>
>>> # on worker 0
>>> options = TensorPipeRpcBackendOptions(
>>> num_worker_threads=8,
>>> device_maps={"worker1": {0: 1}}
>>> # maps worker0's cuda:0 to worker1's cuda:1
>>> )
>>> options.set_device_map("worker1", {1: 2})
>>> # maps worker0's cuda:1 to worker1's cuda:2
>>>
>>> rpc.init_rpc(
>>> "worker0",
>>> rank=0,
>>> world_size=2,
>>> backend=rpc.BackendType.TENSORPIPE,
>>> rpc_backend_options=options
>>> )
>>>
>>> x = torch.ones(2)
>>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1))
>>> # The first argument will be moved to cuda:1 on worker1. When
>>> # sending the return value back, it will follow the invert of
>>> # the device map, and hence will be moved back to cuda:0 and
>>> # cuda:1 on worker0
>>> print(rets[0]) # tensor([2., 2.], device='cuda:0')
>>> print(rets[1]) # tensor([2., 2.], device='cuda:1')
"""
full_device_map = _to_device_map(device_map)
curr_device_maps = super().device_maps
if to in curr_device_maps:
for k, v in full_device_map.items():
if k in curr_device_maps[to] and v != curr_device_maps[to][k]:
raise ValueError(
"`set_device_map` only supports 1-to-1 mapping, trying"
f" to map {k} to {v} and {curr_device_maps[to][k]}"
)
super()._set_device_map(to, full_device_map)
def set_devices(self, devices: List[DeviceType]):
r"""
Set local devices used by the TensorPipe RPC agent. When processing
CUDA RPC requests, the TensorPipe RPC agent will properly synchronize
CUDA streams for all devices in this ``List``.
Args:
devices (List of int, str, or torch.device): local devices used by
the TensorPipe RPC agent.
"""
self.devices = _to_device_list(devices)
|