File size: 10,581 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import copy
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.utils._pytree as pytree
from torch._export.utils import _check_input_constraints_for_graph
from torch.export.unflatten import _assign_attr, _AttrKind
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from ._remove_effect_tokens_pass import _remove_effect_tokens

from .exported_program import (
    ExportedProgram,
    ExportGraphSignature,
    InputKind,
    OutputKind,
)


@torch._dynamo.disable
def _check_input_constraints_pre_hook(self, *args, **kwargs):
    flat_args_with_path, received_spec = pytree.tree_flatten_with_path(args)

    if received_spec != self._in_spec:
        raise ValueError(  # noqa: TRY200
            "Trying to flatten user inputs with exported input tree spec: \n"
            f"{self._in_spec}\n"
            "but actually got inputs with tree spec of: \n"
            f"{received_spec}"
        )

    return _check_input_constraints_for_graph(
        [node for node in self.graph.nodes if node.op == "placeholder"],
        flat_args_with_path,
        self.range_constraints,
    )


def _unlift_inputs_as_getattr(

    gm: torch.fx.GraphModule,

    lifted_inputs: List[Optional[str]],

) -> Tuple[Dict[str, torch.fx.Node], Dict[str, torch.fx.Node]]:
    """

    Unlift inputs referring to params/buffers/constants as getattr nodes in the

    graph

    """
    unlifted_name_to_node = {}
    input_name_to_node = {}

    placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
    assert len(lifted_inputs) == len(placeholder_nodes)
    for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs):
        if lifted_node is None:
            input_name_to_node[input_node.name] = input_node

        else:
            with gm.graph.inserting_after(input_node):
                getattr_node = gm.graph.get_attr(lifted_node)
                input_node.replace_all_uses_with(getattr_node)
                metadata = input_node.meta
                gm.graph.erase_node(input_node)
                getattr_node.meta = metadata
                unlifted_name_to_node[lifted_node] = getattr_node

    return unlifted_name_to_node, input_name_to_node


def _insert_copy_for_mutations(

    gm: torch.fx.GraphModule,

    mutated_outputs: List[Optional[str]],

    unlifted_name_to_node: Dict[str, torch.fx.Node],

    input_name_to_node: Dict[str, torch.fx.Node],

) -> None:
    """

    Find the all the buffers and inputs that were mutated and insert copy_

    operators to reflect mutations.

    """
    output_node = None
    for node in gm.graph.nodes:
        if node.op == "output":
            output_node = node
            break
    assert output_node is not None
    outputs = pytree.tree_flatten(output_node.args)[0]
    assert len(outputs) == len(mutated_outputs)

    user_output_nodes = []
    for return_node, mutated_node_name in zip(outputs, mutated_outputs):
        if mutated_node_name is None:
            user_output_nodes.append(return_node)
            continue

        if mutated_node_name in unlifted_name_to_node:
            mutated_node = unlifted_name_to_node[mutated_node_name]
        elif mutated_node_name in input_name_to_node:
            mutated_node = input_name_to_node[mutated_node_name]
        else:
            raise RuntimeError(
                f"Could not find {mutated_node_name} in either buffer or input nodes"
            )

        with gm.graph.inserting_before(output_node):
            _ = gm.graph.call_function(
                torch.ops.aten.copy_.default, (mutated_node, return_node)
            )

    with gm.graph.inserting_before(output_node):
        # Only return user outputs
        new_output = gm.graph.output(tuple(user_output_nodes))
        output_node.replace_all_uses_with(new_output)
        gm.graph.erase_node(output_node)


def _get_codegen(

    in_spec: pytree.TreeSpec,

    out_spec: Optional[pytree.TreeSpec],

) -> _PyTreeCodeGen:
    """

    Create the codegen for the graph module based on the in/out specs

    """
    if (
        in_spec.type == tuple
        and in_spec.num_children == 2
        and in_spec.children_specs[0].type == tuple
        and in_spec.children_specs[1].type == dict
    ):
        # if in_spec contains the args (tuple) and kwargs (dict)
        names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
        # add kwarg names
        names.extend(in_spec.children_specs[1].context)
    else:
        names = [f"arg_{i}" for i in range(in_spec.num_children)]

    return _PyTreeCodeGen(
        _PyTreeInfo(
            names,
            in_spec,
            out_spec,
        )
    )


def _unlift(

    gm: torch.fx.GraphModule,

    lifted_inputs: List[Optional[str]],

    mutated_outputs: List[Optional[str]],

    in_spec: pytree.TreeSpec,

    out_spec: Optional[pytree.TreeSpec],

    state_dict: Dict[str, Any],

    constants: Dict[str, Any],

):
    """

    Args:

        lifted_inputs: A list matching the graph module's input nodes. For

        an input node that is referring to a lifted parameter/buffer, this

        list will contain the fqn the corresponding attribute. Otherwise, this

        list will contain None. This is used to unlift the lifted parameters as

        get_attr nodes.



        mutated_outputs: A list matching the graph module's output nodes. For

        an output node that is referring to a mutated buffer or user input, this

        list will contain the name of the corresponding buffer or user input

        that needs to be mutated. Otherwise, this list will contain None. This

        is used to re-insert an inplace copy_ operator to copy the mutated

        values back to the original node.

    """
    unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr(
        gm, lifted_inputs
    )
    _insert_copy_for_mutations(
        gm, mutated_outputs, unlifted_name_to_node, input_name_to_node
    )
    gm.graph._codegen = _get_codegen(in_spec, out_spec)
    gm.graph.lint()
    gm.graph.eliminate_dead_code()
    gm.recompile()
    return gm


def _register_attrs_to_new_gm(

    new_gm: torch.fx.GraphModule,

    graph_signature: ExportGraphSignature,

    state_dict: Dict[str, Any],

    constants: Dict[str, Any],

) -> None:
    non_persistent_buffers = set(graph_signature.non_persistent_buffers)
    for name in graph_signature.buffers:
        if name in non_persistent_buffers:
            persistent = False
            value = constants[name]
        else:
            persistent = True
            value = state_dict[name]
        _assign_attr(
            value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent
        )
    for name in graph_signature.parameters:
        value = state_dict[name]
        _assign_attr(
            value,
            new_gm,
            name,
            attr_kind=_AttrKind.PARAMETER,
        )

    for name in chain(
        graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants
    ):
        value = constants[name]
        _assign_attr(
            value,
            new_gm,
            name,
            attr_kind=_AttrKind.CONSTANT,
        )


class _StatefulGraphModuleFactory(type):
    """

    Metaclass that ensures a private constructor for _StatefulGraphModule

    """

    def __call__(cls, *args, **kwargs):
        raise TypeError(
            f"{cls.__module__}.{cls.__qualname__} has no public constructor. "
        )

    def _create(cls, root, graph, range_constraints=None):
        return super().__call__(
            root,
            graph,
            range_constraints=range_constraints,
        )


class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory):
    def __init__(self, root, graph, range_constraints=None):
        super().__init__(root, graph)
        # Need to fix up non-persistent buffers.
        self.range_constraints = range_constraints or []


def _create_stateful_graph_module(

    plain_graph_module: torch.fx.GraphModule,

    range_constraints,

    # TODO(suo) this should not be optional, but is since we still ahve

    # capture_pre_autograd_graph grr

    graph_signature: Optional[ExportGraphSignature] = None,

):
    stateful_gm = _StatefulGraphModule._create(
        plain_graph_module,
        plain_graph_module.graph,
        range_constraints=range_constraints,
    )
    stateful_gm.register_forward_pre_hook(
        _check_input_constraints_pre_hook, with_kwargs=True
    )

    if graph_signature is None:
        return stateful_gm
    # Fix up non-persistent buffers. torch.fx does not distinguish between
    # persistent and non-persistent buffers, so we must restore that distinction
    # here.
    for buffer in graph_signature.non_persistent_buffers:
        _assign_attr(
            plain_graph_module.get_buffer(buffer),
            stateful_gm,
            buffer,
            attr_kind=_AttrKind.BUFFER,
            persistent=False,
        )

    return stateful_gm


def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module:
    ep = _remove_effect_tokens(ep)
    new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
    _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)

    lifted_inputs: List[Optional[str]] = [
        in_spec.target
        if in_spec.kind
        in (
            InputKind.BUFFER,
            InputKind.CONSTANT_TENSOR,
            InputKind.PARAMETER,
            InputKind.CUSTOM_OBJ,
        )
        else None
        for in_spec in ep.graph_signature.input_specs
    ]

    mutated_outputs: List[Optional[str]] = [
        out_spec.target
        if out_spec.kind in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION)
        else None
        for out_spec in ep.graph_signature.output_specs
    ]

    new_gm = _unlift(
        new_gm,
        lifted_inputs,
        mutated_outputs,
        ep.call_spec.in_spec,
        ep.call_spec.out_spec,
        ep.state_dict,
        ep.constants,
    )
    unlift_gm = _create_stateful_graph_module(
        new_gm, ep.range_constraints, ep.graph_signature
    )
    unlift_gm.meta.update(ep.graph_module.meta)
    return unlift_gm