Spaces:
Running
Running
import copy | |
import torch.nn as nn | |
from torch.ao.quantization.fuser_method_mappings import get_fuser_method | |
# for backward compatibility | |
from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn # noqa: F401 | |
from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn_relu # noqa: F401 | |
from torch.nn.utils.parametrize import type_before_parametrizations | |
from typing import List, Optional | |
__all__ = [ | |
"fuse_known_modules", | |
"fuse_modules", | |
"fuse_modules_qat", | |
] | |
# Generalization of getattr | |
def _get_module(model, submodule_key): | |
tokens = submodule_key.split('.') | |
cur_mod = model | |
for s in tokens: | |
cur_mod = getattr(cur_mod, s) | |
return cur_mod | |
# Generalization of setattr | |
def _set_module(model, submodule_key, module): | |
tokens = submodule_key.split('.') | |
sub_tokens = tokens[:-1] | |
cur_mod = model | |
for s in sub_tokens: | |
cur_mod = getattr(cur_mod, s) | |
setattr(cur_mod, tokens[-1], module) | |
def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None): | |
r"""Return a list of known fuse modules. | |
Returns a list of modules that fuses the operations specified | |
in the input module list. | |
Fuses only the following sequence of modules: | |
conv, bn | |
conv, bn, relu | |
conv, relu | |
linear, bn | |
linear, relu | |
For these sequences, the first element in the output module list performs | |
the fused operation. The rest of the elements are set to nn.Identity() | |
""" | |
types = tuple(type_before_parametrizations(m) for m in mod_list) | |
fuser_method = get_fuser_method(types, additional_fuser_method_mapping) | |
if fuser_method is None: | |
raise NotImplementedError(f"Cannot fuse modules: {types}") | |
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list) | |
fused = fuser_method(is_qat, *mod_list) | |
# NOTE: forward hooks not processed in the two following for loops will be lost after the fusion | |
# Move pre forward hooks of the base module to resulting fused module | |
for pre_hook_fn in mod_list[0]._forward_pre_hooks.values(): | |
fused.register_forward_pre_hook(pre_hook_fn) | |
mod_list[0]._forward_pre_hooks.clear() | |
# Move post forward hooks of the last module to resulting fused module | |
for hook_fn in mod_list[-1]._forward_hooks.values(): | |
fused.register_forward_hook(hook_fn) | |
mod_list[-1]._forward_hooks.clear() | |
new_mod[0] = fused | |
for i in range(1, len(mod_list)): | |
identity = nn.Identity() | |
identity.training = mod_list[0].training | |
new_mod[i] = identity | |
return new_mod | |
def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): | |
if fuse_custom_config_dict is None: | |
fuse_custom_config_dict = {} | |
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {}) | |
mod_list = [] | |
for item in modules_to_fuse: | |
mod_list.append(_get_module(model, item)) | |
# Fuse list of modules | |
new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping) | |
# Replace original module list with fused module list | |
for i, item in enumerate(modules_to_fuse): | |
_set_module(model, item, new_mod_list[i]) | |
def _fuse_modules(model, modules_to_fuse, is_qat, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): | |
if not inplace: | |
model = copy.deepcopy(model) | |
if all(isinstance(module_element, str) for module_element in modules_to_fuse): | |
# Handle case of modules_to_fuse being a list | |
_fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict) | |
else: | |
# Handle case of modules_to_fuse being a list of lists | |
for module_list in modules_to_fuse: | |
_fuse_modules_helper(model, module_list, is_qat, fuser_func, fuse_custom_config_dict) | |
return model | |
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): | |
r"""Fuse a list of modules into a single module. | |
Fuses only the following sequence of modules: | |
conv, bn | |
conv, bn, relu | |
conv, relu | |
linear, relu | |
bn, relu | |
All other sequences are left unchanged. | |
For these sequences, replaces the first item in the list | |
with the fused module, replacing the rest of the modules | |
with identity. | |
Args: | |
model: Model containing the modules to be fused | |
modules_to_fuse: list of list of module names to fuse. Can also be a list | |
of strings if there is only a single list of modules to fuse. | |
inplace: bool specifying if fusion happens in place on the model, by default | |
a new model is returned | |
fuser_func: Function that takes in a list of modules and outputs a list of fused modules | |
of the same length. For example, | |
fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] | |
Defaults to torch.ao.quantization.fuse_known_modules | |
`fuse_custom_config_dict`: custom configuration for fusion | |
.. code-block:: python | |
# Example of fuse_custom_config_dict | |
fuse_custom_config_dict = { | |
# Additional fuser_method mapping | |
"additional_fuser_method_mapping": { | |
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn | |
}, | |
} | |
Returns: | |
model with fused modules. A new copy is created if inplace=True. | |
Examples:: | |
>>> # xdoctest: +SKIP | |
>>> m = M().eval() | |
>>> # m is a module containing the sub-modules below | |
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] | |
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) | |
>>> output = fused_m(input) | |
>>> m = M().eval() | |
>>> # Alternately provide a single list of modules to fuse | |
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] | |
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) | |
>>> output = fused_m(input) | |
""" | |
return _fuse_modules( | |
model, | |
modules_to_fuse, | |
is_qat=False, | |
inplace=inplace, | |
fuser_func=fuser_func, | |
fuse_custom_config_dict=fuse_custom_config_dict) | |
def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): | |
"""QAT version for `fuse_modules`.""" | |
return _fuse_modules( | |
model, | |
modules_to_fuse, | |
is_qat=True, | |
inplace=inplace, | |
fuser_func=fuser_func, | |
fuse_custom_config_dict=fuse_custom_config_dict) | |