Spaces:
Running
Running
File size: 7,567 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import uuid
from collections import OrderedDict
from functools import wraps
from typing import Callable, Dict, List, Optional, Type
import torch.nn as nn
from torch.distributed._composable_state import _State
def generate_state_key(string="__composable_api_state_key"):
return f"{string}_{str(uuid.uuid4())}"
STATE_KEY = generate_state_key()
REGISTRY_KEY = generate_state_key()
# TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
# we can add args and kwargs here, and then we can detect whether fully_shard
# is combined with reentrant activation checkpointing and error out with a clear
# message.
class RegistryItem:
pass
def contract(state_cls: Type[_State] = _State):
r"""
Decorate a function as a composable distributed API, where the first
argument of the function must be an :class:`nn.Module` instance. The
decorator verifies that the wrapped function does not modify parameter,
buffer or sub-module fully-qualified names (FQN).
When a function ``func`` is decorated by ``@contract()``, a
``.state(module: nn.Module)`` method will be installed to the decorated
function. Then you can retrieve and modify the state on a module by calling
``func.state(module)``.
Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)
>>>
>>> def forward(self, x):
>>> return self.l2(self.l1(x))
>>>
>>> @contract()
>>> def my_feature(module: nn.Module) -> nn.Module:
>>> my_feature.state(module).some_state = "any value"
>>> return module
>>>
>>> model = MyModel()
>>> my_feature(model.l1)
>>> assert my_feature.state(model.l1).some_state == "any value"
>>> my_feature(model.l2)
>>> model(torch.randn(2, 10)).sum().backward()
"""
# wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
@wraps(state_cls)
def inner(func):
@wraps(func)
def wrapper(module: nn.Module, *args, **kwargs) -> Optional[nn.Module]:
# get existing global states
default_all_state: Dict[Callable, _State] = OrderedDict()
all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY, default_all_state
)
assert isinstance(
all_state, dict
), "Distributed composable API states corrupted"
# get global registry
default_registry: Dict[str, RegistryItem] = OrderedDict()
registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
REGISTRY_KEY, default_registry
)
assert isinstance(
registry, dict
), "Distributed composable API registry corrupted"
# make sure the API func has not been applied to the input module yet.
assert func not in all_state and func.__name__ not in registry, (
"Each distinct composable distributed API can only be applied to a "
f"module once. {func.__name__} has already been applied to the "
f"following module.\n{module}"
)
# install states specific to the wrapped ``func``
all_state.setdefault(func, state_cls())
# register ``func`` in the global registry by name
registry.setdefault(func.__name__, RegistryItem())
orig_named_params = OrderedDict(module.named_parameters())
orig_named_buffers = OrderedDict(
module.named_buffers(remove_duplicate=False)
)
orig_named_modules = OrderedDict(
module.named_modules(remove_duplicate=False)
)
updated = func(module, *args, **kwargs)
if updated is None:
updated = module
new_named_params = OrderedDict(updated.named_parameters())
new_named_buffers = OrderedDict(
updated.named_buffers(remove_duplicate=False)
)
new_named_modules = OrderedDict(
updated.named_modules(remove_duplicate=False)
)
assert isinstance(updated, nn.Module), (
"Output of composable distributed APIs must be either None or "
f"nn.Module, but got {type(updated)}"
)
def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str):
if orig_fqns == new_fqns:
return
orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
orig_only = orig_fqn_set - new_fqn_set
new_only = new_fqn_set - orig_fqn_set
if len(orig_only) or len(new_only):
raise RuntimeError(
f"{check_key}"
"Composable distributed API implementations cannot modify "
"FQNs.\n"
f"Only in original FQNs: {orig_only},\n"
f"Only in new FQNs: {new_only}"
)
else:
raise RuntimeError(
f"{check_key}"
"Composable distributed API implementations cannot modify "
"the order of FQNs.\n"
f"Original FQNs: {orig_only}\n"
f"New FQNs: {new_only}"
)
check_fqn(
list(orig_named_params.keys()),
list(new_named_params.keys()),
"Check parameters, ",
)
check_fqn(
list(orig_named_buffers.keys()),
list(new_named_buffers.keys()),
"Check buffer, ",
)
check_fqn(
list(orig_named_modules.keys()),
list(new_named_modules.keys()),
"Check modules, ",
)
# TODO: a stricter verification should also reject changing module
# types and monkey-patching forward() method implementations.
# TODO: verify that installed distributed paradigms are compatible with
# each other.
return updated
def get_state(module: nn.Module) -> Optional[_State]:
return module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY,
{}, # TODO(@yhcharles): this is a temporary fix, need a better way
).get(
func
) # type: ignore[call-overload]
wrapper.state = get_state # type: ignore[attr-defined]
return wrapper
return inner
def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]:
r"""
Get an ``OrderedDict`` of composable APIs that have been applied to the
``module``, indexed by the API name. If no API has been applied, then this
returns ``None``.
"""
return getattr(module, REGISTRY_KEY, None)
|