Spaces:
Sleeping
Sleeping
from typing import Dict, List, Optional, Union | |
import torch | |
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase | |
from . import constants as rpc_contants | |
DeviceType = Union[int, str, torch.device] | |
__all__ = ["TensorPipeRpcBackendOptions"] | |
def _to_device(device: DeviceType) -> torch.device: | |
device = torch.device(device) | |
if device.type != "cuda": | |
raise ValueError( | |
"`set_devices` expect a list of CUDA devices, but got " | |
f"device type {device.type}." | |
) | |
return device | |
def _to_device_map( | |
device_map: Dict[DeviceType, DeviceType] | |
) -> Dict[torch.device, torch.device]: | |
full_device_map: Dict[torch.device, torch.device] = {} | |
reverse_map: Dict[torch.device, torch.device] = {} | |
for k, v in device_map.items(): | |
k, v = torch.device(k), torch.device(v) | |
if v in reverse_map: | |
raise ValueError( | |
"`device_map` only supports 1-to-1 mapping, " | |
f"trying to map {k} and {reverse_map[v]} to {v}" | |
) | |
full_device_map[k] = v | |
reverse_map[v] = k | |
return full_device_map | |
def _to_device_list(devices: List[DeviceType]) -> List[torch.device]: | |
return list(map(_to_device, devices)) | |
class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): | |
r""" | |
The backend options for | |
:class:`~torch.distributed.rpc.TensorPipeAgent`, derived from | |
:class:`~torch.distributed.rpc.RpcBackendOptions`. | |
Args: | |
num_worker_threads (int, optional): The number of threads in the | |
thread-pool used by | |
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute | |
requests (default: 16). | |
rpc_timeout (float, optional): The default timeout, in seconds, | |
for RPC requests (default: 60 seconds). If the RPC has not | |
completed in this timeframe, an exception indicating so will | |
be raised. Callers can override this timeout for individual | |
RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and | |
:meth:`~torch.distributed.rpc.rpc_async` if necessary. | |
init_method (str, optional): The URL to initialize the distributed | |
store used for rendezvous. It takes any value accepted for the | |
same argument of :meth:`~torch.distributed.init_process_group` | |
(default: ``env://``). | |
device_maps (Dict[str, Dict], optional): Device placement mappings from | |
this worker to the callee. Key is the callee worker name and value | |
the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``) | |
that maps this worker's devices to the callee worker's devices. | |
(default: ``None``) | |
devices (List[int, str, or ``torch.device``], optional): all local | |
CUDA devices used by RPC agent. By Default, it will be initialized | |
to all local devices from its own ``device_maps`` and corresponding | |
devices from its peers' ``device_maps``. When processing CUDA RPC | |
requests, the agent will properly synchronize CUDA streams for | |
all devices in this ``List``. | |
""" | |
def __init__( | |
self, | |
*, | |
num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS, | |
rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC, | |
init_method: str = rpc_contants.DEFAULT_INIT_METHOD, | |
device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None, | |
devices: Optional[List[DeviceType]] = None, | |
_transports: Optional[List] = None, | |
_channels: Optional[List] = None, | |
): | |
full_device_maps = ( | |
{} | |
if device_maps is None | |
else {k: _to_device_map(v) for k, v in device_maps.items()} | |
) | |
full_device_list = [] if devices is None else _to_device_list(devices) | |
super().__init__( | |
num_worker_threads, | |
_transports, | |
_channels, | |
rpc_timeout, | |
init_method, | |
full_device_maps, | |
full_device_list, | |
) | |
def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]): | |
r""" | |
Set device mapping between each RPC caller and callee pair. This | |
function can be called multiple times to incrementally add | |
device placement configurations. | |
Args: | |
to (str): Callee name. | |
device_map (Dict of int, str, or torch.device): Device placement | |
mappings from this worker to the callee. This map must be | |
invertible. | |
Example: | |
>>> # xdoctest: +SKIP("distributed") | |
>>> # both workers | |
>>> def add(x, y): | |
>>> print(x) # tensor([1., 1.], device='cuda:1') | |
>>> return x + y, (x + y).to(2) | |
>>> | |
>>> # on worker 0 | |
>>> options = TensorPipeRpcBackendOptions( | |
>>> num_worker_threads=8, | |
>>> device_maps={"worker1": {0: 1}} | |
>>> # maps worker0's cuda:0 to worker1's cuda:1 | |
>>> ) | |
>>> options.set_device_map("worker1", {1: 2}) | |
>>> # maps worker0's cuda:1 to worker1's cuda:2 | |
>>> | |
>>> rpc.init_rpc( | |
>>> "worker0", | |
>>> rank=0, | |
>>> world_size=2, | |
>>> backend=rpc.BackendType.TENSORPIPE, | |
>>> rpc_backend_options=options | |
>>> ) | |
>>> | |
>>> x = torch.ones(2) | |
>>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1)) | |
>>> # The first argument will be moved to cuda:1 on worker1. When | |
>>> # sending the return value back, it will follow the invert of | |
>>> # the device map, and hence will be moved back to cuda:0 and | |
>>> # cuda:1 on worker0 | |
>>> print(rets[0]) # tensor([2., 2.], device='cuda:0') | |
>>> print(rets[1]) # tensor([2., 2.], device='cuda:1') | |
""" | |
full_device_map = _to_device_map(device_map) | |
curr_device_maps = super().device_maps | |
if to in curr_device_maps: | |
for k, v in full_device_map.items(): | |
if k in curr_device_maps[to] and v != curr_device_maps[to][k]: | |
raise ValueError( | |
"`set_device_map` only supports 1-to-1 mapping, trying" | |
f" to map {k} to {v} and {curr_device_maps[to][k]}" | |
) | |
super()._set_device_map(to, full_device_map) | |
def set_devices(self, devices: List[DeviceType]): | |
r""" | |
Set local devices used by the TensorPipe RPC agent. When processing | |
CUDA RPC requests, the TensorPipe RPC agent will properly synchronize | |
CUDA streams for all devices in this ``List``. | |
Args: | |
devices (List of int, str, or torch.device): local devices used by | |
the TensorPipe RPC agent. | |
""" | |
self.devices = _to_device_list(devices) | |