Spaces:
Running
Running
import torch | |
from ..modules import Module | |
from . import comm | |
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union, cast | |
from torch._utils import _get_device_index | |
from collections import OrderedDict | |
if TYPE_CHECKING: | |
import torch.jit | |
import torch.jit._state | |
__all__ = ['replicate'] | |
def _is_script_module(module: Module) -> bool: | |
import torch.jit | |
return isinstance(module, torch.jit.ScriptModule) | |
def _is_script_method(module: Module) -> bool: | |
import torch.jit | |
return isinstance(module, torch._C.ScriptMethod) | |
def _init_script_module() -> "torch.jit.ScriptModule": | |
import torch.jit | |
return torch.jit.ScriptModule() | |
def _is_jit_enabled() -> "torch.jit._state.EnabledProxy": | |
import torch.jit._state | |
return torch.jit._state._enabled | |
# Check if we can safely replicate the module. | |
# there are two types of module: | |
# 1. python modules | |
# 2. ScriptModule | |
# | |
# currently a module cannot be replicated properly if the descendants of | |
# any ScriptModule contains python module (type 1 above) | |
def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool: | |
# module.modules() contains module itself as the first element | |
def descendant_modules(module: Module) -> Iterator[Module]: | |
gen = module.modules() | |
next(gen) | |
return gen | |
if not _is_jit_enabled(): | |
return True | |
if memo is None: | |
memo = set() | |
# memoize visited modules | |
memo.add(module) | |
if _is_script_module(module): | |
memo.update(descendant_modules(module)) | |
return all(_is_script_module(descendant) for | |
descendant in descendant_modules(module)) | |
for child in module.children(): | |
# since any unreplicatable module will cause the check to return | |
# False early, visited modules here can be safely ignored. | |
if child in memo: | |
continue | |
if not _replicatable_module(child, memo): | |
return False | |
return True | |
def _broadcast_coalesced_reshape( | |
tensors: Sequence[torch.Tensor], | |
devices: Sequence[Union[int, torch.device]], | |
detach: bool = False, | |
) -> List[List[torch.Tensor]]: | |
from ._functions import Broadcast | |
if detach: | |
return comm.broadcast_coalesced(tensors, devices) | |
else: | |
# Use the autograd function to broadcast if not detach | |
if len(tensors) > 0: | |
tensor_copies = Broadcast.apply(devices, *tensors) | |
return [tensor_copies[i:i + len(tensors)] | |
for i in range(0, len(tensor_copies), len(tensors))] | |
else: | |
return [] | |
T = TypeVar("T", bound=Module) | |
def replicate( | |
network: T, | |
devices: Sequence[Union[int, torch.device]], | |
detach: bool = False, | |
) -> List[T]: | |
if not _replicatable_module(network): | |
raise RuntimeError("Cannot replicate network where python modules are " | |
"childrens of ScriptModule") | |
if not devices: | |
return [] | |
devices = [_get_device_index(x, True) for x in devices] | |
num_replicas = len(devices) | |
params = list(network.parameters()) | |
param_indices = {param: idx for idx, param in enumerate(params)} | |
param_copies = _broadcast_coalesced_reshape(params, devices, detach) | |
buffers = list(network.buffers()) | |
buffers_rg: List[torch.Tensor] = [] | |
buffers_not_rg: List[torch.Tensor] = [] | |
for buf in buffers: | |
if buf.requires_grad and not detach: | |
buffers_rg.append(buf) | |
else: | |
buffers_not_rg.append(buf) | |
buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} | |
buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)} | |
buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) | |
buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True) | |
modules = list(network.modules()) | |
module_copies: List[List[Module]] = [[] for _ in devices] | |
module_indices: Dict[Module, int] = {} | |
for i, module in enumerate(modules): | |
module_indices[module] = i | |
for j in range(num_replicas): | |
replica = module._replicate_for_data_parallel() | |
# This is a temporary fix for DDP. DDP needs to access the | |
# replicated model parameters. It used to do so through | |
# `mode.parameters()`. The fix added in #33907 for DP stops the | |
# `parameters()` API from exposing the replicated parameters. | |
# Hence, we add a `_former_parameters` dict here to support DDP. | |
replica._former_parameters = OrderedDict() | |
module_copies[j].append(replica) | |
for i, module in enumerate(modules): | |
for key, child in module._modules.items(): | |
if child is None: | |
for j in range(num_replicas): | |
replica = module_copies[j][i] | |
replica._modules[key] = None | |
else: | |
module_idx = module_indices[child] | |
for j in range(num_replicas): | |
replica = module_copies[j][i] | |
setattr(replica, key, module_copies[j][module_idx]) | |
for key, param in module._parameters.items(): | |
if param is None: | |
for j in range(num_replicas): | |
replica = module_copies[j][i] | |
replica._parameters[key] = None | |
else: | |
param_idx = param_indices[param] | |
for j in range(num_replicas): | |
replica = module_copies[j][i] | |
param_copy = param_copies[j][param_idx] | |
# parameters in replicas are no longer leaves, | |
# so setattr them as non-parameter attributes | |
setattr(replica, key, param_copy) | |
# expose the parameter for DDP | |
replica._former_parameters[key] = param_copy | |
for key, buf in module._buffers.items(): # type: ignore[assignment] | |
if buf is None: | |
for j in range(num_replicas): | |
replica = module_copies[j][i] | |
replica._buffers[key] = None | |
else: | |
if buf.requires_grad and not detach: | |
buffer_copies = buffer_copies_rg | |
buffer_idx = buffer_indices_rg[buf] | |
else: | |
buffer_copies = buffer_copies_not_rg | |
buffer_idx = buffer_indices_not_rg[buf] | |
for j in range(num_replicas): | |
replica = module_copies[j][i] | |
setattr(replica, key, buffer_copies[j][buffer_idx]) | |
return [cast(T, module_copies[j][0]) for j in range(num_replicas)] | |