File size: 6,923 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import copy

import torch.nn as nn

from torch.ao.quantization.fuser_method_mappings import get_fuser_method
# for backward compatibility
from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn  # noqa: F401
from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn_relu  # noqa: F401
from torch.nn.utils.parametrize import type_before_parametrizations

from typing import List, Optional

__all__ = [
    "fuse_known_modules",
    "fuse_modules",
    "fuse_modules_qat",
]

# Generalization of getattr
def _get_module(model, submodule_key):
    tokens = submodule_key.split('.')
    cur_mod = model
    for s in tokens:
        cur_mod = getattr(cur_mod, s)
    return cur_mod

# Generalization of setattr
def _set_module(model, submodule_key, module):
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        cur_mod = getattr(cur_mod, s)

    setattr(cur_mod, tokens[-1], module)

def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
    r"""Return a list of known fuse modules.



    Returns a list of modules that fuses the operations specified

     in the input module list.



    Fuses only the following sequence of modules:

    conv, bn

    conv, bn, relu

    conv, relu

    linear, bn

    linear, relu

    For these sequences, the first element in the output module list performs

    the fused operation. The rest of the elements are set to nn.Identity()

    """
    types = tuple(type_before_parametrizations(m) for m in mod_list)
    fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
    if fuser_method is None:
        raise NotImplementedError(f"Cannot fuse modules: {types}")
    new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
    fused = fuser_method(is_qat, *mod_list)
    # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
    # Move pre forward hooks of the base module to resulting fused module
    for pre_hook_fn in mod_list[0]._forward_pre_hooks.values():
        fused.register_forward_pre_hook(pre_hook_fn)
    mod_list[0]._forward_pre_hooks.clear()
    # Move post forward hooks of the last module to resulting fused module
    for hook_fn in mod_list[-1]._forward_hooks.values():
        fused.register_forward_hook(hook_fn)
    mod_list[-1]._forward_hooks.clear()
    new_mod[0] = fused

    for i in range(1, len(mod_list)):
        identity = nn.Identity()
        identity.training = mod_list[0].training
        new_mod[i] = identity

    return new_mod

def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
    if fuse_custom_config_dict is None:
        fuse_custom_config_dict = {}
    additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
    mod_list = []
    for item in modules_to_fuse:
        mod_list.append(_get_module(model, item))

    # Fuse list of modules
    new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)

    # Replace original module list with fused module list
    for i, item in enumerate(modules_to_fuse):
        _set_module(model, item, new_mod_list[i])

def _fuse_modules(model, modules_to_fuse, is_qat, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
    if not inplace:
        model = copy.deepcopy(model)

    if all(isinstance(module_element, str) for module_element in modules_to_fuse):
        # Handle case of modules_to_fuse being a list
        _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict)
    else:
        # Handle case of modules_to_fuse being a list of lists
        for module_list in modules_to_fuse:
            _fuse_modules_helper(model, module_list, is_qat, fuser_func, fuse_custom_config_dict)
    return model

def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
    r"""Fuse a list of modules into a single module.



    Fuses only the following sequence of modules:

    conv, bn

    conv, bn, relu

    conv, relu

    linear, relu

    bn, relu

    All other sequences are left unchanged.

    For these sequences, replaces the first item in the list

    with the fused module, replacing the rest of the modules

    with identity.



    Args:

        model: Model containing the modules to be fused

        modules_to_fuse: list of list of module names to fuse. Can also be a list

                         of strings if there is only a single list of modules to fuse.

        inplace: bool specifying if fusion happens in place on the model, by default

                 a new model is returned

        fuser_func: Function that takes in a list of modules and outputs a list of fused modules

                    of the same length. For example,

                    fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]

                    Defaults to torch.ao.quantization.fuse_known_modules

        `fuse_custom_config_dict`: custom configuration for fusion



    .. code-block:: python



       # Example of fuse_custom_config_dict

       fuse_custom_config_dict = {

           # Additional fuser_method mapping

           "additional_fuser_method_mapping": {

               (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn

           },

       }



    Returns:

        model with fused modules. A new copy is created if inplace=True.



    Examples::



            >>> # xdoctest: +SKIP

            >>> m = M().eval()

            >>> # m is a module containing the sub-modules below

            >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]

            >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)

            >>> output = fused_m(input)



            >>> m = M().eval()

            >>> # Alternately provide a single list of modules to fuse

            >>> modules_to_fuse = ['conv1', 'bn1', 'relu1']

            >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)

            >>> output = fused_m(input)



    """
    return _fuse_modules(
        model,
        modules_to_fuse,
        is_qat=False,
        inplace=inplace,
        fuser_func=fuser_func,
        fuse_custom_config_dict=fuse_custom_config_dict)

def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
    """QAT version for `fuse_modules`."""
    return _fuse_modules(
        model,
        modules_to_fuse,
        is_qat=True,
        inplace=inplace,
        fuser_func=fuser_func,
        fuse_custom_config_dict=fuse_custom_config_dict)