Spaces:
Sleeping
Sleeping
import logging | |
from collections import defaultdict | |
from threading import Lock | |
from typing import List, Optional | |
import torch | |
import torch.distributed.autograd as dist_autograd | |
import torch.distributed.rpc as rpc | |
import torch.jit as jit | |
import torch.nn as nn | |
from torch import Tensor | |
from torch.distributed.rpc import RRef | |
from .utils import functional_optim_map | |
__all__ = ["DistributedOptimizer"] | |
logger = logging.getLogger(__name__) | |
# XXX: we define a _ScriptModuleOptimizer here to explicitly | |
# compile the FunctionalOptimizer class into TorchScript | |
# This is because ScriptClass instance still lives in | |
# python unless you explicitly compile it as an attribute | |
# in ScriptModule or pass it to a ScriptFunction | |
# _ScriptLocalOptimizerInterface serves as a common | |
# interface type for Optimizer ScriptModules. | |
# | |
# TODO (wanchaol): remove this once we added TorchScript | |
# class reference semantics | |
class _ScriptLocalOptimizerInterface: | |
def step(self, autograd_ctx_id: int) -> None: | |
pass | |
class _ScriptLocalOptimizer(nn.Module): | |
# TorchScript does not support multithread concurrent compiling. | |
# request_callback might invoke concurrent compiling, so we | |
# serialize the compiling with a lock | |
compile_lock = Lock() | |
def __init__(self, optim_cls, local_params_rref, *args, **kwargs): | |
super().__init__() | |
self._local_params = [rref.local_value() for rref in local_params_rref] | |
self.optim = optim_cls(self._local_params, *args, **kwargs) | |
def step(self, autograd_ctx_id: int): | |
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) | |
# apply functional optimizer step with a list of gradients | |
grads: List[Optional[Tensor]] = [ | |
all_local_grads[p] if p in all_local_grads else None | |
for p in self._local_params | |
] | |
self.optim.step(grads) | |
# TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once | |
# we have converted all to functional optimizer in distributed.optim | |
class _LocalOptimizer: | |
# Ideally we would only need to share a lock for instances of | |
# _LocalOptimizer that deal with the same parameters. We are | |
# making a simplifying assumption here that if there is more | |
# than one instance of _LocalOptimizer per worker, they will | |
# be optimizing the same parameters (e.g. each data parallel | |
# trainer will create its own instance of _LocalOptimizer but | |
# they will all optimize the same parameters on each worker) | |
global_lock = Lock() | |
def __init__(self, optim_cls, local_params_rref, *args, **kwargs): | |
self._local_params = [rref.local_value() for rref in local_params_rref] | |
self.optim = optim_cls(self._local_params, *args, **kwargs) | |
def step(self, autograd_ctx_id): | |
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) | |
with _LocalOptimizer.global_lock: | |
for param, grad in all_local_grads.items(): | |
param.grad = grad | |
self.optim.step() | |
def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): | |
return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)) | |
def _local_optimizer_step(local_optim_rref, autograd_ctx_id): | |
local_optim = local_optim_rref.local_value() | |
local_optim.step(autograd_ctx_id) | |
# new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer | |
def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): | |
optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs) | |
with _ScriptLocalOptimizer.compile_lock: | |
script_optim = jit.script(optim) | |
return rpc.RRef(script_optim, _ScriptLocalOptimizerInterface) | |
def _script_local_optimizer_step( | |
local_optim_rref: RRef[_ScriptLocalOptimizerInterface], autograd_ctx_id: int | |
) -> None: | |
local_optim = local_optim_rref.local_value() | |
local_optim.step(autograd_ctx_id) | |
def _wait_for_all(rpc_futs): | |
# TODO: improve error propagation | |
exception = None | |
results = [] | |
for fut in rpc_futs: | |
try: | |
results.append(fut.wait()) | |
except Exception as e: | |
results.append(e) | |
exception = e | |
if exception is not None: | |
raise exception | |
return results | |
class DistributedOptimizer: | |
""" | |
DistributedOptimizer takes remote references to parameters scattered | |
across workers and applies the given optimizer locally for each parameter. | |
This class uses :meth:`~torch.distributed.autograd.get_gradients` in order | |
to retrieve the gradients for specific parameters. | |
Concurrent calls to | |
:meth:`~torch.distributed.optim.DistributedOptimizer.step`, | |
either from the same or different clients, will | |
be serialized on each worker -- as each worker's optimizer can only work | |
on one set of gradients at a time. However, there is no guarantee that | |
the full forward-backward-optimizer sequence will execute for one client | |
at a time. This means that the gradients being applied may not correspond | |
to the latest forward pass executed on a given worker. Also, there is no | |
guaranteed ordering across workers. | |
`DistributedOptimizer` creates the local optimizer with TorchScript enabled | |
by default, so that optimizer updates are not blocked by the Python Global | |
Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed | |
Model Parallel). This feature is currently enabled for most optimizers. You | |
can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support | |
for your own custom optimizers. | |
Args: | |
optimizer_class (optim.Optimizer): the class of optimizer to | |
instantiate on each worker. | |
params_rref (list[RRef]): list of RRefs to local or remote parameters | |
to optimize. | |
args: arguments to pass to the optimizer constructor on each worker. | |
kwargs: arguments to pass to the optimizer constructor on each worker. | |
Example:: | |
>>> # xdoctest: +SKIP("distributed") | |
>>> import torch.distributed.autograd as dist_autograd | |
>>> import torch.distributed.rpc as rpc | |
>>> from torch import optim | |
>>> from torch.distributed.optim import DistributedOptimizer | |
>>> | |
>>> with dist_autograd.context() as context_id: | |
>>> # Forward pass. | |
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) | |
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) | |
>>> loss = rref1.to_here() + rref2.to_here() | |
>>> | |
>>> # Backward pass. | |
>>> dist_autograd.backward(context_id, [loss.sum()]) | |
>>> | |
>>> # Optimizer. | |
>>> dist_optim = DistributedOptimizer( | |
>>> optim.SGD, | |
>>> [rref1, rref2], | |
>>> lr=0.05, | |
>>> ) | |
>>> dist_optim.step(context_id) | |
__ https://github.com/pytorch/tutorials/pull/1465 | |
""" | |
def __init__(self, optimizer_class, params_rref, *args, **kwargs): | |
torch._C._log_api_usage_once("torch.distributed.optim.DistributedOptimizer") | |
per_worker_params_rref = defaultdict(list) | |
for param in params_rref: | |
per_worker_params_rref[param.owner()].append(param) | |
if optimizer_class in functional_optim_map and jit._state._enabled: | |
optim_ctor = functional_optim_map.get(optimizer_class) | |
else: | |
optim_ctor = optimizer_class | |
self.is_functional_optim = optim_ctor != optimizer_class | |
if self.is_functional_optim: | |
optimizer_new_func = _new_script_local_optimizer | |
else: | |
logger.warning( | |
"Creating the optimizer %s without TorchScript support, " | |
"this might result in slow computation time in multithreading environment" | |
"(i.e. Distributed Model Parallel training on CPU) due to the Python's " | |
"Global Interpreter Lock (GIL). Please file an issue if you need this " | |
"optimizer in TorchScript. ", | |
optimizer_class | |
) | |
optimizer_new_func = _new_local_optimizer | |
remote_optim_futs = [] | |
for worker, param_rrefs in per_worker_params_rref.items(): | |
remote_optim_rref_fut = rpc.rpc_async( | |
worker, | |
optimizer_new_func, | |
args=(optim_ctor, param_rrefs) + args, | |
kwargs=kwargs, | |
) | |
remote_optim_futs.append(remote_optim_rref_fut) | |
self.remote_optimizers = _wait_for_all(remote_optim_futs) | |
def step(self, context_id): | |
""" | |
Performs a single optimization step. | |
This will call :meth:`torch.optim.Optimizer.step` on each worker | |
containing parameters to be optimized, and will block until all workers | |
return. The provided ``context_id`` will be used to retrieve the | |
corresponding :class:`~torch.distributed.autograd.context` that | |
contains the gradients that should be applied to the parameters. | |
Args: | |
context_id: the autograd context id for which we should run the | |
optimizer step. | |
""" | |
dist_autograd._is_valid_context(context_id) | |
optimizer_step_func = ( | |
_script_local_optimizer_step | |
if self.is_functional_optim | |
else _local_optimizer_step | |
) | |
rpc_futs = [] | |
for optimizer in self.remote_optimizers: | |
rpc_futs.append( | |
rpc.rpc_async( | |
optimizer.owner(), | |
optimizer_step_func, | |
args=(optimizer, context_id), | |
) | |
) | |
_wait_for_all(rpc_futs) | |