Spaces:
Running
Running
__all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync", | |
"rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"] | |
import collections | |
import contextlib | |
import functools | |
import inspect | |
import logging | |
import threading | |
from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING | |
import torch | |
from torch.futures import Future | |
from torch._C._distributed_rpc import ( | |
PyRRef, | |
RemoteProfilerManager, | |
WorkerInfo, | |
TensorPipeAgent, | |
get_rpc_timeout, | |
_cleanup_python_rpc_handler, | |
_delete_all_user_and_unforked_owner_rrefs, | |
_destroy_rref_context, | |
_get_current_rpc_agent, | |
_invoke_remote_builtin, | |
_invoke_remote_python_udf, | |
_invoke_remote_torchscript, | |
_invoke_rpc_builtin, | |
_invoke_rpc_python_udf, | |
_invoke_rpc_torchscript, | |
_is_current_rpc_agent_set, | |
_reset_current_rpc_agent, | |
_set_and_start_rpc_agent, | |
) | |
from .internal import ( | |
PythonUDF, | |
RPCExecMode, | |
_internal_rpc_pickler, | |
_build_rpc_profiling_key, | |
) | |
from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT | |
from ._utils import _group_membership_management, _update_group_membership | |
logger = logging.getLogger(__name__) | |
# NB: Ignoring RRef leaks during shutdown. Without this, applications have to | |
# make sure there is no references to any RRef in the application code and | |
# Python GC has done its job to delete those RRefs. This is could result in bad | |
# debugging experiences especially when for large applications. Therefore, by | |
# default, we are going to ignore RRef leaks during shutdown. This is usually | |
# fine as shutdown means applications have done training and no longer care | |
# about states. | |
# | |
# To enable RRef leak checking, set this _ignore_rref_leak to False | |
_ignore_rref_leak = True | |
_default_pickler = _internal_rpc_pickler | |
def _use_rpc_pickler(rpc_pickler): | |
r""" | |
rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler | |
""" | |
global _default_pickler | |
_default_pickler = rpc_pickler | |
try: | |
yield | |
finally: | |
_default_pickler = _internal_rpc_pickler | |
def _require_initialized(func): | |
def wrapper(*args, **kwargs): | |
if not _is_current_rpc_agent_set(): | |
raise RuntimeError( | |
"RPC has not been initialized. Call " | |
"torch.distributed.rpc.init_rpc first." | |
) | |
return func(*args, **kwargs) | |
return wrapper | |
class AllGatherStates: | |
def __init__(self): | |
# Each `gathered_objects` is an empty dict at beginning. | |
# The leader worker is elected as the first worker in a sorted worker | |
# name list. Whenever there is a worker entering `_all_gather()`, it | |
# runs `_gather_to_leader()` on the leader to add its own name and | |
# data obj to this dict. The leader also adds itself's name to the dict | |
# on calling `_all_gather()`. | |
# Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader | |
# will broadcast the gathered dict to all follower workers and set their | |
# `gathered_objects` field and the `proceed_signal` field. | |
self.gathered_objects = {} | |
# All workers wait on this signal until it receives all gathered | |
# objects. | |
self.proceed_signal = threading.Event() | |
# States used by `def _all_gather()`. | |
# `_ALL_WORKER_NAMES` is initialized on initializing RPC layer. | |
_ALL_WORKER_NAMES: Set[Any] = set() | |
_all_gather_dict_lock = threading.RLock() | |
_all_gather_sequence_id: Dict[str, int] = {} | |
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates) | |
def _init_rpc_states(agent): | |
worker_infos = agent.get_worker_infos() | |
global _ALL_WORKER_NAMES | |
_ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos} | |
# NB: backend implementation might have already set the rpc_agent. | |
if not _is_current_rpc_agent_set(): | |
_set_and_start_rpc_agent(agent) | |
def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None): | |
with _all_gather_dict_lock: | |
if not worker_names: | |
worker_names = _ALL_WORKER_NAMES | |
assert ( | |
worker_name in worker_names | |
), f"{worker_name} is not expected by leader." | |
states = _all_gather_sequence_id_to_states[sequence_id] | |
assert ( | |
worker_name not in states.gathered_objects | |
), f"{worker_name} reported intent sequence id {sequence_id} twice. " | |
states.gathered_objects[worker_name] = obj | |
if worker_names == set(states.gathered_objects.keys()): | |
states.proceed_signal.set() | |
def _broadcast_to_followers(sequence_id, objects_map): | |
with _all_gather_dict_lock: | |
states = _all_gather_sequence_id_to_states[sequence_id] | |
assert ( | |
not states.proceed_signal.is_set() | |
), f"Termination signal sequence id {sequence_id} got set twice." | |
states.gathered_objects = objects_map | |
states.proceed_signal.set() | |
_thread_local_var = threading.local() | |
def _wait_all(): | |
r""" | |
A context manager that collects all futures returned by ``rpc_async`` and | |
waits them on the context manager's exit; relieving the user of needing | |
to explicitly call wait. | |
Example:: | |
>>> # xdoctest: +SKIP("distributed") | |
>>> # On worker 0: | |
>>> import torch | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker0", rank=0, world_size=2) | |
>>> with rpc._wait_all(): | |
>>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) | |
>>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) | |
>>> #fut_1 and fut_2 are waited on | |
""" | |
_thread_local_var.future_list = [] | |
try: | |
yield | |
finally: | |
try: | |
torch.futures.wait_all(_thread_local_var.future_list) | |
finally: | |
del _thread_local_var.future_list | |
def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): | |
r""" | |
This is similar to torch.distributed.all_gather(), but is using RPC. It | |
picks the worker with the smallest name (alphabetic order) as the leader. | |
Then all followers send their data ``obj`` to the leader. After the leader | |
has received all, it will broadcast the results back to all followers. This | |
function blocks until all workers have received the gathered results. | |
""" | |
if not worker_names: | |
assert ( | |
_ALL_WORKER_NAMES is not None | |
), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." | |
worker_names = _ALL_WORKER_NAMES | |
leader_name = min(worker_names) | |
self_name = _get_current_rpc_agent().get_worker_info().name | |
with _all_gather_dict_lock: | |
concat_names = "".join(sorted(worker_names)) | |
sequence_num = _all_gather_sequence_id.get(concat_names, 0) | |
_all_gather_sequence_id[concat_names] = sequence_num + 1 | |
sequence_id = concat_names + str(sequence_num) | |
is_leader = leader_name == self_name | |
if timeout == UNSET_RPC_TIMEOUT: | |
# Timeout is specified by agent for RPC calls | |
rpc_timeout = get_rpc_timeout() | |
# No timeout for signal | |
signal_timeout = None | |
elif timeout == DEFAULT_SHUTDOWN_TIMEOUT: | |
# No timeout for RPC | |
rpc_timeout = timeout | |
# No timeout for signal | |
signal_timeout = None | |
else: | |
# Signal and RPC timeout use the same timeout | |
signal_timeout = rpc_timeout = timeout | |
# Phase 1: Followers send it's object to the leader | |
if is_leader: | |
_gather_to_leader(sequence_id, self_name, obj, worker_names) | |
else: | |
rpc_sync( | |
leader_name, | |
_gather_to_leader, | |
args=(sequence_id, self_name, obj, worker_names), | |
timeout=rpc_timeout, | |
) | |
with _all_gather_dict_lock: | |
states = _all_gather_sequence_id_to_states[sequence_id] | |
# Timeout is either set by function parameter or None (which is indefinite) | |
states.proceed_signal.wait(timeout=signal_timeout) | |
# Phase 2: Leader broadcast gathered results to all followers | |
# Leader's signal is the first to be unblocked, after receiving all | |
# followers' data objects. | |
if is_leader: | |
worker_name_to_response_future_dict = {} | |
for follower_name in worker_names - {leader_name}: | |
fut = rpc_async( | |
follower_name, | |
_broadcast_to_followers, | |
args=(sequence_id, states.gathered_objects), | |
timeout=rpc_timeout | |
) | |
worker_name_to_response_future_dict[follower_name] = fut | |
errors = [] | |
for follower_name, fut in worker_name_to_response_future_dict.items(): | |
try: | |
fut.wait() | |
except RuntimeError as ex: | |
errors.append((follower_name, ex)) | |
if errors: | |
raise RuntimeError( | |
f"Followers {[e[0] for e in errors]} timed out in _all_gather " | |
f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}" | |
) | |
# Clean up for the states using the sequence_id | |
with _all_gather_dict_lock: | |
states = _all_gather_sequence_id_to_states.pop(sequence_id) | |
return states.gathered_objects | |
def _barrier(worker_names): | |
r""" | |
Synchronizes local and remote RPC processes. | |
This will block until all local and remote RPC processes specified under worker_names | |
reach this method to wait for all outstanding work to complete. | |
Args: | |
worker_names (List[str]): The set of workers to synchronize. | |
""" | |
try: | |
_all_gather(None, set(worker_names)) | |
except RuntimeError as ex: | |
logger.error( | |
"Failed to complete barrier, got error %s", ex | |
) | |
def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT): | |
r""" | |
Block until all local and remote RPC processes reach this method and wait | |
for all outstanding work to complete. Every RPC process must call this | |
method before exit to perform a graceful shutdown. This should be used to | |
terminate the RPC framework, and there is no guarantee that the RPC | |
framework will work after this method returns. | |
""" | |
try: | |
_all_gather(None, timeout=timeout) | |
except RuntimeError as ex: | |
logger.error( | |
"Failed to respond to 'Shutdown Proceed' in time, got error %s", ex | |
) | |
raise ex | |
def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): | |
r""" | |
Perform a shutdown of the RPC agent, and then destroy the RPC agent. This | |
stops the local agent from accepting outstanding requests, and shuts | |
down the RPC framework by terminating all RPC threads. If ``graceful=True``, | |
this will block until all local and remote RPC processes reach this method | |
and wait for all outstanding work to complete. Otherwise, if | |
``graceful=False``, this is a local shutdown, and it does not wait for other | |
RPC processes to reach this method. | |
.. warning:: | |
For :class:`~torch.futures.Future` objects returned by | |
:meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not | |
be called after ``shutdown()``. | |
Args: | |
graceful (bool): Whether to do a graceful shutdown or not. If True, | |
this will 1) wait until there is no pending system | |
messages for ``UserRRefs`` and delete them; 2) block | |
until all local and remote RPC processes have reached | |
this method and wait for all outstanding work to | |
complete. | |
Example:: | |
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly | |
on both workers. Refer to :meth:`~torch.distributed.init_process_group` | |
API for more details. For example, | |
export MASTER_ADDR=localhost | |
export MASTER_PORT=5678 | |
Then run the following code in two different processes: | |
>>> # xdoctest: +SKIP | |
>>> # On worker 0: | |
>>> import torch | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker0", rank=0, world_size=2) | |
>>> # do some work | |
>>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) | |
>>> # ready to shutdown | |
>>> rpc.shutdown() | |
>>> # On worker 1: | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker1", rank=1, world_size=2) | |
>>> # wait for worker 0 to finish work, and then shutdown. | |
>>> rpc.shutdown() | |
""" | |
if graceful: | |
try: | |
agent = _get_current_rpc_agent() | |
if not isinstance(agent, TensorPipeAgent) or agent.is_static_group: | |
_wait_all_workers(timeout) | |
_delete_all_user_and_unforked_owner_rrefs() | |
agent.join(shutdown=True, timeout=timeout) | |
else: | |
# This is a dynamic group so we need to grab the token for the operation | |
my_worker_info = agent.get_worker_info() | |
my_name = my_worker_info.name | |
with _group_membership_management(agent.store, my_name, False): | |
all_worker_infos = agent.get_worker_infos() | |
for worker in all_worker_infos: | |
if worker.name != my_name: | |
rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False)) | |
agent.join(shutdown=True, timeout=timeout) | |
finally: | |
# In case of errors, continue to complete the local shutdown. | |
_finalize_shutdown() | |
else: | |
_finalize_shutdown() | |
def _finalize_shutdown(): | |
try: | |
# This raises a `TORCH_CHECK()` exception on RRef leak detected. | |
_destroy_rref_context(_ignore_rref_leak) | |
finally: | |
_get_current_rpc_agent().shutdown() | |
# clean up python rpc handler in shutdown(), see comments in | |
# PythonRpcHandler::cleanup(), call it in python API because the | |
# cleanup() function has python dependency, it assumes python | |
# interpreter exists. | |
# No matter if RRef leak exception is raised, this clean-up code | |
# must run to avoid destruction segfault in Python 3.5. | |
# | |
# future.wait() should not be called after shutdown(). | |
# pythonRpcHandler is cleaned up in shutdown(), after | |
# shutdown(), python objects returned from rpc python call can not be | |
# resolved. | |
_cleanup_python_rpc_handler() | |
_reset_current_rpc_agent() | |
def get_worker_info(worker_name=None): | |
r""" | |
Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name. | |
Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an | |
expensive string on every invocation. | |
Args: | |
worker_name (str): the string name of a worker. If ``None``, return the | |
the id of the current worker. (default ``None``) | |
Returns: | |
:class:`~torch.distributed.rpc.WorkerInfo` instance for the given | |
``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the | |
current worker if ``worker_name`` is ``None``. | |
""" | |
if worker_name is not None: | |
return _get_current_rpc_agent().get_worker_info(worker_name) | |
else: | |
return _get_current_rpc_agent().get_worker_info() | |
def _to_worker_info(to): | |
if isinstance(to, WorkerInfo): | |
return to | |
elif isinstance(to, (str, int)): | |
return get_worker_info(to) | |
else: | |
raise ValueError(f"Cannot get WorkerInfo from name {to}") | |
def _rref_typeof_on_owner(rref, blocking: bool = True): | |
rref_type = type(rref.local_value()) | |
if blocking: | |
return rref_type | |
else: | |
# Wrap result into a completed Future. This is so that if blocking=`False` | |
# is specified, we return a future regardless of if this call is on user | |
# or owner. | |
future = Future[type]() | |
future.set_result(rref_type) | |
return future | |
def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True): | |
fut = rpc_async( | |
rref.owner(), | |
_rref_typeof_on_owner, | |
args=(rref,), | |
timeout=timeout | |
) | |
if blocking: | |
return fut.wait() | |
else: | |
return fut | |
T = TypeVar("T") | |
GenericWithOneTypeVar = Generic[T] | |
if TYPE_CHECKING: | |
class RRef(PyRRef[T], Generic[T]): | |
pass | |
else: | |
try: | |
# Combine the implementation class and the type class. | |
class RRef(PyRRef, Generic[T]): | |
pass | |
except TypeError: | |
# TypeError: metaclass conflict: the metaclass of a derived class | |
# must be a (non-strict) subclass of the metaclasses of all its bases | |
# Mypy doesn't understand __class__ (mypy bug #4177) | |
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type] | |
pass | |
# Combine the implementation class and the type class. | |
# Types for classes expecting a certain generic parameter (mypy bug #7791) | |
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type] | |
pass | |
# Install docstrings from `PyRRef` to `RRef`. | |
# | |
# This is for the fact that pybind11 generates the parameter | |
# `self` as type `rpc.PyRRef`, so a `:inherited-members:` | |
# under `.. autoclass:: RRef` does not work. | |
# we have to do the following process to replace `rpc.PyRRef` with `rpc.RRef`. | |
# | |
def method_factory(method_name, docstring): | |
def method(self, *args, **kwargs): | |
return getattr(super(RRef, self), method_name)(*args, **kwargs) | |
if method.__doc__: | |
method.__doc__ = docstring | |
return method | |
for method_name, method in inspect.getmembers(PyRRef): | |
# Ignore magic methods, except "__str__". | |
if method_name.startswith("_") and method_name != "__str__": | |
continue | |
# Get pybind11 generated docstring. | |
# It's like, | |
""" | |
to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object | |
Blocking call that copies the value of the RRef from the owner | |
to the local node and returns it. If the current node is the | |
owner, returns a reference to the local value. | |
""" | |
docstring = getattr(method, "__doc__", None) | |
assert docstring is not None, "RRef user-facing methods should all have docstrings." | |
# Do surgery on pybind11 generated docstrings. | |
docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef") | |
# Attach user-facing RRef method with modified docstring. | |
new_method = method_factory(method_name, docstring) | |
setattr(RRef, method_name, new_method) | |
def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): | |
r""" | |
Make a remote call to run ``func`` on worker ``to`` and return an | |
:class:`~torch.distributed.rpc.RRef` to the result value immediately. | |
Worker ``to`` will be the owner of the returned | |
:class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is | |
a user. The owner manages the global reference count of its | |
:class:`~torch.distributed.rpc.RRef`, and the owner | |
:class:`~torch.distributed.rpc.RRef` is only destructed when globally there | |
are no living references to it. | |
Args: | |
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. | |
func (Callable): a callable function, such as Python callables, builtin | |
operators (e.g. :meth:`~torch.add`) and annotated | |
TorchScript functions. | |
args (tuple): the argument tuple for the ``func`` invocation. | |
kwargs (dict): is a dictionary of keyword arguments for the ``func`` | |
invocation. | |
timeout (float, optional): timeout in seconds for this remote call. If the | |
creation of this | |
:class:`~torch.distributed.rpc.RRef` on worker | |
``to`` is not successfully processed on this | |
worker within this timeout, then the next time | |
there is an attempt to use the RRef (such as | |
``to_here()``), a timeout will be raised | |
indicating this failure. A value of 0 indicates | |
an infinite timeout, i.e. a timeout error will | |
never be raised. If not provided, the default | |
value set during initialization or with | |
``_set_rpc_timeout`` is used. | |
Returns: | |
A user :class:`~torch.distributed.rpc.RRef` instance to the result | |
value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here` | |
to retrieve the result value locally. | |
.. warning :: | |
The ``remote`` API does not copy storages of argument tensors until | |
sending them over the wire, which could be done by a different thread | |
depending on the RPC backend type. The caller should make sure that the | |
contents of those tensors stay intact until the returned RRef is | |
confirmed by the owner, which can be checked using the | |
:meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API. | |
.. warning :: | |
Errors such as timeouts for the ``remote`` API are handled on a | |
best-effort basis. This means that when remote calls initiated by | |
``remote`` fail, such as with a timeout error, we take a best-effort | |
approach to error handling. This means that errors are handled and set | |
on the resulting RRef on an asynchronous basis. If the RRef has not been | |
used by the application before this handling (such as ``to_here`` or | |
fork call), then future uses of the ``RRef`` will appropriately raise | |
errors. However, it is possible that the user application will use the | |
``RRef`` before the errors are handled. In this case, errors may not be | |
raised as they have not yet been handled. | |
Example:: | |
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly | |
on both workers. Refer to :meth:`~torch.distributed.init_process_group` | |
API for more details. For example, | |
export MASTER_ADDR=localhost | |
export MASTER_PORT=5678 | |
Then run the following code in two different processes: | |
>>> # xdoctest: +SKIP | |
>>> # On worker 0: | |
>>> import torch | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker0", rank=0, world_size=2) | |
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) | |
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) | |
>>> x = rref1.to_here() + rref2.to_here() | |
>>> rpc.shutdown() | |
>>> # On worker 1: | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker1", rank=1, world_size=2) | |
>>> rpc.shutdown() | |
Below is an example of running a TorchScript function using RPC. | |
>>> # On both workers: | |
>>> @torch.jit.script | |
>>> def my_script_add(tensor: torch.Tensor, scalar: int): | |
>>> return torch.add(tensor, scalar) | |
>>> # On worker 0: | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker0", rank=0, world_size=2) | |
>>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3)) | |
>>> rref.to_here() | |
>>> rpc.shutdown() | |
>>> # On worker 1: | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker1", rank=1, world_size=2) | |
>>> rpc.shutdown() | |
""" | |
torch._C._log_api_usage_once("torch.distributed.rpc_remote") | |
qualified_name = torch.jit._builtins._find_builtin(func) | |
dst_worker_info = _to_worker_info(to) | |
should_profile = _get_should_profile() | |
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info) | |
with ctx_manager as rf: | |
args = args if args else () | |
kwargs = kwargs if kwargs else {} | |
is_async_exec = hasattr(func, "_wrapped_async_rpc_function") | |
if is_async_exec: | |
wrapped = func._wrapped_async_rpc_function | |
if isinstance(wrapped, torch.jit.ScriptFunction): | |
func = wrapped | |
if qualified_name is not None: | |
rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs) | |
elif isinstance(func, torch.jit.ScriptFunction): | |
rref = _invoke_remote_torchscript( | |
dst_worker_info.name, | |
torch._jit_internal._qualified_name(func), | |
timeout, | |
is_async_exec, | |
*args, | |
**kwargs, | |
) | |
else: | |
(pickled_python_udf, tensors) = _default_pickler.serialize( | |
PythonUDF(func, args, kwargs) | |
) | |
rref = _invoke_remote_python_udf( | |
dst_worker_info, | |
pickled_python_udf, | |
tensors, | |
timeout, | |
is_async_exec | |
) | |
# attach profiling information | |
if should_profile: | |
assert torch.autograd._profiler_enabled() | |
assert rf is not None | |
fut = rf._call_end_callbacks_on_future(rref._get_future()) | |
rref._set_profiling_future(fut) | |
return rref | |
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT): | |
if not callable(func): | |
raise TypeError("function should be callable.") | |
qualified_name = torch.jit._builtins._find_builtin(func) | |
dst_worker_info = _to_worker_info(to) | |
should_profile = _get_should_profile() | |
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info) | |
with ctx_manager as rf: | |
args = args if args else () | |
kwargs = kwargs if kwargs else {} | |
is_async_exec = hasattr(func, "_wrapped_async_rpc_function") | |
if is_async_exec: | |
wrapped = func._wrapped_async_rpc_function | |
if isinstance(wrapped, torch.jit.ScriptFunction): | |
func = wrapped | |
if qualified_name is not None: | |
fut = _invoke_rpc_builtin( | |
dst_worker_info, | |
qualified_name, | |
rpc_timeout, | |
*args, | |
**kwargs | |
) | |
elif isinstance(func, torch.jit.ScriptFunction): | |
fut = _invoke_rpc_torchscript( | |
dst_worker_info.name, | |
torch._jit_internal._qualified_name(func), | |
args, | |
kwargs, | |
rpc_timeout, | |
is_async_exec | |
) | |
else: | |
(pickled_python_udf, tensors) = _default_pickler.serialize( | |
PythonUDF(func, args, kwargs) | |
) | |
fut = _invoke_rpc_python_udf( | |
dst_worker_info, | |
pickled_python_udf, | |
tensors, | |
rpc_timeout, | |
is_async_exec | |
) | |
if should_profile: | |
assert torch.autograd._profiler_enabled() | |
assert rf is not None | |
# Schedule profiling callbacks to run when the future completes. | |
# This returns a future that is completed when the original future | |
# completes and the profiling callbacks have been completed as well, | |
# to guarantee that fut.wait() completes the profiling. This new | |
# future will contain the same value as the original future. | |
fut = rf._call_end_callbacks_on_future(fut) | |
return fut | |
def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT): | |
r""" | |
Make a blocking RPC call to run function ``func`` on worker ``to``. RPC | |
messages are sent and received in parallel to execution of Python code. This | |
method is thread-safe. | |
Args: | |
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. | |
func (Callable): a callable function, such as Python callables, builtin | |
operators (e.g. :meth:`~torch.add`) and annotated | |
TorchScript functions. | |
args (tuple): the argument tuple for the ``func`` invocation. | |
kwargs (dict): is a dictionary of keyword arguments for the ``func`` | |
invocation. | |
timeout (float, optional): timeout in seconds to use for this RPC. If | |
the RPC does not complete in this amount of | |
time, an exception indicating it has | |
timed out will be raised. A value of 0 | |
indicates an infinite timeout, i.e. a timeout | |
error will never be raised. If not provided, | |
the default value set during initialization | |
or with ``_set_rpc_timeout`` is used. | |
Returns: | |
Returns the result of running ``func`` with ``args`` and ``kwargs``. | |
Example:: | |
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly | |
on both workers. Refer to :meth:`~torch.distributed.init_process_group` | |
API for more details. For example, | |
export MASTER_ADDR=localhost | |
export MASTER_PORT=5678 | |
Then run the following code in two different processes: | |
>>> # xdoctest: +SKIP | |
>>> # On worker 0: | |
>>> import torch | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker0", rank=0, world_size=2) | |
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3)) | |
>>> rpc.shutdown() | |
>>> # On worker 1: | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker1", rank=1, world_size=2) | |
>>> rpc.shutdown() | |
Below is an example of running a TorchScript function using RPC. | |
>>> # On both workers: | |
>>> @torch.jit.script | |
>>> def my_script_add(tensor: torch.Tensor, scalar: int): | |
>>> return torch.add(tensor, scalar) | |
>>> # On worker 0: | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker0", rank=0, world_size=2) | |
>>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3)) | |
>>> rpc.shutdown() | |
>>> # On worker 1: | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker1", rank=1, world_size=2) | |
>>> rpc.shutdown() | |
""" | |
torch._C._log_api_usage_once("torch.distributed.rpc_sync") | |
fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout) | |
return fut.wait() | |
def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): | |
r""" | |
Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC | |
messages are sent and received in parallel to execution of Python code. This | |
method is thread-safe. This method will immediately return a | |
:class:`~torch.futures.Future` that can be awaited on. | |
Args: | |
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. | |
func (Callable): a callable function, such as Python callables, builtin | |
operators (e.g. :meth:`~torch.add`) and annotated | |
TorchScript functions. | |
args (tuple): the argument tuple for the ``func`` invocation. | |
kwargs (dict): is a dictionary of keyword arguments for the ``func`` | |
invocation. | |
timeout (float, optional): timeout in seconds to use for this RPC. If | |
the RPC does not complete in this amount of | |
time, an exception indicating it has | |
timed out will be raised. A value of 0 | |
indicates an infinite timeout, i.e. a timeout | |
error will never be raised. If not provided, | |
the default value set during initialization | |
or with ``_set_rpc_timeout`` is used. | |
Returns: | |
Returns a :class:`~torch.futures.Future` object that can be waited | |
on. When completed, the return value of ``func`` on ``args`` and | |
``kwargs`` can be retrieved from the :class:`~torch.futures.Future` | |
object. | |
.. warning :: | |
Using GPU tensors as arguments or return values of ``func`` is not | |
supported since we don't support sending GPU tensors over the wire. You | |
need to explicitly copy GPU tensors to CPU before using them as | |
arguments or return values of ``func``. | |
.. warning :: | |
The ``rpc_async`` API does not copy storages of argument tensors until | |
sending them over the wire, which could be done by a different thread | |
depending on the RPC backend type. The caller should make sure that the | |
contents of those tensors stay intact until the returned | |
:class:`~torch.futures.Future` completes. | |
Example:: | |
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly | |
on both workers. Refer to :meth:`~torch.distributed.init_process_group` | |
API for more details. For example, | |
export MASTER_ADDR=localhost | |
export MASTER_PORT=5678 | |
Then run the following code in two different processes: | |
>>> # xdoctest: +SKIP | |
>>> # On worker 0: | |
>>> import torch | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker0", rank=0, world_size=2) | |
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3)) | |
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) | |
>>> result = fut1.wait() + fut2.wait() | |
>>> rpc.shutdown() | |
>>> # On worker 1: | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker1", rank=1, world_size=2) | |
>>> rpc.shutdown() | |
Below is an example of running a TorchScript function using RPC. | |
>>> # On both workers: | |
>>> @torch.jit.script | |
>>> def my_script_add(tensor: torch.Tensor, scalar: int): | |
>>> return torch.add(tensor, scalar) | |
>>> # On worker 0: | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker0", rank=0, world_size=2) | |
>>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3)) | |
>>> ret = fut.wait() | |
>>> rpc.shutdown() | |
>>> # On worker 1: | |
>>> import torch.distributed.rpc as rpc | |
>>> rpc.init_rpc("worker1", rank=1, world_size=2) | |
>>> rpc.shutdown() | |
""" | |
torch._C._log_api_usage_once("torch.distributed.rpc_async") | |
fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout) | |
if hasattr(_thread_local_var, "future_list"): | |
_thread_local_var.future_list.append(fut) | |
return fut | |
def _get_should_profile(): | |
# Legacy profiler should be enabled. RPC profiling is not supported with | |
# Kineto profiler. | |
ActiveProfilerType = torch._C._profiler.ActiveProfilerType | |
return ( | |
torch.autograd._profiler_enabled() and | |
torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined] | |
) | |
def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info): | |
ctx_manager = contextlib.nullcontext() | |
if should_profile: | |
# Create appropriate string representation based on type of func | |
# (builtin, script, python) | |
if qualified_name is None: | |
func_name = ( | |
torch._jit_internal._qualified_name(func) | |
if isinstance(func, torch.jit.ScriptFunction) | |
else func.__qualname__ | |
) | |
else: | |
func_name = qualified_name | |
# Build RPC profiling key. | |
rpc_profiling_key = _build_rpc_profiling_key( | |
rpc_type, | |
func_name, | |
get_worker_info().name, | |
dst_worker_info.name, | |
) | |
RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) | |
# Mypy doesn't support re-def of a variable not in the same block (#1174) | |
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment] | |
return ctx_manager | |