Spaces:
Running
Running
import torch | |
from torch.ao.quantization.qconfig import QConfig | |
from torch.ao.quantization.quant_type import QuantType | |
from torch.jit._recursive import wrap_cpp_module | |
__all__ = [ | |
"script_qconfig", | |
"script_qconfig_dict", | |
"fuse_conv_bn_jit", | |
"prepare_jit", | |
"prepare_dynamic_jit", | |
"convert_jit", | |
"convert_dynamic_jit", | |
"quantize_jit", | |
"quantize_dynamic_jit", | |
] | |
def _check_is_script_module(model): | |
if not isinstance(model, torch.jit.ScriptModule): | |
raise ValueError('input must be a script module, got: ' + str(type(model))) | |
def _check_forward_method(model): | |
if not model._c._has_method('forward'): | |
raise ValueError('input script module does not have forward method') | |
def script_qconfig(qconfig): | |
r"""Instantiate the activation and weight observer modules and script | |
them, these observer module instances will be deepcopied during | |
prepare_jit step. | |
""" | |
return QConfig( | |
activation=torch.jit.script(qconfig.activation())._c, | |
weight=torch.jit.script(qconfig.weight())._c) | |
def script_qconfig_dict(qconfig_dict): | |
r"""Helper function used by `prepare_jit`. | |
Apply `script_qconfig` for all entries in `qconfig_dict` that is | |
not None. | |
""" | |
return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()} | |
def fuse_conv_bn_jit(model, inplace=False): | |
r""" Fuse conv - bn module | |
Works for eval model only. | |
Args: | |
model: TorchScript model from scripting or tracing | |
""" | |
torch._C._log_api_usage_once("quantization_api.quantize_jit.fuse_conv_bn_jit") | |
model_c = model._c | |
model_c = torch._C._jit_pass_fold_convbn(model_c) | |
if inplace: | |
model._reconstruct(model_c) | |
else: | |
model = wrap_cpp_module(model_c) | |
return model | |
def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC): | |
_check_is_script_module(model) | |
_check_forward_method(model) | |
if not all(isinstance(x, str) for x in qconfig_dict.keys()): | |
raise ValueError('qconfig_dict should only contain names(str) as keys.') | |
scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) | |
model = fuse_conv_bn_jit(model, inplace) | |
model_c = torch._C._jit_pass_insert_observers(model._c, | |
'forward', | |
scripted_qconfig_dict, | |
inplace, | |
quant_type) | |
if inplace: | |
model._reconstruct(model_c) | |
else: | |
model = wrap_cpp_module(model_c) | |
return model | |
def _prepare_ondevice_jit(model, qconfig_dict, method_name='forward', inplace=False, quant_type=QuantType.STATIC): | |
_check_is_script_module(model) | |
if not all(isinstance(x, str) for x in qconfig_dict.keys()): | |
raise ValueError('qconfig_dict should only contain names(str) as keys.') | |
scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) | |
method_graph = model._c._get_method(method_name).graph | |
torch._C._jit_pass_inline(method_graph) | |
model = fuse_conv_bn_jit(model, inplace) | |
model_c = torch._C._jit_pass_insert_observer_method_for_ondevice_ptq(model._c, | |
method_name, | |
scripted_qconfig_dict, | |
inplace, | |
quant_type) | |
if inplace: | |
model._reconstruct(model_c) | |
else: | |
model = wrap_cpp_module(model_c) | |
return model | |
def prepare_jit(model, qconfig_dict, inplace=False): | |
torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_jit") | |
return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.STATIC) | |
def prepare_dynamic_jit(model, qconfig_dict, inplace=False): | |
torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_dynamic_jit") | |
return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC) | |
def _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False): | |
return _prepare_ondevice_jit(model, qconfig_dict, method_name, inplace, quant_type=QuantType.DYNAMIC) | |
def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC, | |
preserved_attrs=None): | |
_check_is_script_module(model) | |
model.eval() | |
model_c = model._c | |
model_c = torch._C._jit_pass_insert_quant_dequant(model_c, 'forward', inplace, debug, quant_type) | |
if not debug: | |
is_xpu = all(p.device.type == 'xpu' for p in model.parameters()) | |
if not is_xpu: | |
# Moving model parameters to CPU since quantized operators | |
# are only supported on CPU and XPU right now | |
model.cpu() | |
if preserved_attrs is None: | |
preserved_attrs = [] | |
model_c = torch._C._jit_pass_quant_finalize(model_c, quant_type, preserved_attrs) | |
if inplace: | |
model._reconstruct(model_c) | |
else: | |
model = wrap_cpp_module(model_c) | |
torch._C._jit_pass_constant_propagation(model.graph) | |
torch._C._jit_pass_dce(model.graph) | |
return model | |
def _convert_ondevice_jit(model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC): | |
_check_is_script_module(model) | |
assert quant_type == QuantType.DYNAMIC, "This API, while should work for static quant, is only tested for dynamic quant." | |
assert not method_name.startswith("observe_"), "Pass in valid method to be quantized, e.g. forward" | |
observe_method_name = "observe_" + method_name | |
quantize_method_name = "quantize_" + method_name | |
model_c = model._c | |
model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq( | |
model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC) | |
model_c = torch._C._jit_pass_quant_finalize_for_ondevice_ptq(model_c, QuantType.DYNAMIC, quantize_method_name) | |
if inplace: | |
model._reconstruct(model_c) | |
else: | |
model = wrap_cpp_module(model_c) | |
return model | |
def convert_jit(model, inplace=False, debug=False, preserved_attrs=None): | |
torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit") | |
return _convert_jit(model, inplace, debug, quant_type=QuantType.STATIC, preserved_attrs=preserved_attrs) | |
def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None): | |
torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit") | |
return _convert_jit(model, inplace, debug, quant_type=QuantType.DYNAMIC, preserved_attrs=preserved_attrs) | |
def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False): | |
return _convert_ondevice_jit(model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC) | |
def _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=False): | |
model = _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name, inplace) | |
model = _convert_ondevice_dynamic_jit(model, method_name, inplace) | |
return model | |
def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC): | |
# Always do inplace convert because the Tensor is already | |
# copied in prepare_jit when inplace is False | |
if quant_type == QuantType.DYNAMIC: | |
model = prepare_dynamic_jit(model, qconfig_dict, inplace) | |
model = convert_dynamic_jit(model, True, debug) | |
else: | |
assert run_fn, "Must provide calibration function for post training static quantization" | |
assert run_args, "Must provide calibration dataset for post training static quantization" | |
model = prepare_jit(model, qconfig_dict, inplace) | |
run_fn(model, *run_args) | |
model = convert_jit(model, True, debug) | |
torch._C._jit_pass_constant_propagation(model.graph) | |
torch._C._jit_pass_dce(model.graph) | |
return model | |
def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False): | |
r"""Quantize the input float TorchScript model with | |
post training static quantization. | |
First it will prepare the model for calibration, then it calls | |
`run_fn` which will run the calibration step, after that we will | |
convert the model to a quantized model. | |
Args: | |
`model`: input float TorchScript model | |
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and | |
qconfig for that module as value, empty key means the qconfig will be applied | |
to whole model unless it's overwritten by more specific configurations, the | |
qconfig for each module is either found in the dictionary or fallback to | |
the qconfig of parent module. | |
Right now qconfig_dict is the only way to configure how the model is quantized, | |
and it is done in the granularity of module, that is, we only support one type | |
of qconfig for each torch.nn.Module, and the qconfig for sub module will | |
override the qconfig for parent module, empty string means global configuration. | |
`run_fn`: a calibration function for calibrating the prepared model | |
`run_args`: positional arguments for `run_fn` | |
`inplace`: carry out model transformations in-place, the original module is | |
mutated | |
`debug`: flag for producing a debug friendly model (preserve weight attribute) | |
Return: | |
Quantized TorchSciprt model. | |
Example: | |
```python | |
import torch | |
from torch.ao.quantization import get_default_qconfig | |
from torch.ao.quantization import quantize_jit | |
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input) | |
qconfig = get_default_qconfig('fbgemm') | |
def calibrate(model, data_loader): | |
model.eval() | |
with torch.no_grad(): | |
for image, target in data_loader: | |
model(image) | |
quantized_model = quantize_jit( | |
ts_model, | |
{'': qconfig}, | |
calibrate, | |
[data_loader_test]) | |
``` | |
""" | |
torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_jit") | |
return _quantize_jit(model, qconfig_dict, run_fn, run_args, inplace, debug, quant_type=QuantType.STATIC) | |
def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False): | |
r"""Quantize the input float TorchScript model with | |
post training dynamic quantization. | |
Currently only qint8 quantization of torch.nn.Linear is supported. | |
Args: | |
`model`: input float TorchScript model | |
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and | |
qconfig for that module as value, please see detailed | |
descriptions in :func:`~torch.ao.quantization.quantize_jit` | |
`inplace`: carry out model transformations in-place, the original module is | |
mutated | |
`debug`: flag for producing a debug friendly model (preserve weight attribute) | |
Return: | |
Quantized TorchSciprt model. | |
Example: | |
```python | |
import torch | |
from torch.ao.quantization import per_channel_dynamic_qconfig | |
from torch.ao.quantization import quantize_dynamic_jit | |
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input) | |
qconfig = get_default_qconfig('fbgemm') | |
def calibrate(model, data_loader): | |
model.eval() | |
with torch.no_grad(): | |
for image, target in data_loader: | |
model(image) | |
quantized_model = quantize_dynamic_jit( | |
ts_model, | |
{'': qconfig}, | |
calibrate, | |
[data_loader_test]) | |
``` | |
""" | |
torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit") | |
return _quantize_jit(model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC) | |
def _quantize_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False): | |
r"""Prepares the input float TorchScript model with | |
*on-device* post training dynamic quantization. | |
Currently only qint8 quantization of torch.nn.Linear is supported. | |
Args: | |
`model`: input float TorchScript model | |
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and | |
qconfig for that module as value, please see detailed | |
`method_name`: Name of the method within the model, to be prepared for quantization | |
descriptions in :func:`~torch.ao.quantization.quantize_jit` | |
`inplace`: carry out model transformations in-place, the original module is | |
mutated | |
Return: | |
TorchScript model that is ready for on device quantization. | |
This means that the returned | |
model has: | |
- Method is inlined. | |
- Model has observer modules inserted in the model. | |
- Model has packed params inserted in the model. However they are empty as in they dont | |
contain valid quantized weights. | |
- observe_<method_name> is added that observe the values to be quantized. | |
- reset_observers_<method_name> to reset observers. | |
- quantize_<method_name> is added to the model. | |
- This method extract scale, zero points. | |
- Quantizes observed weights. | |
- Creates packed params from it and update the attribute of the model with the new values | |
for the packed params. | |
- Reset the original fp32 weights with empty tensor using SetAttr. | |
- quantized_<method_name> is added to the model. | |
- This method uses quantized weights and quantized linear ops instead of fp32 op. | |
- This method should be used for inference post PTQ. | |
- Note that all method's signatures should be the same as method_name. | |
Later on device: | |
- Run reset_observers_<method_name> | |
- Run observe_<method_name> | |
- Run quantize_<method_name> | |
- Now model can be saved and loaded later. | |
- Run model with quantized_<method_name> | |
Example: | |
```python | |
import torch | |
from torch.ao.quantization import per_channel_dynamic_qconfig | |
from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit | |
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input) | |
qconfig = get_default_qconfig('fbgemm') | |
quant_ready_model = _quantize_ondevice_dynamic_jit( | |
ts_model, | |
{'': qconfig}, | |
'forward', | |
True) | |
``` | |
""" | |
return _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=inplace) | |