File size: 7,567 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import uuid
from collections import OrderedDict
from functools import wraps
from typing import Callable, Dict, List, Optional, Type

import torch.nn as nn
from torch.distributed._composable_state import _State


def generate_state_key(string="__composable_api_state_key"):
    return f"{string}_{str(uuid.uuid4())}"


STATE_KEY = generate_state_key()
REGISTRY_KEY = generate_state_key()


# TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
# we can add args and kwargs here, and then we can detect whether fully_shard
# is combined with reentrant activation checkpointing and error out with a clear
# message.
class RegistryItem:
    pass


def contract(state_cls: Type[_State] = _State):
    r"""

    Decorate a function as a composable distributed API, where the first

    argument of the function must be an :class:`nn.Module` instance. The

    decorator verifies that the wrapped function does not modify parameter,

    buffer or sub-module fully-qualified names (FQN).



    When a function ``func`` is decorated by ``@contract()``, a

    ``.state(module: nn.Module)`` method will be installed to the decorated

    function. Then you can retrieve and modify the state on a module by calling

    ``func.state(module)``.



    Example::

        >>> # xdoctest: +SKIP

        >>> import torch.nn as nn

        >>>

        >>> class MyModel(nn.Module):

        >>>     def __init__(self):

        >>>         super().__init__()

        >>>         self.l1 = nn.Linear(10, 10)

        >>>         self.l2 = nn.Linear(10, 10)

        >>>

        >>>     def forward(self, x):

        >>>         return self.l2(self.l1(x))

        >>>

        >>> @contract()

        >>> def my_feature(module: nn.Module) -> nn.Module:

        >>>     my_feature.state(module).some_state = "any value"

        >>>     return module

        >>>

        >>> model = MyModel()

        >>> my_feature(model.l1)

        >>> assert my_feature.state(model.l1).some_state == "any value"

        >>> my_feature(model.l2)

        >>> model(torch.randn(2, 10)).sum().backward()

    """

    # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
    @wraps(state_cls)
    def inner(func):
        @wraps(func)
        def wrapper(module: nn.Module, *args, **kwargs) -> Optional[nn.Module]:
            # get existing global states
            default_all_state: Dict[Callable, _State] = OrderedDict()
            all_state: Dict[Callable, _State] = module.__dict__.setdefault(  # type: ignore[call-overload]
                STATE_KEY, default_all_state
            )
            assert isinstance(
                all_state, dict
            ), "Distributed composable API states corrupted"

            # get global registry
            default_registry: Dict[str, RegistryItem] = OrderedDict()
            registry: Dict[str, RegistryItem] = module.__dict__.setdefault(  # type: ignore[call-overload]
                REGISTRY_KEY, default_registry
            )

            assert isinstance(
                registry, dict
            ), "Distributed composable API registry corrupted"

            # make sure the API func has not been applied to the input module yet.
            assert func not in all_state and func.__name__ not in registry, (
                "Each distinct composable distributed API can only be applied to a "
                f"module once. {func.__name__} has already been applied to the "
                f"following module.\n{module}"
            )

            # install states specific to the wrapped ``func``
            all_state.setdefault(func, state_cls())
            # register ``func`` in the global registry by name
            registry.setdefault(func.__name__, RegistryItem())

            orig_named_params = OrderedDict(module.named_parameters())
            orig_named_buffers = OrderedDict(
                module.named_buffers(remove_duplicate=False)
            )
            orig_named_modules = OrderedDict(
                module.named_modules(remove_duplicate=False)
            )

            updated = func(module, *args, **kwargs)

            if updated is None:
                updated = module

            new_named_params = OrderedDict(updated.named_parameters())
            new_named_buffers = OrderedDict(
                updated.named_buffers(remove_duplicate=False)
            )
            new_named_modules = OrderedDict(
                updated.named_modules(remove_duplicate=False)
            )

            assert isinstance(updated, nn.Module), (
                "Output of composable distributed APIs must be either None or "
                f"nn.Module, but got {type(updated)}"
            )

            def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str):
                if orig_fqns == new_fqns:
                    return

                orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
                orig_only = orig_fqn_set - new_fqn_set
                new_only = new_fqn_set - orig_fqn_set
                if len(orig_only) or len(new_only):
                    raise RuntimeError(
                        f"{check_key}"
                        "Composable distributed API implementations cannot modify "
                        "FQNs.\n"
                        f"Only in original FQNs: {orig_only},\n"
                        f"Only in new FQNs: {new_only}"
                    )
                else:
                    raise RuntimeError(
                        f"{check_key}"
                        "Composable distributed API implementations cannot modify "
                        "the order of FQNs.\n"
                        f"Original FQNs: {orig_only}\n"
                        f"New FQNs: {new_only}"
                    )

            check_fqn(
                list(orig_named_params.keys()),
                list(new_named_params.keys()),
                "Check parameters, ",
            )
            check_fqn(
                list(orig_named_buffers.keys()),
                list(new_named_buffers.keys()),
                "Check buffer, ",
            )
            check_fqn(
                list(orig_named_modules.keys()),
                list(new_named_modules.keys()),
                "Check modules, ",
            )

            # TODO: a stricter verification should also reject changing module
            # types and monkey-patching forward() method implementations.

            # TODO: verify that installed distributed paradigms are compatible with
            # each other.

            return updated

        def get_state(module: nn.Module) -> Optional[_State]:
            return module.__dict__.setdefault(  # type: ignore[call-overload]
                STATE_KEY,
                {},  # TODO(@yhcharles): this is a temporary fix, need a better way
            ).get(
                func
            )  # type: ignore[call-overload]

        wrapper.state = get_state  # type: ignore[attr-defined]

        return wrapper

    return inner


def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]:
    r"""

    Get an ``OrderedDict`` of composable APIs that have been applied to the

    ``module``, indexed by the API name. If no API has been applied, then this

    returns ``None``.

    """
    return getattr(module, REGISTRY_KEY, None)