Spaces:
Running
Running
import copy | |
import itertools | |
import warnings | |
import torch | |
import torch.nn as nn | |
import torch.ao.nn.quantized as nnq | |
from torch.ao.nn.intrinsic import _FusedModule | |
from torch.ao.quantization.quantization_mappings import ( | |
get_default_dynamic_quant_module_mappings, | |
get_default_static_quant_module_mappings, | |
get_default_static_quant_reference_module_mappings, | |
get_default_qat_module_mappings, | |
get_default_qconfig_propagation_list, | |
no_observer_set, | |
_has_special_act_post_process, | |
_get_special_act_post_process, | |
) | |
from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations | |
from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper | |
from torch.ao.quantization.qconfig import ( | |
_add_module_to_qconfig_obs_ctr, | |
default_dynamic_qconfig, | |
float16_dynamic_qconfig, | |
float_qparams_weight_only_qconfig, | |
float_qparams_weight_only_qconfig_4bit, | |
_activation_is_memoryless) | |
from torch.nn.utils.parametrize import type_before_parametrizations | |
from torch.ao.quantization.observer import _is_activation_post_process | |
# TODO remove this once BC is no longer required to avoid a SEV | |
from torch.ao.quantization.observer import ( # noqa: F401 | |
_is_activation_post_process as is_activation_post_process | |
) | |
__all__ = [ | |
"get_default_custom_config_dict", | |
"propagate_qconfig_", | |
"add_quant_dequant", | |
"prepare", | |
"quantize", | |
"quantize_dynamic", | |
"prepare_qat", | |
"quantize_qat", | |
"convert", | |
"swap_module", | |
] | |
_DEFAULT_CUSTOM_CONFIG_DICT = { | |
'float_to_observed_custom_module_class': { | |
nn.LSTM: nn.quantizable.LSTM, | |
nn.MultiheadAttention: nn.quantizable.MultiheadAttention, | |
}, | |
'observed_to_quantized_custom_module_class': { | |
nn.quantizable.LSTM: nn.quantized.LSTM, | |
nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, | |
} | |
} | |
def get_default_custom_config_dict(): | |
r"""Defines the default custom config dict. | |
""" | |
return _DEFAULT_CUSTOM_CONFIG_DICT | |
def _propagate_qconfig_helper(module, qconfig_dict, | |
qconfig_parent=None, prefix='', prepare_custom_config_dict=None): | |
r"""This is a helper function for `propagate_qconfig_` | |
Args: | |
module: input module | |
qconfig_dict: dictionary that maps from name of submodule to quantization | |
configuration | |
qconfig_parent: quantization config of parent module, we will fallback to | |
this config when there is no specified config for current | |
module | |
prefix: corresponding prefix of the current module, used as key in | |
qconfig_dict | |
prepare_custom_config_dict: dictionary for custom handling of modules | |
see docs for :func:`~torch.ao.quantization.prepare_fx` | |
Return: | |
None, module is modified inplace with qconfig attached | |
""" | |
module_qconfig = qconfig_dict.get(type_before_parametrizations(module), qconfig_parent) | |
module_qconfig = qconfig_dict.get(prefix, module_qconfig) | |
module_qconfig = getattr(module, 'qconfig', module_qconfig) | |
torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module) | |
qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module) | |
module.qconfig = qconfig_with_device_check | |
for name, child in module.named_children(): | |
module_prefix = prefix + '.' + name if prefix else name | |
# do no not propagate qconfig to child if child is non traceable | |
if prepare_custom_config_dict is None or not ( | |
name in prepare_custom_config_dict.get("non_traceable_module_name", []) | |
or type(child) in prepare_custom_config_dict.get("non_traceable_module_class", []) | |
): | |
_propagate_qconfig_helper( | |
child, qconfig_dict, qconfig_with_device_check, module_prefix | |
) | |
def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): | |
r"""Propagate qconfig through the module hierarchy and assign `qconfig` | |
attribute on each leaf module | |
Args: | |
module: input module | |
qconfig_dict: dictionary that maps from name or type of submodule to | |
quantization configuration, qconfig applies to all submodules of a | |
given module unless qconfig for the submodules are specified (when | |
the submodule already has qconfig attribute) | |
prepare_custom_config_dict: dictionary for custom handling of modules | |
see docs for :func:`~torch.ao.quantization.prepare_fx` | |
Return: | |
None, module is modified inplace with qconfig attached | |
""" | |
if qconfig_dict is None: | |
qconfig_dict = {} | |
if prepare_custom_config_dict is None: | |
prepare_custom_config_dict = {} | |
_propagate_qconfig_helper(module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict) | |
def _observer_forward_hook(self, input, output): | |
r"""Forward hook that calls observer on the output | |
""" | |
return self.activation_post_process(output) | |
def _observer_forward_pre_hook(self, input): | |
r"""Forward pre hook that calls observer on the output | |
""" | |
return self.activation_post_process(input[0]) | |
def _register_activation_post_process_hook(module, pre_hook=False): | |
assert hasattr(module, 'activation_post_process'), \ | |
'Expect activation_post_process attribute already attached to the module' | |
if pre_hook: | |
handle = module.register_forward_pre_hook( | |
_observer_forward_pre_hook, prepend=True | |
) | |
else: | |
handle = module.register_forward_hook( | |
_observer_forward_hook, prepend=True | |
) | |
def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None): | |
r"""Add observer for the leaf child of the module. | |
This function insert observer module to all leaf child module that | |
has a valid qconfig attribute. | |
Args: | |
module: input module with qconfig attributes for all the leaf modules that we want to quantize | |
qconfig_propagation_list: a list of quantizable modules that will have observers added to them | |
if they are leaf nodes | |
device: parent device, if any | |
non_leaf_module_list: list of non-leaf modules we want to add observer | |
Return: | |
None, module is modified inplace with added observer modules and forward_hooks | |
""" | |
if qconfig_propagation_list is None: | |
qconfig_propagation_list = get_default_qconfig_propagation_list() | |
if custom_module_class_mapping is None: | |
custom_module_class_mapping = {} | |
# respect device affinity when adding observers | |
if device is None: | |
devices = _get_unique_devices_(module) | |
assert len(devices) <= 1, ( | |
f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" | |
) | |
device = next(iter(devices)) if len(devices) > 0 else None | |
def get_activation_post_process(qconfig, device, special_act_post_process=None): | |
activation = qconfig.activation() if special_act_post_process is None else special_act_post_process() | |
if device is not None: | |
activation.to(device) | |
return activation | |
def needs_observation(m): | |
return hasattr(m, 'qconfig') and m.qconfig is not None | |
def insert_activation_post_process(m, special_act_post_process=None): | |
""" Adds an activation post process module and register | |
a pre or post hook that calls the module | |
""" | |
# We don't insert observer/fake_quantize for DeQuantStub | |
if needs_observation(m) and not isinstance(m, DeQuantStub): | |
# observer and hook will be gone after we swap the module | |
m.add_module('activation_post_process', get_activation_post_process( | |
m.qconfig, device, special_act_post_process)) | |
# Register observer as the first entry in the hook list | |
# All post forward hooks are preserved and will be executed after the observer before convert | |
_register_activation_post_process_hook(m, pre_hook=_activation_is_memoryless(m.qconfig)) | |
for name, child in module.named_children(): | |
# TODO remove Dropout special after codebase stable | |
if type_before_parametrizations(child) in [nn.Dropout]: | |
continue | |
elif issubclass(type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)): | |
if needs_observation(child): | |
assert hasattr(child, "activation_post_process"), ( | |
f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" | |
) | |
child.activation_post_process = get_activation_post_process(child.qconfig, device) | |
elif isinstance(child, _FusedModule): | |
# activation_post_process are now added directly to nn.Sequential/_FusedModule | |
if needs_observation(child): | |
insert_activation_post_process(child) | |
elif non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list: | |
if needs_observation(child): | |
insert_activation_post_process(child) | |
elif _has_special_act_post_process(child): | |
special_act_post_process = _get_special_act_post_process(child) | |
insert_activation_post_process(child, special_act_post_process) | |
elif needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping: | |
observed_child = custom_module_class_mapping[type_before_parametrizations(child)].from_float(child) | |
setattr(module, name, observed_child) | |
# TODO: These are the modules that cannot be observed | |
# Once there are more, we should move them to a separate list | |
if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set(): | |
insert_activation_post_process(observed_child) | |
else: | |
_add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping) | |
# Insert observers only for leaf nodes, note that this observer is for | |
# the output of the module, for input QuantStub will observe them | |
if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \ | |
and type_before_parametrizations(module) in qconfig_propagation_list: | |
insert_activation_post_process(module) | |
def _get_unique_devices_(module): | |
return {p.device for p in module.parameters()} | \ | |
{p.device for p in module.buffers()} | |
def add_quant_dequant(module): | |
r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig | |
Note that this function will modify the children of module inplace and it | |
can return a new module which wraps the input module as well. | |
Args: | |
module: input module with qconfig attributes for all the leaf modules | |
that we want to quantize | |
Return: | |
Either the inplace modified module with submodules wrapped in | |
`QuantWrapper` based on qconfig or a new `QuantWrapper` module which | |
wraps the input module, the latter case only happens when the input | |
module is a leaf module and we want to quantize it. | |
""" | |
if has_no_children_ignoring_parametrizations(module) and hasattr(module, 'qconfig') and module.qconfig: | |
return QuantWrapper(module) | |
for name, child in module.named_children(): | |
module._modules[name] = add_quant_dequant(child) | |
return module | |
def prepare(model, inplace=False, allow_list=None, | |
observer_non_leaf_module_list=None, | |
prepare_custom_config_dict=None): | |
r"""Prepares a copy of the model for quantization calibration or quantization-aware training. | |
Quantization configuration should be assigned preemptively | |
to individual submodules in `.qconfig` attribute. | |
The model will be attached with observer or fake quant modules, and qconfig | |
will be propagated. | |
Args: | |
`model`: input model to be modified in-place | |
`inplace`: carry out model transformations in-place, the original module is mutated | |
`allow_list`: list of quantizable modules | |
`observer_non_leaf_module_list`: list of non-leaf modules we want to add observer | |
`prepare_custom_config_dict`: customization configuration dictionary for prepare function | |
.. code-block:: python | |
# Example of prepare_custom_config_dict: | |
prepare_custom_config_dict = { | |
# user will manually define the corresponding observed | |
# module class which has a from_float class method that converts | |
# float custom module to observed custom module | |
"float_to_observed_custom_module_class": { | |
CustomModule: ObservedCustomModule | |
} | |
} | |
""" | |
torch._C._log_api_usage_once("quantization_api.quantize.prepare") | |
if prepare_custom_config_dict is None: | |
prepare_custom_config_dict = get_default_custom_config_dict() | |
custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) | |
if not inplace: | |
model = copy.deepcopy(model) | |
# TODO: remove allow_list | |
qconfig_propagation_list = allow_list | |
if allow_list is None: | |
qconfig_propagation_list = get_default_qconfig_propagation_list() | |
propagate_qconfig_(model, qconfig_dict=None) | |
# sanity check common API misusage | |
if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()): | |
warnings.warn("None of the submodule got qconfig applied. Make sure you " | |
"passed correct configuration through `qconfig_dict` or " | |
"by assigning the `.qconfig` attribute directly on submodules") | |
_add_observer_( | |
model, qconfig_propagation_list, observer_non_leaf_module_list, | |
custom_module_class_mapping=custom_module_class_mapping) | |
return model | |
def _remove_activation_post_process(module): | |
# TODO: maybe we should change activation_post_process to _activation_post_process | |
# to prevent it from being used by user | |
if hasattr(module, 'activation_post_process') and \ | |
_is_activation_post_process(module.activation_post_process): | |
delattr(module, 'activation_post_process') | |
# remove activation_post_process pre and post hooks | |
def remove_hooks(pre_hook=False): | |
hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks | |
observer_hook = _observer_forward_pre_hook if pre_hook else _observer_forward_hook | |
handle_ids_to_remove = set() | |
for handle_id, hook_fn in hook_map.items(): | |
if hook_fn is observer_hook: | |
handle_ids_to_remove.add(handle_id) | |
for handle_id in handle_ids_to_remove: | |
hook_map.pop(handle_id) | |
remove_hooks(pre_hook=True) | |
remove_hooks(pre_hook=False) | |
# TODO: rename to something more general | |
def _remove_qconfig(module): | |
r"""Clean up the qconfig left in the module so that new qconfig can be | |
propagated. | |
Args: | |
module: module to be cleaned up | |
""" | |
for child in module.children(): | |
_remove_qconfig(child) | |
if hasattr(module, "qconfig"): | |
del module.qconfig | |
_remove_activation_post_process(module) | |
def quantize(model, run_fn, run_args, mapping=None, inplace=False): | |
r"""Quantize the input float model with post training static quantization. | |
First it will prepare the model for calibration, then it calls | |
`run_fn` which will run the calibration step, after that we will | |
convert the model to a quantized model. | |
Args: | |
model: input float model | |
run_fn: a calibration function for calibrating the prepared model | |
run_args: positional arguments for `run_fn` | |
inplace: carry out model transformations in-place, the original module is mutated | |
mapping: correspondence between original module types and quantized counterparts | |
Return: | |
Quantized model. | |
""" | |
torch._C._log_api_usage_once("quantization_api.quantize.quantize") | |
if mapping is None: | |
mapping = get_default_static_quant_module_mappings() | |
if not inplace: | |
model = copy.deepcopy(model) | |
model.eval() | |
prepare(model, inplace=True) | |
run_fn(model, *run_args) | |
convert(model, mapping, inplace=True) | |
return model | |
def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, | |
mapping=None, inplace=False): | |
r"""Converts a float model to dynamic (i.e. weights-only) quantized model. | |
Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. | |
For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization | |
by default is performed for layers with large weights size - i.e. Linear and RNN variants. | |
Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. | |
If `qconfig` is provided, the `dtype` argument is ignored. | |
Args: | |
model: input model | |
qconfig_spec: Either: | |
- A dictionary that maps from name or type of submodule to quantization | |
configuration, qconfig applies to all submodules of a given | |
module unless qconfig for the submodules are specified (when the | |
submodule already has qconfig attribute). Entries in the dictionary | |
need to be QConfig instances. | |
- A set of types and/or submodule names to apply dynamic quantization to, | |
in which case the `dtype` argument is used to specify the bit-width | |
inplace: carry out model transformations in-place, the original module is mutated | |
mapping: maps type of a submodule to a type of corresponding dynamically quantized version | |
with which the submodule needs to be replaced | |
""" | |
torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic") | |
if qconfig_spec is None: | |
if dtype == torch.qint8: | |
qconfig_spec = { | |
nn.Linear : default_dynamic_qconfig, | |
nn.LSTM : default_dynamic_qconfig, | |
nn.GRU : default_dynamic_qconfig, | |
nn.LSTMCell : default_dynamic_qconfig, | |
nn.RNNCell : default_dynamic_qconfig, | |
nn.GRUCell : default_dynamic_qconfig, | |
} | |
elif dtype == torch.float16: | |
qconfig_spec = { | |
nn.Linear : float16_dynamic_qconfig, | |
nn.LSTM : float16_dynamic_qconfig, | |
nn.GRU : float16_dynamic_qconfig, | |
nn.LSTMCell : float16_dynamic_qconfig, | |
nn.RNNCell : float16_dynamic_qconfig, | |
nn.GRUCell : float16_dynamic_qconfig, | |
} | |
elif dtype == torch.quint8: | |
qconfig_spec = { | |
nn.EmbeddingBag : float_qparams_weight_only_qconfig, | |
nn.Embedding : float_qparams_weight_only_qconfig, | |
} | |
elif dtype == torch.quint4x2: | |
qconfig_spec = { | |
nn.EmbeddingBag : float_qparams_weight_only_qconfig_4bit, | |
} | |
else: | |
raise ValueError( | |
f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please") | |
elif isinstance(qconfig_spec, set): | |
if dtype is torch.qint8: | |
default_qconfig = default_dynamic_qconfig | |
elif dtype is torch.float16: | |
default_qconfig = float16_dynamic_qconfig | |
elif dtype is torch.quint8: | |
default_qconfig = float_qparams_weight_only_qconfig | |
elif dtype is torch.quint4x2: | |
default_qconfig = float_qparams_weight_only_qconfig_4bit | |
else: | |
raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype)) | |
qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) | |
if mapping is None: | |
mapping = get_default_dynamic_quant_module_mappings() | |
if not inplace: | |
model = copy.deepcopy(model) | |
model.eval() | |
propagate_qconfig_(model, qconfig_spec) | |
convert(model, mapping, inplace=True) | |
return model | |
def prepare_qat(model, mapping=None, inplace=False): | |
r""" | |
Prepares a copy of the model for quantization calibration or | |
quantization-aware training and converts it to quantized version. | |
Quantization configuration should be assigned preemptively | |
to individual submodules in `.qconfig` attribute. | |
Args: | |
model: input model to be modified in-place | |
mapping: dictionary that maps float modules to quantized modules to be | |
replaced. | |
inplace: carry out model transformations in-place, the original module | |
is mutated | |
""" | |
torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") | |
assert model.training, "prepare_qat only works on models in training mode" | |
if mapping is None: | |
mapping = get_default_qat_module_mappings() | |
if not inplace: | |
model = copy.deepcopy(model) | |
propagate_qconfig_(model, qconfig_dict=None) | |
convert(model, mapping=mapping, inplace=True, remove_qconfig=False) | |
prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True) | |
return model | |
def quantize_qat(model, run_fn, run_args, inplace=False): | |
r"""Do quantization aware training and output a quantized model | |
Args: | |
model: input model | |
run_fn: a function for evaluating the prepared model, can be a | |
function that simply runs the prepared model or a training | |
loop | |
run_args: positional arguments for `run_fn` | |
Return: | |
Quantized model. | |
""" | |
torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat") | |
if not inplace: | |
model = copy.deepcopy(model) | |
model.train() | |
prepare_qat(model, inplace=True) | |
run_fn(model, *run_args) | |
convert(model, inplace=True) | |
return model | |
def convert( | |
module, mapping=None, inplace=False, remove_qconfig=True, | |
is_reference=False, convert_custom_config_dict=None): | |
r"""Converts submodules in input module to a different module according to `mapping` | |
by calling `from_float` method on the target module class. And remove qconfig at the | |
end if remove_qconfig is set to True. | |
Args: | |
`module`: prepared and calibrated module | |
`mapping`: a dictionary that maps from source module type to target | |
module type, can be overwritten to allow swapping user defined | |
Modules | |
`inplace`: carry out model transformations in-place, the original module | |
is mutated | |
`convert_custom_config_dict`: custom configuration dictionary for convert function | |
.. code-block:: python | |
# Example of convert_custom_config_dict: | |
convert_custom_config_dict = { | |
# user will manually define the corresponding quantized | |
# module class which has a from_observed class method that converts | |
# observed custom module to quantized custom module | |
"observed_to_quantized_custom_module_class": { | |
ObservedCustomModule: QuantizedCustomModule | |
} | |
} | |
""" | |
torch._C._log_api_usage_once("quantization_api.quantize.convert") | |
if not inplace: | |
module = copy.deepcopy(module) | |
_convert( | |
module, mapping, inplace=True, is_reference=is_reference, | |
convert_custom_config_dict=convert_custom_config_dict) | |
if remove_qconfig: | |
_remove_qconfig(module) | |
return module | |
def _convert( | |
module, mapping=None, inplace=False, | |
is_reference=False, convert_custom_config_dict=None): | |
r"""Converts submodules in input module to a different module according to `mapping` | |
by calling `from_float` method on the target module class | |
Args: | |
module: input module | |
mapping: a dictionary that maps from source module type to target | |
module type, can be overwritten to allow swapping user defined | |
Modules | |
inplace: carry out model transformations in-place, the original module | |
is mutated | |
is_reference: a flag to enable quantized reference module | |
""" | |
if mapping is None: | |
mapping = get_default_static_quant_reference_module_mappings() if is_reference \ | |
else get_default_static_quant_module_mappings() | |
if convert_custom_config_dict is None: | |
convert_custom_config_dict = get_default_custom_config_dict() | |
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {}) | |
if not inplace: | |
module = copy.deepcopy(module) | |
reassign = {} | |
for name, mod in module.named_children(): | |
# both fused modules and observed custom modules are | |
# swapped as one unit | |
if not isinstance(mod, _FusedModule) and \ | |
type_before_parametrizations(mod) not in custom_module_class_mapping: | |
_convert(mod, mapping, True, # inplace | |
is_reference, convert_custom_config_dict) | |
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping) | |
for key, value in reassign.items(): | |
module._modules[key] = value | |
return module | |
def swap_module(mod, mapping, custom_module_class_mapping): | |
r"""Swaps the module if it has a quantized counterpart and it has an | |
`observer` attached. | |
Args: | |
mod: input module | |
mapping: a dictionary that maps from nn module to nnq module | |
Return: | |
The corresponding quantized module of `mod` | |
""" | |
new_mod = mod | |
if hasattr(mod, 'qconfig') and mod.qconfig is not None: | |
swapped = False | |
if type_before_parametrizations(mod) in custom_module_class_mapping: | |
new_mod = custom_module_class_mapping[type_before_parametrizations(mod)].from_observed(mod) | |
swapped = True | |
elif type_before_parametrizations(mod) in mapping: | |
qmod = mapping[type_before_parametrizations(mod)] | |
if hasattr(qmod, '_IS_REFERENCE') and qmod._IS_REFERENCE: | |
assert mod.qconfig is not None | |
weight_post_process = mod.qconfig.weight() | |
weight_post_process(mod.weight) | |
weight_qparams = get_qparam_dict(weight_post_process) | |
new_mod = qmod.from_float(mod, weight_qparams) | |
else: | |
new_mod = qmod.from_float(mod) | |
swapped = True | |
if swapped: | |
# Preserve module's pre forward hooks. They'll be called on quantized input | |
for pre_hook_fn in mod._forward_pre_hooks.values(): | |
new_mod.register_forward_pre_hook(pre_hook_fn) | |
# Preserve module's post forward hooks except _observer_forward_hook | |
# After convert they'll work with quantized output | |
for hook_fn in mod._forward_hooks.values(): | |
if hook_fn is not _observer_forward_hook: | |
new_mod.register_forward_hook(hook_fn) | |
# respect device affinity when swapping modules | |
devices = _get_unique_devices_(mod) | |
assert len(devices) <= 1, ( | |
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" | |
) | |
device = next(iter(devices)) if len(devices) > 0 else None | |
if device: | |
new_mod.to(device) | |
return new_mod | |
def _get_observer_dict(mod, target_dict, prefix=""): | |
r"""Traverse the modules and save all observers into dict. | |
This is mainly used for quantization accuracy debug | |
Args: | |
mod: the top module we want to save all observers | |
prefix: the prefix for the current module | |
target_dict: the dictionary used to save all the observers | |
""" | |
def get_prefix(prefix): | |
return prefix if prefix == "" else prefix + '.' | |
if hasattr(mod, 'activation_post_process'): | |
target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process | |
for name, child in mod.named_children(): | |
module_prefix = get_prefix(prefix) + name if prefix else name | |
_get_observer_dict(child, target_dict, module_prefix) | |