Spaces:
Running
Running
File size: 6,943 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import torch
from ..modules import Module
from . import comm
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union, cast
from torch._utils import _get_device_index
from collections import OrderedDict
if TYPE_CHECKING:
import torch.jit
import torch.jit._state
__all__ = ['replicate']
def _is_script_module(module: Module) -> bool:
import torch.jit
return isinstance(module, torch.jit.ScriptModule)
def _is_script_method(module: Module) -> bool:
import torch.jit
return isinstance(module, torch._C.ScriptMethod)
def _init_script_module() -> "torch.jit.ScriptModule":
import torch.jit
return torch.jit.ScriptModule()
def _is_jit_enabled() -> "torch.jit._state.EnabledProxy":
import torch.jit._state
return torch.jit._state._enabled
# Check if we can safely replicate the module.
# there are two types of module:
# 1. python modules
# 2. ScriptModule
#
# currently a module cannot be replicated properly if the descendants of
# any ScriptModule contains python module (type 1 above)
def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool:
# module.modules() contains module itself as the first element
def descendant_modules(module: Module) -> Iterator[Module]:
gen = module.modules()
next(gen)
return gen
if not _is_jit_enabled():
return True
if memo is None:
memo = set()
# memoize visited modules
memo.add(module)
if _is_script_module(module):
memo.update(descendant_modules(module))
return all(_is_script_module(descendant) for
descendant in descendant_modules(module))
for child in module.children():
# since any unreplicatable module will cause the check to return
# False early, visited modules here can be safely ignored.
if child in memo:
continue
if not _replicatable_module(child, memo):
return False
return True
def _broadcast_coalesced_reshape(
tensors: Sequence[torch.Tensor],
devices: Sequence[Union[int, torch.device]],
detach: bool = False,
) -> List[List[torch.Tensor]]:
from ._functions import Broadcast
if detach:
return comm.broadcast_coalesced(tensors, devices)
else:
# Use the autograd function to broadcast if not detach
if len(tensors) > 0:
tensor_copies = Broadcast.apply(devices, *tensors)
return [tensor_copies[i:i + len(tensors)]
for i in range(0, len(tensor_copies), len(tensors))]
else:
return []
T = TypeVar("T", bound=Module)
def replicate(
network: T,
devices: Sequence[Union[int, torch.device]],
detach: bool = False,
) -> List[T]:
if not _replicatable_module(network):
raise RuntimeError("Cannot replicate network where python modules are "
"childrens of ScriptModule")
if not devices:
return []
devices = [_get_device_index(x, True) for x in devices]
num_replicas = len(devices)
params = list(network.parameters())
param_indices = {param: idx for idx, param in enumerate(params)}
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
buffers = list(network.buffers())
buffers_rg: List[torch.Tensor] = []
buffers_not_rg: List[torch.Tensor] = []
for buf in buffers:
if buf.requires_grad and not detach:
buffers_rg.append(buf)
else:
buffers_not_rg.append(buf)
buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}
buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True)
modules = list(network.modules())
module_copies: List[List[Module]] = [[] for _ in devices]
module_indices: Dict[Module, int] = {}
for i, module in enumerate(modules):
module_indices[module] = i
for j in range(num_replicas):
replica = module._replicate_for_data_parallel()
# This is a temporary fix for DDP. DDP needs to access the
# replicated model parameters. It used to do so through
# `mode.parameters()`. The fix added in #33907 for DP stops the
# `parameters()` API from exposing the replicated parameters.
# Hence, we add a `_former_parameters` dict here to support DDP.
replica._former_parameters = OrderedDict()
module_copies[j].append(replica)
for i, module in enumerate(modules):
for key, child in module._modules.items():
if child is None:
for j in range(num_replicas):
replica = module_copies[j][i]
replica._modules[key] = None
else:
module_idx = module_indices[child]
for j in range(num_replicas):
replica = module_copies[j][i]
setattr(replica, key, module_copies[j][module_idx])
for key, param in module._parameters.items():
if param is None:
for j in range(num_replicas):
replica = module_copies[j][i]
replica._parameters[key] = None
else:
param_idx = param_indices[param]
for j in range(num_replicas):
replica = module_copies[j][i]
param_copy = param_copies[j][param_idx]
# parameters in replicas are no longer leaves,
# so setattr them as non-parameter attributes
setattr(replica, key, param_copy)
# expose the parameter for DDP
replica._former_parameters[key] = param_copy
for key, buf in module._buffers.items(): # type: ignore[assignment]
if buf is None:
for j in range(num_replicas):
replica = module_copies[j][i]
replica._buffers[key] = None
else:
if buf.requires_grad and not detach:
buffer_copies = buffer_copies_rg
buffer_idx = buffer_indices_rg[buf]
else:
buffer_copies = buffer_copies_not_rg
buffer_idx = buffer_indices_not_rg[buf]
for j in range(num_replicas):
replica = module_copies[j][i]
setattr(replica, key, buffer_copies[j][buffer_idx])
return [cast(T, module_copies[j][0]) for j in range(num_replicas)]
|