File size: 21,969 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
import inspect
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
from collections import OrderedDict
import logging

import torch
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node

if TYPE_CHECKING:
    import sympy  # noqa: F401

__all__ = ["Partition", "split_module"]
_LOGGER = logging.getLogger(__name__)

@compatibility(is_backward_compatible=True)
class Partition:
    def __init__(self, name: str):
        self.name: str = name
        self.submod_name = f"submod_{name}"
        self.node_names: List[str] = []
        self.inputs: Dict[str, None] = {}
        self.outputs: Dict[str, None] = {}
        self.dependencies: Dict[str, None] = {}
        self.dependents: Dict[str, None] = {}
        self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
        self.environment: Dict[Node, Node] = {}
        self.targets: Dict[str, Any] = {}

    def __repr__(self) -> str:
        return (
            f"name: {self.name},\n"
            f" nodes: {self.node_names},\n"
            f" inputs: {self.inputs},\n"
            f" outputs: {self.outputs},\n"
            f" partitions depended on: {self.dependencies},\n"
            f" partition dependents: {self.dependents}"
        )


# Creates subgraphs out of main graph
@compatibility(is_backward_compatible=True)
def split_module(

    m: GraphModule,

    root_m: torch.nn.Module,

    split_callback: Callable[[Node], int],

    qualname_map: Optional[Dict[str, str]] = None,

    keep_original_order: Optional[bool] = False,

    keep_original_node_name: Optional[bool] = False,

):
    """

    Creates subgraphs out of main graph



    Args:

        m (GraphModule): Graph module to split

        root_m (torch.nn.Module): root nn module. Not currently used. Included

            because the root nn module is usually transformed via

            torch.fx._symbolic_trace.symbolic_trace (see example below)

        split_callback (Callable[[Node], int]): Callable function

            that maps a given Node instance to a numeric partition identifier.

            split_module will use this function as the policy for which operations

            appear in which partitions in the output Module.

        qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a

            mapping from new target names in the module after split to old target

            names in the original module.

        keep_original_order: Optional[bool]: keep the original order of the GraphModule

            or use the Topological order of the new constructed GraphModule





    Returns:

        GraphModule: the module after split.



    Example:



        This is a sample setup:



            import torch

            from torch.fx.symbolic_trace import symbolic_trace

            from torch.fx.graph_module import GraphModule

            from torch.fx.node import Node

            from torch.fx.passes.split_module import split_module



            class MyModule(torch.nn.Module):

                def __init__(self):

                    super().__init__()

                    self.param = torch.nn.Parameter(torch.rand(3, 4))

                    self.linear = torch.nn.Linear(4, 5)



                def forward(self, x, y):

                    z = self.linear(x + self.param).clamp(min=0.0, max=1.0)

                    w = self.linear(y).clamp(min=0.0, max=1.0)

                    return z + w



            # symbolically trace model

            my_module = MyModule()

            my_module_traced = symbolic_trace(my_module)



            # random mod partitioning

            partition_counter = 0

            NPARTITIONS = 3



            def mod_partition(node: Node):

                global partition_counter

                partition = partition_counter % NPARTITIONS

                partition_counter = (partition_counter + 1) % NPARTITIONS

                return partition



            # split module in module with submodules

            module_with_submodules = split_module(

                my_module_traced, my_module, mod_partition

            )



        Output looks like this. Original graph is broken into partitions



            > print(module_with_submodules)

            GraphModule(

                (submod_0): GraphModule(

                    (linear): Linear(in_features=4, out_features=5, bias=True)

                )

                (submod_1): GraphModule(

                    (linear): Linear(in_features=4, out_features=5, bias=True)

                )

                (submod_2): GraphModule()

            )



            def forward(self, x, y):

                param = self.param

                submod_0 = self.submod_0(x, param, y);  x = param = y = None

                getitem = submod_0[0]

                getitem_1 = submod_0[1];  submod_0 = None

                submod_1 = self.submod_1(getitem, getitem_1);  getitem = getitem_1 = None

                getitem_2 = submod_1[0]

                getitem_3 = submod_1[1];  submod_1 = None

                submod_2 = self.submod_2(getitem_2, getitem_3);  getitem_2 = getitem_3 = None

                return submod_2



        Output of split module is the same as output of input traced module.

        This is an example within a test setting:



            > orig_out = my_module_traced(x, y)

            > submodules_out = module_with_submodules(x, y)

            > self.assertEqual(orig_out, submodules_out)

            True

    """

    def construct_graph(

        node: Node,

        base_mod_env: Dict[str, Node],

        base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule],

    ):
        if node.op == "placeholder":
            default_value = (
                node.args[0] if len(node.args) > 0 else inspect.Signature.empty
            )
            if keep_original_node_name:
                args = () if default_value is inspect.Signature.empty else (default_value,)
                base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type)
            else:
                base_mod_env[node.name] = base_mod_graph.placeholder(
                    node.target, type_expr=node.type, default_value=default_value
                )
            base_mod_env[node.name].meta = node.meta.copy()
        elif node.op == "get_attr":
            base_mod_env[node.name] = base_mod_graph.get_attr(node.target)
            base_mod_env[node.name].meta = node.meta.copy()
            attr_val = m
            for atom in node.target.split("."):  # type: ignore[union-attr]
                if not hasattr(attr_val, atom):
                    raise AttributeError(f"Node target {node.target} not found!")
                attr_val = getattr(attr_val, atom)
            base_mod_attrs[node.target] = attr_val  # type: ignore[index]
        return base_mod_env, base_mod_attrs

    partitions: Dict[str, Partition] = {}
    orig_nodes: Dict[str, Node] = {}
    symbol_to_node: Dict["sympy.Symbol", Node] = {}

    def record_cross_partition_use(

        def_node: Node, use_node: Optional[Node]

    ):  # noqa: B950
        from torch.fx.experimental.symbolic_shapes import free_symbols

        defined = getattr(def_node, "_fx_partition", None)
        used = getattr(use_node, "_fx_partition", None)
        if defined != used:
            if defined is not None:
                def_partition = partitions[defined]
                def_partition.outputs.setdefault(def_node.name)
                if used is not None:
                    def_partition.dependents.setdefault(used)

            if used is not None:
                use_partition = partitions[used]
                use_partition.inputs.setdefault(def_node.name)
                if (def_val := def_node.meta.get("example_value")) is not None:
                    for s in sorted(free_symbols(def_val), key=str):
                        use_partition.inputs.setdefault(symbol_to_node[s].name)
                if defined is not None:
                    use_partition.dependencies.setdefault(defined)

    def instantiate_node_partition_mapping(node):
        partition_name = str(split_callback(node))

        # add node to partitions
        partition = partitions.get(partition_name)
        if partition is None:
            partitions[partition_name] = partition = Partition(partition_name)

        partition.node_names.append(node.name)
        node._fx_partition = partition_name

    # Global State Nodes are nodes which by their global state effects,
    # "taint" all downstream nodes while they are active.
    GLOBAL_STATE_NODES = [
        torch.amp._enter_autocast,
        torch.amp._exit_autocast,
        torch._C._set_grad_enabled
    ]

    # For grad regions:
    # ------------------------
    # 1. first region: we do nothing
    # 2. subsequent regions: we insert the set_grad at the beginning
    grad_regions: OrderedDict[Node, Set[int]] = OrderedDict()

    # For autocast regions:
    # ------------------------
    # 1. first region: we will only insert the _exit at the end
    # 2. intermediate regions: we will insert both the
    #    _enter at the beginning and _exit at the end
    # 3. last region: we will only insert _enter at the beginning
    # We will do so in the order in which the autocasts were instantiated.
    autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict()
    autocast_exits: Dict[Node, Optional[Node]] = {}

    active_grad = None
    active_autocasts = set()

    import sympy  # noqa: F811

    for node in m.graph.nodes:
        if node.op in ["placeholder", "get_attr", "output"]:
            if (
                node.op == "placeholder" and
                (val := node.meta.get("example_value")) is not None and
                isinstance(val, torch.SymInt) and
                isinstance(val.node.expr, sympy.Symbol)
            ):
                symbol_to_node[val.node.expr] = node
            continue

        instantiate_node_partition_mapping(node)

        if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
            if node.target == torch._C._set_grad_enabled:
                assert len(node.args) == 1
                assert isinstance(node.args[0], bool)
                active_grad = node
                grad_regions[active_grad] = set({split_callback(node)})
            elif node.target == torch.amp._enter_autocast:
                # Should all be python constants
                assert all(not isinstance(arg, Node) for arg in node.args)
                active_autocasts.add(node)
                autocast_regions[node] = set({split_callback(node)})
                autocast_exits[node] = None
            elif node.target == torch.amp._exit_autocast:
                assert len(node.args) == 1
                autocast_regions[node.args[0]].add(split_callback(node))
                active_autocasts.remove(node.args[0])
                autocast_exits[node.args[0]] = node

        if active_grad is not None:
            grad_regions[active_grad].add(split_callback(node))

        for a in active_autocasts:
            autocast_regions[a].add(split_callback(node))

    assert all(v is not None for v in autocast_exits.values()), "autocast must exit"

    autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
    grad_regions = {k: sorted(v) for k, v in grad_regions.items()}

    if _LOGGER.isEnabledFor(logging.DEBUG):
        _LOGGER.debug("autocast_regions: %s", autocast_regions)
        _LOGGER.debug("grad_regions: %s", grad_regions)

    assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions)

    # split nodes into partitions
    highest_partition = -1
    for node in m.graph.nodes:
        orig_nodes[node.name] = node

        # TODO currently placeholders/parameters aren't put into random partitions,
        # rather they're added to the graphs where they are used down below
        if node.op in ["placeholder", "get_attr"]:
            continue
        if node.op == "output":
            torch.fx.graph.map_arg(
                node.args[0], lambda n: record_cross_partition_use(n, None)
            )
            continue

        if assert_monotonically_increasing:
            pid = split_callback(node)
            assert highest_partition <= pid, \
                ("autocast or set_grad_enabled require monotonically increasing partitions:"
                 f"highest: {highest_partition}, this node's: {pid}")
            highest_partition = pid

        # do not capture cross-partition dependencies for global state nodes as they will be
        # self-contained - their setup and unwind will be isolated to each partition submodule.
        if node.target not in GLOBAL_STATE_NODES:
            torch.fx.graph.map_arg(
                node.args, lambda def_node: record_cross_partition_use(def_node, node)
            )
            torch.fx.graph.map_arg(
                node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
            )  # noqa: B950

    original_partition_order = list(partitions.keys())
    # find partitions with no dependencies
    root_partitions: List[str] = []
    for partition_name, partition in partitions.items():
        if not len(partition.dependencies):
            root_partitions.append(partition_name)

    # check partitions for circular dependencies and create topological partition ordering
    sorted_partitions: List[str] = []
    while root_partitions:
        root_partition = root_partitions.pop()
        sorted_partitions.append(root_partition)
        for dependent in partitions[root_partition].dependents:
            partitions[dependent].dependencies.pop(root_partition)
            if not partitions[dependent].dependencies:
                root_partitions.append(dependent)
    if len(sorted_partitions) != len(partitions):
        raise RuntimeError("cycle exists between partitions!")

    # Enter prelude
    for regions_mapping in [autocast_regions, grad_regions]:
        for node, regions in regions_mapping.items():
            assert len(regions) > 0
            partitions[str(regions[0])].environment[node] = node
            for r in regions[1:]:
                partition = partitions[str(r)]
                new_node = partition.graph.create_node(
                    op=node.op,
                    target=node.target,
                    args=tuple(arg for arg in node.args),
                    kwargs={},
                    type_expr=node.type,
                )
                new_node.meta = node.meta.copy()  # is it really a good idea to copy this?
                partition.environment[node] = new_node

    # add placeholders to partition inputs
    for partition_name in sorted_partitions:
        partition = partitions[partition_name]
        for inp in partition.inputs:
            placeholder = partition.graph.placeholder(
                inp,
                type_expr=orig_nodes[inp].type,
            )
            placeholder.meta = orig_nodes[inp].meta.copy()
            partition.environment[orig_nodes[inp]] = placeholder

    # Transform nodes and collect targets for partition's submodule
    for node in m.graph.nodes:
        if hasattr(node, "_fx_partition"):
            partition = partitions[node._fx_partition]

            # swap out old graph nodes in kw/args with references to new nodes in this submodule
            environment = partition.environment
            gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
            gathered_kwargs = torch.fx.graph.map_arg(
                node.kwargs, lambda n: environment[n]
            )

            if node.op not in ["call_module", "get_attr"]:
                target = node.target
            else:
                target_atoms = node.target.split(".")
                target_attr = m
                for atom in target_atoms:
                    if not hasattr(target_attr, atom):
                        raise AttributeError(f"Operator target {node.target} not found!")
                    target_attr = getattr(target_attr, atom)
                # target = target_atoms[-1]
                target = "_".join(target_atoms)
                partition.targets[target] = target_attr
                # Fill in the passed-in mapping from new qualname to old qualname
                if qualname_map is not None:
                    # When creating the split module later, the submodules will have
                    # path prefix matching the corresponding partition's submod_name
                    qualname = f"{partition.submod_name}.{target}"
                    qualname_map[qualname] = node.target

            assert isinstance(gathered_args, tuple)
            assert isinstance(gathered_kwargs, dict)
            name = node.name if keep_original_node_name else None
            new_node = partition.graph.create_node(
                op=node.op,
                target=target,
                args=gathered_args,
                kwargs=gathered_kwargs,
                type_expr=node.type,
                name=name,
            )
            new_node.meta = node.meta.copy()
            partition.environment[node] = new_node

    # Exit epilogue
    for regions_mapping in [autocast_regions]:
        for node in reversed(regions_mapping):
            regions = regions_mapping[node]
            assert len(regions) > 0
            for r in regions[:-1]:
                partition = partitions[str(r)]
                exit_node = autocast_exits[node]
                assert exit_node is not None, "Missing exit node"
                new_node = partition.graph.create_node(
                    op=exit_node.op,
                    target=exit_node.target,
                    args=(partition.environment[node],),
                    kwargs={},
                    type_expr=exit_node.type,
                )
                new_node.meta = exit_node.meta.copy()  # is it really a good idea to copy this?

    # original module environment dict mapping node names to nodes
    orig_mod_env: Dict[str, Node] = {}
    # Set up values to construct base module
    base_mod_env: Dict[str, Node] = {}
    base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
    base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
    if not keep_original_order:
        for node in m.graph.nodes:
            base_mod_env, base_mod_attrs = construct_graph(
                node, base_mod_env, base_mod_attrs
            )

    else:
        # Go through the graph to construct the mapping dict
        for node in m.graph.nodes:
            orig_mod_env[node.name] = node

    # Do some things iterating over the partitions in topological order again:
    # 1) Finish off submodule Graphs by setting corresponding outputs
    # 2) Construct GraphModules for each submodule
    # 3) Construct the base graph by emitting calls to those submodules in
    #    topological order or original order specified by keep_original_order

    construct_order_partitions = (
        sorted_partitions if not keep_original_order else original_partition_order
    )

    already_constructed_attr_nodes = set()
    for partition_name in construct_order_partitions:
        partition = partitions[partition_name]

        # Set correct output values
        output_vals = tuple(
            partition.environment[orig_nodes[name]] for name in partition.outputs
        )

        # skip output node generation if there are no output values
        num_output_vals = len(output_vals)
        if num_output_vals == 1:
            partition.graph.output(output_vals[0])
        elif num_output_vals > 1:
            partition.graph.output(output_vals)

        if keep_original_order:
            # first get the attr nodes required by this partition
            orig_mod_attr_nodes: List[Node] = [
                orig_mod_env[key] for key in partition.inputs
            ]
            # Construct GraphModule for this partition
            for node in orig_mod_attr_nodes:  # type: ignore[attr-defined]
                if node in already_constructed_attr_nodes:
                    continue
                base_mod_env, base_mod_attrs = construct_graph(
                    node, base_mod_env, base_mod_attrs
                )
                already_constructed_attr_nodes.add(node)

        base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule(
            partition.targets, partition.graph
        )  # noqa: B950

        # Emit call in base graph to this submodule
        output_val = base_mod_graph.call_module(
            partition.submod_name,
            tuple(base_mod_env[name] for name in partition.inputs),
        )

        num_outputs = len(partition.outputs)
        if num_outputs > 1:
            # Unpack multiple return values from submodule
            output_val_proxy = torch.fx.proxy.Proxy(output_val)
            for i, output_name in enumerate(partition.outputs):
                base_mod_env[output_name] = output_val_proxy[i].node  # type: ignore[index]
        elif num_outputs == 1:
            base_mod_env[next(iter(partition.outputs))] = output_val

    for node in m.graph.nodes:
        if node.op == "output":
            base_mod_graph.output(
                torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
            )  # noqa: B950

    return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)