File size: 3,207 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Dict, Tuple

from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph

from torch.fx.graph_module import GraphModule
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
from torch.nn import Module


__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"]


@compatibility(is_backward_compatible=False)
class HolderModule(Module):
    """

    HolderModule is used to copy all the attributes from original module to submodules

    that uses the attributes

    """

    def __init__(self, d):
        super().__init__()
        for k, v in d.items():
            self.add_module(k, v)


@compatibility(is_backward_compatible=False)
def lift_subgraph_as_module(

    gm: GraphModule,

    subgraph: Graph,

    comp_name: str = "",

    class_name: str = "GraphModule",

) -> Tuple[GraphModule, Dict[str, str]]:
    """

    Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.



    Args:

        gm (GraphModule): parent graph module



        subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph



        comp_name (str): name for the new component



        class_name (str): name for the submodule



    """

    # Loop through all module calls (call_module) and param fetches (get_attr)
    # in this component, creating HolderModules as necessary to match the path.
    # e.g. if in the original module there's a get_attr node fetches "conv.weight".
    # We create a HolderModule as root -> add a HolderModule named "conv" ->
    # make "weight" a attribute of "conv" HolderModule and point to conv.weight in
    # the original module.
    submodule = HolderModule({})
    orig_to_split_fqn_mapping: Dict[str, str] = {}
    for n in subgraph.nodes:
        if n.op not in ("call_module", "get_attr"):
            continue

        target = n.target
        assert isinstance(target, str)
        target_name_parts = target.split(".")
        curr = submodule
        orig_gm = gm

        for name in target_name_parts[:-1]:
            if not hasattr(curr, name):
                curr.add_module(name, HolderModule({}))

            curr = getattr(curr, name)
            orig_gm = getattr(orig_gm, name)

        leaf_node_name = target_name_parts[-1]
        leaf_node = getattr(orig_gm, leaf_node_name)

        orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}"
        # Relies on custom __setattr__ magic.
        setattr(curr, leaf_node_name, leaf_node)

    return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping


@compatibility(is_backward_compatible=False)
def compare_graphs(left: Graph, right: Graph) -> bool:
    """

    Return True if two graphs are identical, i.e they

        - have the same number of outputs in the same order

        - have the same number of inputs in the same order

        - have the same set of nodes, and identical connectivity

    """

    matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True)
    matches = matcher.match(right)

    return len(matches) > 0