Spaces:
Running
Running
""" | |
This module implements observers which are used to collect statistics about | |
the values observed during calibration (PTQ) or training (QAT). | |
""" | |
import re | |
import warnings | |
from abc import ABCMeta, abstractmethod | |
from collections import OrderedDict | |
from functools import partial | |
from typing import Any, List, Tuple, Optional, Dict | |
import torch | |
import torch.nn as nn | |
from torch.ao.quantization.utils import ( | |
check_min_max_valid, calculate_qmin_qmax, is_per_tensor, is_per_channel, validate_qmin_qmax) | |
__all__ = [ | |
"default_affine_fixed_qparams_observer", | |
"default_debug_observer", | |
"default_dynamic_quant_observer", | |
"default_fixed_qparams_range_0to1_observer", | |
"default_fixed_qparams_range_neg1to1_observer", | |
"default_float_qparams_observer", | |
"default_float_qparams_observer_4bit", | |
"default_histogram_observer", | |
"default_observer", | |
"default_per_channel_weight_observer", | |
"default_placeholder_observer", | |
"default_reuse_input_observer", | |
"default_symmetric_fixed_qparams_observer", | |
"default_weight_observer", | |
"get_observer_state_dict", | |
"load_observer_state_dict", | |
"per_channel_weight_observer_range_neg_127_to_127", | |
"weight_observer_range_neg_127_to_127", | |
"FixedQParamsObserver", | |
"HistogramObserver", | |
"MinMaxObserver", | |
"MovingAverageMinMaxObserver", | |
"MovingAveragePerChannelMinMaxObserver", | |
"NoopObserver", | |
"ObserverBase", | |
"PerChannelMinMaxObserver", | |
"PlaceholderObserver", | |
"RecordingObserver", | |
"ReuseInputObserver", | |
"UniformQuantizationObserverBase", | |
] | |
class _PartialWrapper: | |
def __init__(self, p): | |
self.p = p | |
self.callable_args = {} | |
def __call__(self, *args, **keywords): | |
# call each arg in callable_args and add them partial, then run with keywords | |
# skip if arg_name in keywords so its possible to overwrite | |
for arg_name in self.callable_args: | |
if arg_name not in keywords: | |
keywords = {**keywords, arg_name: self.callable_args[arg_name]()} | |
return self.p(*args, **keywords) | |
def __repr__(self): | |
return self.p.__repr__() + self.callable_args.__repr__() | |
def with_args(self, **kwargs): | |
return _with_args(self, **kwargs) | |
def with_callable_args(self, **kwargs): | |
result = _PartialWrapper(p=self.p) | |
result.callable_args = {**self.callable_args, **kwargs} | |
return result | |
def _with_args(cls_or_self, **kwargs): | |
r"""Wrapper that allows creation of class factories. | |
This can be useful when there is a need to create classes with the same | |
constructor arguments, but different instances. Can be used in conjunction with | |
_callable_args | |
Example:: | |
>>> # xdoctest: +SKIP("Undefined vars") | |
>>> Foo.with_args = classmethod(_with_args) | |
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) | |
>>> foo_instance1 = foo_builder() | |
>>> foo_instance2 = foo_builder() | |
>>> id(foo_instance1) == id(foo_instance2) | |
False | |
""" | |
r = _PartialWrapper(partial(cls_or_self, **kwargs)) | |
return r | |
def _with_callable_args(cls_or_self, **kwargs): | |
r"""Wrapper that allows creation of class factories args that need to be | |
called at construction time. | |
This can be useful when there is a need to create classes with the same | |
constructor arguments, but different instances and those arguments should only | |
be calculated at construction time. Can be used in conjunction with _with_args | |
Example:: | |
>>> # xdoctest: +SKIP("Undefined vars") | |
>>> Foo.with_callable_args = classmethod(_with_callable_args) | |
>>> Foo.with_args = classmethod(_with_args) | |
>>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan") | |
>>> foo_instance1 = foo_builder() | |
>>> # wait 50 | |
>>> foo_instance2 = foo_builder() | |
>>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time) | |
False | |
""" | |
r = _PartialWrapper(partial(cls_or_self)) | |
return r.with_callable_args(**kwargs) | |
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: | |
class ObserverBase(ABC, nn.Module): | |
r"""Base observer Module. | |
Any observer implementation should derive from this class. | |
Concrete observers should follow the same API. In forward, they will update | |
the statistics of the observed Tensor. And they should provide a | |
`calculate_qparams` function that computes the quantization parameters given | |
the collected statistics. | |
Args: | |
dtype: dtype argument to the `quantize` node needed to implement the | |
reference model spec. | |
is_dynamic: indicator for whether the observer is a placeholder for dynamic quantization | |
or static quantization | |
""" | |
def __init__(self, dtype, is_dynamic=False): | |
super().__init__() | |
self.dtype = dtype | |
self.is_dynamic = is_dynamic | |
def forward(self, x): | |
pass | |
def calculate_qparams(self, **kwargs): | |
pass | |
with_args = classmethod(_with_args) | |
with_callable_args = classmethod(_with_callable_args) | |
class UniformQuantizationObserverBase(ObserverBase): | |
r"""Common base for all observers using uniform quantization to calculate | |
scale and zero_point. | |
Args: | |
dtype: dtype argument to the `quantize` node needed to implement the | |
reference model spec. | |
qscheme: Quantization scheme to be used. | |
reduce_range: Reduces the range of the quantized data type by 1 bit. | |
This is sometimes required to avoid instruction overflow. | |
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. | |
quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. | |
eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. | |
.. warning:: | |
:attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. | |
or `torch.int8` or `torch.uint8` | |
.. warning:: | |
:attr:`qscheme` can only take one of the following options: | |
- ``torch.per_tensor_affine`` | |
- ``torch.per_tensor_symmetric`` | |
- ``torch.per_channel_affine`` | |
- ``torch.per_channel_symmetric`` | |
""" | |
# Note: the version is shared by all observer types | |
# | |
# Version 1/None | |
# self | |
# | |
# Version 2 (base class only, does not include child class buffers) | |
# self | |
# |--- eps : Tensor | |
# | |
# Version 3 | |
# for HistogramObserver only, changed the shape of uninitialized | |
# min_val and max_val buffers from torch.Size([0]) to torch.Size([]) | |
# for PerChannelObservers, changed the name of the buffers from min_vals | |
# to min_val and from max_vals to max_val. | |
_version = 3 | |
eps: torch.Tensor | |
def __init__( | |
self, | |
dtype=torch.quint8, | |
qscheme=torch.per_tensor_affine, | |
reduce_range=False, | |
quant_min=None, | |
quant_max=None, | |
factory_kwargs=None, | |
eps=torch.finfo(torch.float32).eps, | |
is_dynamic=False, | |
**kwargs, | |
) -> None: | |
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) | |
super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) | |
self.qscheme = qscheme | |
if reduce_range: | |
warnings.warn( | |
"Please use quant_min and quant_max to specify the range for observers. \ | |
reduce_range will be deprecated in a future release of PyTorch." | |
) | |
self.reduce_range = reduce_range | |
self.register_buffer( | |
"eps", torch.tensor([eps], **factory_kwargs) | |
) | |
assert self.qscheme in ( | |
torch.per_tensor_affine, | |
torch.per_tensor_symmetric, | |
torch.per_channel_affine, | |
torch.per_channel_symmetric, | |
torch.per_channel_affine_float_qparams, | |
), "Default Observer only works for per_tensor_affine, \ | |
per_tensor_symmetric, per_channel_affine, \ | |
per_channel_symmetric and per_channel_float_qparams quantization scheme" | |
_ALLOWED_DTYPES = ( | |
torch.qint8, | |
torch.quint8, | |
torch.quint4x2, | |
torch.qint32, | |
torch.int8, | |
torch.uint8, | |
torch.int16, | |
torch.int32, | |
) | |
assert self.dtype in _ALLOWED_DTYPES, f"Default Observer only works for {_ALLOWED_DTYPES} data type" | |
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) | |
if self.has_customized_qrange: | |
validate_qmin_qmax(quant_min, quant_max) | |
self.quant_min, self.quant_max = \ | |
calculate_qmin_qmax(quant_min, quant_max, self.has_customized_qrange, self.dtype, self.reduce_range) | |
def _load_from_state_dict( | |
self, | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
): | |
version = local_metadata.get("version", None) | |
if version is None or version == 1: | |
# eps was moved to a buffer in version 2 | |
eps = torch.tensor([torch.finfo(torch.float32).eps]) | |
state_dict[prefix + "eps"] = eps | |
super()._load_from_state_dict( | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
) | |
def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None: | |
r"""Validates that the user-specified quantization range is properly initialized | |
and within the given bound supported by the observer dtype. | |
To accommodate lower-bit quantization with respect to the existing torch.qint8 and | |
torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing | |
in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax | |
values are used to calculate static estimates of the scale and zero point for aggressive lower-bit | |
fake quantization. These estimates are compared against parameters learned through backpropagation. | |
The related literatures for scale and zero point via backpropagation are as follows: | |
Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS | |
Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf | |
""" | |
# The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted | |
# based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. | |
assert ( | |
quant_min <= 0 <= quant_max | |
), "Used-specified quantization range must include 0." | |
assert ( | |
quant_min < quant_max | |
), "qmin must be strictly less than qmax for user-specified quantization range." | |
def _calculate_qparams( | |
self, min_val: torch.Tensor, max_val: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
r"""Calculates the quantization parameters, given min and max | |
value tensors. Works for both per tensor and per channel cases | |
Args: | |
min_val: Minimum values per channel | |
max_val: Maximum values per channel | |
Returns: | |
scales: Scales tensor of shape (#channels,) | |
zero_points: Zero points tensor of shape (#channels,) | |
""" | |
# Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme | |
# as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer | |
# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code | |
# seems unlikey to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. | |
# TODO(jakeszwe, jerryzh168) | |
if not check_min_max_valid(min_val, max_val): | |
return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type) | |
quant_min, quant_max = self.quant_min, self.quant_max | |
min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) | |
max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) | |
device = min_val_neg.device | |
scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device) | |
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) | |
if ( | |
self.qscheme == torch.per_tensor_symmetric | |
or self.qscheme == torch.per_channel_symmetric | |
): | |
max_val_pos = torch.max(-min_val_neg, max_val_pos) | |
scale = max_val_pos / (float(quant_max - quant_min) / 2) | |
scale = torch.max(scale, self.eps) | |
if self.dtype in [torch.quint8, torch.uint8]: | |
if self.has_customized_qrange: | |
# When customized quantization range is used, down-rounded midpoint of the range is chosen. | |
zero_point = zero_point.new_full( | |
zero_point.size(), (quant_min + quant_max) // 2 | |
) | |
else: | |
zero_point = zero_point.new_full(zero_point.size(), 128) | |
elif self.qscheme == torch.per_channel_affine_float_qparams: | |
scale = (max_val - min_val) / float(quant_max - quant_min) | |
scale = torch.where(scale > self.eps, scale, torch.ones_like(scale)) | |
# We use the quantize function | |
# xq = Round(Xf * inv_scale + zero_point), | |
# setting zero_point to (-1 * min *inv_scale) we get | |
# Xq = Round((Xf - min) * inv_scale) | |
zero_point = -1 * min_val / scale | |
else: | |
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) | |
scale = torch.max(scale, self.eps) | |
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) | |
zero_point = torch.clamp(zero_point, quant_min, quant_max) | |
# For scalar values, cast them to Tensors of size 1 to keep the shape | |
# consistent with default values in FakeQuantize. | |
if len(scale.shape) == 0: | |
# TODO: switch to scale.item() after adding JIT support | |
scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) | |
if len(zero_point.shape) == 0: | |
# TODO: switch to zero_point.item() after adding JIT support | |
zero_point = torch.tensor( | |
[int(zero_point)], dtype=zero_point.dtype, device=device | |
) | |
if self.qscheme == torch.per_channel_affine_float_qparams: | |
zero_point = torch.tensor( | |
[float(zero_point)], dtype=zero_point.dtype, device=device | |
) | |
return scale, zero_point | |
def reset_min_max_vals(self): | |
raise NotImplementedError("Cannot reset min/max values in the given observer.") | |
# Originally, this class was called `_ObserverBase`. Keeping the old name around | |
# for backwards compatibility. | |
# TODO(after v1.13): delete this | |
_ObserverBase = UniformQuantizationObserverBase | |
class MinMaxObserver(UniformQuantizationObserverBase): | |
r"""Observer module for computing the quantization parameters based on the | |
running min and max values. | |
This observer uses the tensor min/max statistics to compute the quantization | |
parameters. The module records the running minimum and maximum of incoming | |
tensors, and uses this statistic to compute the quantization parameters. | |
Args: | |
dtype: dtype argument to the `quantize` node needed to implement the | |
reference model spec. | |
qscheme: Quantization scheme to be used | |
reduce_range: Reduces the range of the quantized data type by 1 bit | |
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. | |
quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. | |
eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. | |
Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`, | |
scale :math:`s` and zero point :math:`z` are computed as: | |
The running minimum/maximum :math:`x_\text{min/max}` is computed as: | |
.. math:: | |
\begin{array}{ll} | |
x_\text{min} &= \begin{cases} | |
\min(X) & \text{if~}x_\text{min} = \text{None} \\ | |
\min\left(x_\text{min}, \min(X)\right) & \text{otherwise} | |
\end{cases}\\ | |
x_\text{max} &= \begin{cases} | |
\max(X) & \text{if~}x_\text{max} = \text{None} \\ | |
\max\left(x_\text{max}, \max(X)\right) & \text{otherwise} | |
\end{cases}\\ | |
\end{array} | |
where :math:`X` is the observed tensor. | |
The scale :math:`s` and zero point :math:`z` are then computed as: | |
.. math:: | |
\begin{aligned} | |
\text{if Symmetric:}&\\ | |
&s = 2 \max(|x_\text{min}|, x_\text{max}) / | |
\left( Q_\text{max} - Q_\text{min} \right) \\ | |
&z = \begin{cases} | |
0 & \text{if dtype is qint8} \\ | |
128 & \text{otherwise} | |
\end{cases}\\ | |
\text{Otherwise:}&\\ | |
&s = \left( x_\text{max} - x_\text{min} \right ) / | |
\left( Q_\text{max} - Q_\text{min} \right ) \\ | |
&z = Q_\text{min} - \text{round}(x_\text{min} / s) | |
\end{aligned} | |
where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and | |
maximum of the quantized data type. | |
.. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. | |
.. note:: If the running minimum equals to the running maximum, the scale | |
and zero_point are set to 1.0 and 0. | |
""" | |
min_val: torch.Tensor | |
max_val: torch.Tensor | |
def __init__( | |
self, | |
dtype=torch.quint8, | |
qscheme=torch.per_tensor_affine, | |
reduce_range=False, | |
quant_min=None, | |
quant_max=None, | |
factory_kwargs=None, | |
eps=torch.finfo(torch.float32).eps, | |
is_dynamic=False, | |
**kwargs, | |
) -> None: | |
if not is_per_tensor(qscheme): | |
raise NotImplementedError( | |
"MinMaxObserver's qscheme only support torch.per_tensor_symmetric \ | |
and torch.per_tensor_affine." | |
) | |
# TODO: MinMaxObserver by itself doesn't support dynamic quantization, but | |
# if it's inherited by MovingAverageObserver, and averaging_constant is 1, it | |
# supports dynamic quantization, we may need to better error checking here | |
# For x86 quantized kernels, we need to ensure that the vpmaddubsw | |
# instruction does not overflow. We allow for a reduce_range argument to | |
# observers that reduces the quantized range to (0,127) or (-64, 63). | |
# For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp | |
# This is not an optimal choice for non x86 backends as it loses a bit | |
# of precision for activations. | |
super().__init__( | |
dtype=dtype, | |
qscheme=qscheme, | |
reduce_range=reduce_range, | |
quant_min=quant_min, | |
quant_max=quant_max, | |
factory_kwargs=factory_kwargs, | |
eps=eps, | |
is_dynamic=is_dynamic, | |
**kwargs, | |
) | |
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) | |
self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) | |
self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) | |
if ( | |
self.qscheme == torch.per_tensor_symmetric | |
and self.reduce_range | |
and self.dtype == torch.quint8 | |
): | |
raise NotImplementedError( | |
"Cannot reduce range for symmetric \ | |
quantization for quint8" | |
) | |
def forward(self, x_orig): | |
r"""Records the running minimum and maximum of ``x``.""" | |
if x_orig.numel() == 0: | |
return x_orig | |
x = x_orig.detach() # avoid keeping autograd tape | |
x = x.to(self.min_val.dtype) | |
min_val_cur, max_val_cur = torch.aminmax(x) | |
min_val = torch.min(min_val_cur, self.min_val) | |
max_val = torch.max(max_val_cur, self.max_val) | |
self.min_val.copy_(min_val) | |
self.max_val.copy_(max_val) | |
return x_orig | |
def calculate_qparams(self): | |
r"""Calculates the quantization parameters.""" | |
return self._calculate_qparams(self.min_val, self.max_val) | |
def extra_repr(self): | |
return f"min_val={self.min_val}, max_val={self.max_val}" | |
def reset_min_max_vals(self): | |
"""Resets the min/max values.""" | |
self.min_val.copy_(torch.tensor(float("inf"))) | |
self.max_val.copy_(torch.tensor(float("-inf"))) | |
class MovingAverageMinMaxObserver(MinMaxObserver): | |
r"""Observer module for computing the quantization parameters based on the | |
moving average of the min and max values. | |
This observer computes the quantization parameters based on the moving | |
averages of minimums and maximums of the incoming tensors. The module | |
records the average minimum and maximum of incoming tensors, and uses this | |
statistic to compute the quantization parameters. | |
Args: | |
averaging_constant: Averaging constant for min/max. | |
dtype: dtype argument to the `quantize` node needed to implement the | |
reference model spec. | |
qscheme: Quantization scheme to be used | |
reduce_range: Reduces the range of the quantized data type by 1 bit | |
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. | |
quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. | |
eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. | |
The moving average min/max is computed as follows | |
.. math:: | |
\begin{array}{ll} | |
x_\text{min} = \begin{cases} | |
\min(X) & \text{if~}x_\text{min} = \text{None} \\ | |
(1 - c) x_\text{min} + c \min(X) & \text{otherwise} | |
\end{cases}\\ | |
x_\text{max} = \begin{cases} | |
\max(X) & \text{if~}x_\text{max} = \text{None} \\ | |
(1 - c) x_\text{max} + c \max(X) & \text{otherwise} | |
\end{cases}\\ | |
\end{array} | |
where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is | |
is the incoming tensor, and :math:`c` is the ``averaging_constant``. | |
The scale and zero point are then computed as in | |
:class:`~torch.ao.quantization.observer.MinMaxObserver`. | |
.. note:: Only works with ``torch.per_tensor_affine`` quantization scheme. | |
.. note:: If the running minimum equals to the running maximum, the scale | |
and zero_point are set to 1.0 and 0. | |
""" | |
def __init__( | |
self, | |
averaging_constant=0.01, | |
dtype=torch.quint8, | |
qscheme=torch.per_tensor_affine, | |
reduce_range=False, | |
quant_min=None, | |
quant_max=None, | |
eps=torch.finfo(torch.float32).eps, | |
is_dynamic=False, | |
**kwargs | |
) -> None: | |
if not is_per_tensor(qscheme): | |
raise NotImplementedError( | |
f"MovingAverageMinMaxObserver's qscheme only support \ | |
torch.per_tensor_symmetric and torch.per_tensor_affine. \ | |
but got: {qscheme}" | |
) | |
self.averaging_constant = averaging_constant | |
if is_dynamic and self.averaging_constant != 1: | |
raise NotImplementedError( | |
"MovingAverageMinMaxObserver doesn't support dynamic quantization for " | |
f"averaging constant of {self.averaging_constant}" | |
) | |
super().__init__( | |
dtype=dtype, | |
qscheme=qscheme, | |
reduce_range=reduce_range, | |
quant_min=quant_min, | |
quant_max=quant_max, | |
eps=eps, | |
is_dynamic=is_dynamic, | |
**kwargs | |
) | |
def forward(self, x_orig): | |
if x_orig.numel() == 0: | |
return x_orig | |
x = x_orig.detach() # avoid keeping autograd tape | |
x = x.to(self.min_val.dtype) | |
min_val = self.min_val | |
max_val = self.max_val | |
if min_val == float("inf") and max_val == float("-inf"): | |
min_val, max_val = torch.aminmax(x) | |
else: | |
min_val_cur, max_val_cur = torch.aminmax(x) | |
min_val = min_val + self.averaging_constant * (min_val_cur - min_val) | |
max_val = max_val + self.averaging_constant * (max_val_cur - max_val) | |
self.min_val.copy_(min_val) | |
self.max_val.copy_(max_val) | |
return x_orig | |
class PerChannelMinMaxObserver(UniformQuantizationObserverBase): | |
r"""Observer module for computing the quantization parameters based on the | |
running per channel min and max values. | |
This observer uses the tensor min/max statistics to compute the per channel | |
quantization parameters. The module records the running minimum and maximum | |
of incoming tensors, and uses this statistic to compute the quantization | |
parameters. | |
Args: | |
ch_axis: Channel axis | |
dtype: dtype argument to the `quantize` node needed to implement the | |
reference model spec. | |
qscheme: Quantization scheme to be used | |
reduce_range: Reduces the range of the quantized data type by 1 bit | |
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. | |
quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. | |
eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. | |
The quantization parameters are computed the same way as in | |
:class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference | |
that the running min/max values are stored per channel. | |
Scales and zero points are thus computed per channel as well. | |
.. note:: If the running minimum equals to the running maximum, the scales | |
and zero_points are set to 1.0 and 0. | |
""" | |
min_val: torch.Tensor | |
max_val: torch.Tensor | |
def __init__( | |
self, | |
ch_axis=0, | |
dtype=torch.quint8, | |
qscheme=torch.per_channel_affine, | |
reduce_range=False, | |
quant_min=None, | |
quant_max=None, | |
factory_kwargs=None, | |
eps=torch.finfo(torch.float32).eps, | |
is_dynamic=False, | |
**kwargs, | |
) -> None: | |
if not is_per_channel(qscheme): | |
raise NotImplementedError( | |
"PerChannelMinMaxObserver's qscheme only support \ | |
torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." | |
) | |
if is_dynamic: | |
raise NotImplementedError( | |
"PerChannelMinMaxObserver doesn't support dynamic quantization" | |
) | |
super().__init__( | |
dtype=dtype, | |
qscheme=qscheme, | |
reduce_range=reduce_range, | |
quant_min=quant_min, | |
quant_max=quant_max, | |
factory_kwargs=factory_kwargs, | |
eps=eps, | |
is_dynamic=is_dynamic, | |
**kwargs, | |
) | |
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) | |
self.ch_axis = ch_axis | |
self.register_buffer("min_val", torch.tensor([], **factory_kwargs)) | |
self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) | |
if ( | |
self.qscheme == torch.per_channel_symmetric | |
and self.reduce_range | |
and self.dtype == torch.quint8 | |
): | |
raise NotImplementedError( | |
"Cannot reduce range for symmetric quantization for quint8" | |
) | |
def forward(self, x_orig): | |
return self._forward(x_orig) | |
def _forward(self, x_orig): | |
if x_orig.numel() == 0: | |
return x_orig | |
x = x_orig.detach() # avoid keeping autograd tape | |
min_val = self.min_val | |
max_val = self.max_val | |
x_dim = x.size() | |
new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 | |
new_axis_list[self.ch_axis] = 0 | |
new_axis_list[0] = self.ch_axis | |
y = x.permute(new_axis_list) | |
# Need to match dtype of min/max because the updates to buffers | |
# are done in place and types need to match for comparisons | |
y = y.to(self.min_val.dtype) | |
y = torch.flatten(y, start_dim=1) | |
if min_val.numel() == 0 or max_val.numel() == 0: | |
min_val, max_val = torch.aminmax(y, dim=1) | |
else: | |
min_val_cur, max_val_cur = torch.aminmax(y, dim=1) | |
min_val = torch.min(min_val_cur, min_val) | |
max_val = torch.max(max_val_cur, max_val) | |
self.min_val.resize_(min_val.shape) | |
self.max_val.resize_(max_val.shape) | |
self.min_val.copy_(min_val) | |
self.max_val.copy_(max_val) | |
return x_orig | |
def calculate_qparams(self): | |
return self._calculate_qparams(self.min_val, self.max_val) | |
def extra_repr(self): | |
return f"min_val={self.min_val}, max_val={self.max_val}" | |
def _load_from_state_dict( | |
self, | |
state_dict: Dict[str, Any], | |
prefix: str, | |
local_metadata: Dict[str, torch.Tensor], | |
strict: bool, | |
missing_keys: List[str], | |
unexpected_keys: List[str], | |
error_msgs: List[str], | |
): | |
version = local_metadata.get("version", None) | |
if version is not None and version < 3: | |
local_state = ["min_vals", "max_vals"] | |
expected_min_name = "min_vals" | |
expected_max_name = "max_vals" | |
else: | |
local_state = ["min_val", "max_val"] | |
expected_min_name = "min_val" | |
expected_max_name = "max_val" | |
for name in local_state: | |
key = prefix + name | |
if key in state_dict: | |
val = state_dict[key] | |
# Custom handling to allow loading min_val or max_val | |
# of size N into uninitialized buffers of size 0. The | |
# buffers are resized here, and the values are copied in | |
# the default state_dict loading code of the parent. | |
if name == expected_min_name: | |
self.min_val.resize_(val.shape) | |
elif name == expected_max_name: | |
self.max_val.resize_(val.shape) | |
else: | |
warnings.warn(f"Observer load_from_state_dict got unexpected name {name}") | |
# For torchscript module we need to update the attributes here since we do not | |
# call the `_load_from_state_dict` function defined module.py | |
if torch.jit.is_scripting(): | |
if name == expected_min_name: | |
self.min_val.copy_(val) | |
elif name == expected_max_name: | |
self.max_val.copy_(val) | |
else: | |
warnings.warn(f"Observer load_from_state_dict got unexpected name {name}") | |
elif strict: | |
missing_keys.append(key) | |
if not torch.jit.is_scripting(): | |
super()._load_from_state_dict( | |
state_dict, | |
prefix, | |
local_metadata, | |
False, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
) | |
def _load_from_state_dict_script( | |
self, | |
state_dict: Dict[str, Any], | |
prefix: str, | |
local_metadata: Dict[str, torch.Tensor], | |
strict: bool, | |
missing_keys: List[str], | |
unexpected_keys: List[str], | |
error_msgs: List[str], | |
): | |
self._load_from_state_dict( | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
) | |
def reset_min_max_vals(self): | |
"""Resets the min/max values.""" | |
# This used to be torch.ones but that does not work because | |
# JIT compiler can optimize it via common subexpression elimination | |
# in which case both min_val and max_val point to the same tensor. | |
self.min_val = torch.rand(0, ) | |
self.max_val = torch.rand(0, ) | |
class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): | |
r"""Observer module for computing the quantization parameters based on the | |
running per channel min and max values. | |
This observer uses the tensor min/max statistics to compute the per channel | |
quantization parameters. The module records the running minimum and maximum | |
of incoming tensors, and uses this statistic to compute the quantization | |
parameters. | |
Args: | |
averaging_constant: Averaging constant for min/max. | |
ch_axis: Channel axis | |
dtype: Quantized data type | |
qscheme: Quantization scheme to be used | |
reduce_range: Reduces the range of the quantized data type by 1 bit | |
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. | |
quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. | |
eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. | |
The quantization parameters are computed the same way as in | |
:class:`~torch.ao.quantization.observer.MovingAverageMinMaxObserver`, with the | |
difference that the running min/max values are stored per channel. | |
Scales and zero points are thus computed per channel as well. | |
.. note:: If the running minimum equals to the running maximum, the scales | |
and zero_points are set to 1.0 and 0. | |
""" | |
def __init__( | |
self, | |
averaging_constant=0.01, | |
ch_axis=0, | |
dtype=torch.quint8, | |
qscheme=torch.per_channel_affine, | |
reduce_range=False, | |
quant_min=None, | |
quant_max=None, | |
eps=torch.finfo(torch.float32).eps, | |
is_dynamic=False, | |
**kwargs | |
) -> None: | |
if not is_per_channel(qscheme): | |
raise NotImplementedError( | |
"MovingAveragePerChannelMinMaxObserver's qscheme only support \ | |
torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." | |
) | |
if is_dynamic: | |
raise NotImplementedError( | |
"MovingAveragePerChannelMinMaxObserver doesn't support dynamic quantization" | |
) | |
super().__init__( | |
ch_axis=ch_axis, | |
dtype=dtype, | |
qscheme=qscheme, | |
reduce_range=reduce_range, | |
quant_min=quant_min, | |
quant_max=quant_max, | |
eps=eps, | |
is_dynamic=is_dynamic, | |
**kwargs | |
) | |
self.averaging_constant = averaging_constant | |
def forward(self, x_orig): | |
if x_orig.numel() == 0: | |
return x_orig | |
x = x_orig.detach() # avoid keeping autograd tape | |
x = x.to(self.min_val.dtype) | |
min_val = self.min_val | |
max_val = self.max_val | |
x_dim = x.size() | |
new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 | |
new_axis_list[self.ch_axis] = 0 | |
new_axis_list[0] = self.ch_axis | |
y = x.permute(new_axis_list) | |
y = torch.flatten(y, start_dim=1) | |
if min_val.numel() == 0 or max_val.numel() == 0: | |
min_val, max_val = torch.aminmax(y, dim=1) | |
else: | |
min_val_cur, max_val_cur = torch.aminmax(y, dim=1) | |
min_val = min_val + self.averaging_constant * (min_val_cur - min_val) | |
max_val = max_val + self.averaging_constant * (max_val_cur - max_val) | |
self.min_val.resize_(min_val.shape) | |
self.max_val.resize_(max_val.shape) | |
self.min_val.copy_(min_val) | |
self.max_val.copy_(max_val) | |
return x_orig | |
class HistogramObserver(UniformQuantizationObserverBase): | |
r""" | |
The module records the running histogram of tensor values along with | |
min/max values. ``calculate_qparams`` will calculate scale and zero_point. | |
Args: | |
bins: Number of bins to use for the histogram | |
upsample_rate: Factor by which the histograms are upsampled, this is | |
used to interpolate histograms with varying ranges across observations | |
dtype: dtype argument to the `quantize` node needed to implement the | |
reference model spec | |
qscheme: Quantization scheme to be used | |
reduce_range: Reduces the range of the quantized data type by 1 bit | |
eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. | |
The scale and zero point are computed as follows: | |
1. Create the histogram of the incoming inputs. | |
The histogram is computed continuously, and the ranges per bin change | |
with every new tensor observed. | |
2. Search the distribution in the histogram for optimal min/max values. | |
The search for the min/max values ensures the minimization of the | |
quantization error with respect to the floating point model. | |
3. Compute the scale and zero point the same way as in the | |
:class:`~torch.ao.quantization.MinMaxObserver` | |
""" | |
histogram: torch.Tensor | |
min_val: torch.Tensor | |
max_val: torch.Tensor | |
def __init__( | |
self, | |
bins: int = 2048, | |
upsample_rate: int = 128, | |
dtype: torch.dtype = torch.quint8, | |
qscheme=torch.per_tensor_affine, | |
reduce_range=False, | |
quant_min=None, | |
quant_max=None, | |
factory_kwargs=None, | |
eps=torch.finfo(torch.float32).eps, | |
is_dynamic=False, | |
**kwargs, | |
) -> None: | |
if not is_per_tensor(qscheme): | |
raise NotImplementedError( | |
"HistogramObserver's qscheme only support torch.per_tensor_symmetric \ | |
and torch.per_tensor_affine." | |
) | |
if is_dynamic: | |
raise NotImplementedError( | |
"HistogramObserver doesn't support dynamic quantization" | |
) | |
# bins: The number of bins used for histogram calculation. | |
super().__init__( | |
dtype=dtype, | |
qscheme=qscheme, | |
reduce_range=reduce_range, | |
quant_min=quant_min, | |
quant_max=quant_max, | |
factory_kwargs=factory_kwargs, | |
eps=eps, | |
is_dynamic=is_dynamic, | |
**kwargs | |
) | |
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) | |
self.bins = bins | |
self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs)) | |
self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) | |
self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) | |
self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits | |
self.upsample_rate = upsample_rate | |
def _get_norm( | |
self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor | |
) -> torch.Tensor: | |
r""" | |
Compute the norm of the values uniformaly distributed between | |
delta_begin and delta_end. | |
Currently only L2 norm is supported. | |
norm = density * (integral_{begin, end} x^2) | |
= density * (end^3 - begin^3) / 3 | |
""" | |
norm = ( | |
delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin | |
) / 3 | |
return density * norm | |
def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int): | |
r""" | |
Compute the quantization error if we use start_bin to end_bin as the | |
min and max to do the quantization. | |
""" | |
bin_width = (self.max_val.item() - self.min_val.item()) / self.bins | |
dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins | |
if dst_bin_width == 0.0: | |
return 0.0 | |
src_bin = torch.arange(self.bins, device=self.histogram.device) | |
# distances from the beginning of first dst_bin to the beginning and | |
# end of src_bin | |
src_bin_begin = (src_bin - next_start_bin) * bin_width | |
src_bin_end = src_bin_begin + bin_width | |
# which dst_bins the beginning and end of src_bin belong to? | |
dst_bin_of_begin = torch.clamp( | |
torch.div(src_bin_begin, dst_bin_width, rounding_mode='floor'), 0, self.dst_nbins - 1 | |
) | |
dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width | |
dst_bin_of_end = torch.clamp( | |
torch.div(src_bin_end, dst_bin_width, rounding_mode='floor'), 0, self.dst_nbins - 1 | |
) | |
density = self.histogram / bin_width | |
norm = torch.zeros(self.bins, device=self.histogram.device) | |
delta_begin = src_bin_begin - dst_bin_of_begin_center | |
delta_end = dst_bin_width / 2 | |
norm += self._get_norm(delta_begin, | |
torch.ones(self.bins, device=self.histogram.device) * delta_end, | |
density) | |
norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm( | |
torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density | |
) | |
dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2 | |
delta_begin = -dst_bin_width / 2 | |
delta_end = src_bin_end - dst_bin_of_end_center | |
norm += self._get_norm(torch.tensor(delta_begin), delta_end, density) | |
return norm.sum().item() | |
def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]: | |
r"""Non-linear parameter search. | |
An approximation for L2 error minimization for selecting min/max. | |
By selecting new min/max, we filter out outliers in input distribution. | |
This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in | |
caffe2/quantization/server/norm_minimization.cc | |
""" | |
assert self.histogram.size()[0] == self.bins, "bins mismatch" | |
bin_width = (self.max_val - self.min_val) / self.bins | |
# cumulative sum | |
total = torch.sum(self.histogram).item() | |
cSum = torch.cumsum(self.histogram, dim=0) | |
stepsize = 1e-5 # granularity | |
alpha = 0.0 # lower bound | |
beta = 1.0 # upper bound | |
start_bin = 0 | |
end_bin = self.bins - 1 | |
norm_min = float("inf") | |
while alpha < beta: | |
# Find the next step | |
next_alpha = alpha + stepsize | |
next_beta = beta - stepsize | |
# find the left and right bins between the quantile bounds | |
l = start_bin | |
r = end_bin | |
while l < end_bin and cSum[l] < next_alpha * total: | |
l = l + 1 | |
while r > start_bin and cSum[r] > next_beta * total: | |
r = r - 1 | |
# decide the next move | |
next_start_bin = start_bin | |
next_end_bin = end_bin | |
if (l - start_bin) > (end_bin - r): | |
# move the start bin | |
next_start_bin = l | |
alpha = next_alpha | |
else: | |
# move the end bin | |
next_end_bin = r | |
beta = next_beta | |
if next_start_bin == start_bin and next_end_bin == end_bin: | |
continue | |
# calculate the quantization error using next_start_bin and next_end_bin | |
norm = self._compute_quantization_error(next_start_bin, next_end_bin) | |
if norm > norm_min: | |
break | |
norm_min = norm | |
start_bin = next_start_bin | |
end_bin = next_end_bin | |
new_min = self.min_val + bin_width * start_bin | |
new_max = self.min_val + bin_width * (end_bin + 1) | |
return new_min, new_max | |
def _adjust_min_max( | |
self, combined_min: torch.Tensor, combined_max: torch.Tensor, upsample_rate: int | |
) -> Tuple[torch.Tensor, torch.Tensor, int, int]: | |
# We ensure that: | |
# (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins) | |
# This allows us to have a common grid of resolution s, where we can align | |
# the input histogram | |
# start_idx maps min_val to the histogram bin index. | |
# Compute the width of histogram bins is a straightforward solution, where | |
# hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate) | |
# Underflow happens if the numerator is close to the smallest positive subnormal number of FP32 | |
# Therefore, we avoid such division operation. | |
downsample_rate = int( | |
torch.ceil( | |
((combined_max - combined_min) / (self.max_val - self.min_val)) * upsample_rate | |
).item() | |
) | |
e = downsample_rate / upsample_rate * (self.max_val - self.min_val) - (combined_max - combined_min) | |
start_idx = int( | |
torch.round((self.min_val - combined_min) / (self.max_val - self.min_val) * self.bins * upsample_rate).item() | |
) | |
combined_max = combined_max + e | |
return combined_min, combined_max, downsample_rate, start_idx | |
def _combine_histograms( | |
self, | |
orig_hist: torch.Tensor, | |
new_hist: torch.Tensor, | |
upsample_rate: int, | |
downsample_rate: int, | |
start_idx: int, | |
Nbins: int, | |
) -> torch.Tensor: | |
# First up-sample the histogram with new data by a factor of L | |
# This creates an approximate probability density thats piecewise constant | |
upsampled_histogram = new_hist.repeat_interleave(upsample_rate) | |
# Now insert the upsampled histogram into the output | |
# histogram, which is initialized with zeros. | |
# The offset at which the histogram is introduced is determined | |
# by the start index as the output histogram can cover a wider range | |
histogram_with_output_range = torch.zeros( | |
(Nbins * downsample_rate), device=orig_hist.device | |
) | |
histogram_with_output_range[ | |
start_idx : Nbins * upsample_rate + start_idx | |
] = upsampled_histogram | |
# Compute integral histogram, double precision is needed to ensure | |
# that there are no overflows | |
integral_histogram = torch.cumsum( | |
histogram_with_output_range, 0, dtype=torch.double | |
)[downsample_rate - 1 :: downsample_rate] | |
# Finally perform interpolation | |
shifted_integral_histogram = torch.zeros((Nbins), device=orig_hist.device) | |
shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1] | |
interpolated_histogram = ( | |
integral_histogram - shifted_integral_histogram | |
) / upsample_rate | |
orig_hist = orig_hist + interpolated_histogram.to(torch.float) | |
return orig_hist | |
def forward(self, x_orig: torch.Tensor) -> torch.Tensor: | |
if x_orig.numel() == 0: | |
return x_orig | |
x = x_orig.detach() | |
x_min, x_max = torch.aminmax(x) | |
# want to ignore torch.inf since we don't actually | |
# want to make our quantization range infinite | |
# and in practice those values will be clamped | |
if x_min == -torch.inf or x_max == torch.inf: | |
warnings.warn("torch.inf detected in input tensor, ignoring input") | |
x = x[x.abs() != torch.inf] | |
if x.numel() == 0: | |
return x_orig | |
x_min, x_max = torch.aminmax(x) | |
min_val = self.min_val | |
max_val = self.max_val | |
same_values = min_val.item() == max_val.item() | |
is_uninitialized = min_val == float("inf") and max_val == float("-inf") | |
if is_uninitialized or same_values: | |
min_val, max_val = x_min, x_max | |
self.min_val.resize_(min_val.shape) | |
self.min_val.copy_(min_val) | |
self.max_val.resize_(max_val.shape) | |
self.max_val.copy_(max_val) | |
assert ( | |
min_val.numel() == 1 and max_val.numel() == 1 | |
), "histogram min/max values must be scalar." | |
torch.histc( | |
x, self.bins, min=min_val, max=max_val, out=self.histogram # type: ignore[arg-type] | |
) | |
else: | |
new_min, new_max = x_min, x_max | |
combined_min = torch.min(new_min, min_val) | |
combined_max = torch.max(new_max, max_val) | |
# combine the existing histogram and new histogram into 1 histogram | |
# We do this by first upsampling the histogram to a dense grid | |
# and then downsampling the histogram efficiently | |
( | |
combined_min, | |
combined_max, | |
downsample_rate, | |
start_idx, | |
) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate) | |
assert ( | |
combined_min.numel() == 1 and combined_max.numel() == 1 | |
), "histogram min/max values must be scalar." | |
# TODO: For some reason, this is required for it to pass torchscript test | |
# combined_min and combined_max should already have requires_grad set to False | |
combined_min, combined_max = combined_min.detach(), combined_max.detach() | |
combined_histogram = torch.histc( | |
x, self.bins, min=combined_min, max=combined_max # type: ignore[arg-type] | |
) | |
if combined_min == min_val and combined_max == max_val: | |
combined_histogram += self.histogram | |
else: | |
combined_histogram = self._combine_histograms( | |
combined_histogram, | |
self.histogram, | |
self.upsample_rate, | |
downsample_rate, | |
start_idx, | |
self.bins, | |
) | |
self.histogram.detach_().resize_(combined_histogram.shape) | |
self.histogram.copy_(combined_histogram) | |
self.min_val.detach_().resize_(combined_min.shape) | |
self.min_val.copy_(combined_min) | |
self.max_val.detach_().resize_(combined_max.shape) | |
self.max_val.copy_(combined_max) | |
return x_orig | |
def calculate_qparams(self): | |
is_uninitialized = self.min_val == float("inf") and self.max_val == float( | |
"-inf" | |
) | |
if is_uninitialized: | |
warnings.warn( | |
"must run observer before calling calculate_qparams.\ | |
Returning default scale and zero point " | |
) | |
return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor([0], device=self.min_val.device.type) | |
assert self.bins == len(self.histogram), ( | |
"The number of bins in histogram should be equal to the number of bins " | |
"supplied while making this observer" | |
) | |
new_min, new_max = self._non_linear_param_search() | |
return self._calculate_qparams(new_min, new_max) | |
def _save_to_state_dict(self, destination, prefix, keep_vars): | |
super()._save_to_state_dict(destination, prefix, keep_vars) | |
destination[prefix + "min_val"] = self.min_val | |
destination[prefix + "max_val"] = self.max_val | |
def _load_from_state_dict( | |
self, | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
): | |
version = local_metadata.get("version", None) | |
if version is None or version < 3: | |
# if min_val and max_val are not initialized, update their shape | |
# to account for the differences between v2 and v3 | |
min_val_name, max_val_name = prefix + "min_val", prefix + "max_val" | |
if min_val_name in state_dict: | |
if state_dict[min_val_name].shape == torch.Size([0]): | |
state_dict[min_val_name] = torch.tensor(float("inf")) | |
if max_val_name in state_dict: | |
if state_dict[max_val_name].shape == torch.Size([0]): | |
state_dict[max_val_name] = torch.tensor(float("-inf")) | |
local_state = ["min_val", "max_val"] | |
for name in local_state: | |
key = prefix + name | |
if key in state_dict: | |
val = state_dict[key] | |
setattr(self, name, val) | |
elif strict: | |
missing_keys.append(key) | |
super()._load_from_state_dict( | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
) | |
def extra_repr(self): | |
return f"min_val={self.min_val}, max_val={self.max_val}" | |
class FixedQParamsObserver(ObserverBase): | |
r""" | |
Observer that simulates quantize and dequantize with fixed | |
quantization parameters in training time. Only per tensor | |
quantization is supported. | |
Args: | |
`scale` (float): fixed scale for the observer | |
`zero_point` (int): fixed zero point for the observer | |
`dtype`, `qscheme`, `quant_min`, `quant_max` | |
""" | |
scale: torch.Tensor | |
zero_point: torch.Tensor | |
def __init__( | |
self, | |
scale, | |
zero_point, | |
dtype=torch.quint8, | |
qscheme=torch.per_tensor_affine, | |
quant_min=0, | |
quant_max=255, | |
is_dynamic=False, | |
**kwargs, | |
): | |
if is_dynamic: | |
raise NotImplementedError( | |
"FixedQParamsObserver doesn't support dynamic quantization" | |
) | |
super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) | |
self.quant_min = quant_min | |
self.quant_max = quant_max | |
self.register_buffer('scale', torch.tensor([scale], dtype=torch.float)) | |
self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.int)) | |
self.dtype = dtype | |
self.qscheme = qscheme | |
def forward(self, X): | |
return X | |
def calculate_qparams(self): | |
return self.scale, self.zero_point | |
class PlaceholderObserver(ObserverBase): | |
r""" | |
Observer that doesn't do anything and just passes its configuration to the | |
quantized module's ``.from_float()``. | |
Can be used for quantization to float16 which doesn't require determining | |
ranges. | |
Args: | |
dtype: dtype argument to the `quantize` node needed to implement the | |
reference model spec. | |
quant_min: minimum value in quantized domain (TODO: align behavior with other observers) | |
quant_max: maximum value in quantized domain | |
custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation | |
(Can be used in Graph Mode Passes for special case ops). | |
compute_dtype (deprecated): if set, marks the future quantize function to use | |
dynamic quantization instead of static quantization. | |
This field is deprecated, use `is_dynamic=True` instead. | |
is_dynamic: if True, the `quantize` function in the reference model | |
representation taking stats from this observer instance will | |
use dynamic quantization. | |
""" | |
def __init__( | |
self, dtype=torch.float32, custom_op_name="", compute_dtype=None, | |
quant_min=None, quant_max=None, qscheme=None, eps=None, | |
is_dynamic=False, | |
) -> None: | |
super().__init__(dtype=dtype, is_dynamic=is_dynamic) | |
if qscheme is None: | |
qscheme = torch.per_tensor_affine | |
if eps is None: | |
eps = torch.finfo(torch.float32).eps | |
# dtype of input of the target operator, e.g. for dynamic quantization | |
# ops, the dtype will be float32 | |
self.dtype = dtype | |
self.qscheme = qscheme | |
self.quant_min = quant_min | |
self.quant_max = quant_max | |
self.eps = eps | |
self.custom_op = custom_op_name | |
# used for configuration of computation type for dynamic quantization | |
if compute_dtype: | |
is_dynamic = True | |
warnings.warn( | |
"Please use `is_dynamic` instead of `compute_dtype`. \ | |
`compute_dtype` will be deprecated in a future release \ | |
of PyTorch." | |
) | |
def forward(self, x): | |
return x | |
def extra_repr(self): | |
return f"dtype={self.dtype}, is_dynamic={self.is_dynamic}" | |
def calculate_qparams(self): | |
raise Exception( | |
"calculate_qparams should not be called for PlaceholderObserver" | |
) | |
class RecordingObserver(ObserverBase): | |
r""" | |
The module is mainly for debug and records the tensor values during runtime. | |
Args: | |
dtype: Quantized data type | |
qscheme: Quantization scheme to be used | |
reduce_range: Reduces the range of the quantized data type by 1 bit | |
""" | |
__annotations__ = {"tensor_val": List[Optional[torch.Tensor]]} | |
def __init__(self, dtype=torch.quint8): | |
super().__init__(dtype=dtype, is_dynamic=False) # type: ignore[call-arg] | |
self.tensor_val = [] | |
def forward(self, x): | |
self.tensor_val.append(x.clone()) | |
return x | |
def calculate_qparams(self): | |
raise Exception("calculate_qparams should not be called for RecordingObserver") | |
def get_tensor_value(self): | |
return self.tensor_val | |
class NoopObserver(ObserverBase): | |
r""" | |
Observer that doesn't do anything and just passes its configuration to the | |
quantized module's ``.from_float()``. | |
Primarily used for quantization to float16 which doesn't require determining | |
ranges. | |
Args: | |
dtype: Quantized data type | |
custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation | |
(Can be used in Graph Mode Passes for special case ops). | |
""" | |
def __init__(self, dtype=torch.float16, custom_op_name="") -> None: | |
super().__init__(dtype=dtype, is_dynamic=False) | |
self.dtype = dtype | |
self.custom_op = custom_op_name | |
def forward(self, x): | |
return x | |
def calculate_qparams(self): | |
raise Exception("calculate_qparams should not be called for NoopObserver") | |
class ReuseInputObserver(ObserverBase): | |
r""" This observer is used when we want to reuse the observer from the operator | |
that produces the input Tensor, typically used for operators like reshape, e.g. | |
``` | |
x0 = ... | |
x1 = x0.reshape() | |
``` | |
if we configure x0 to be observed by some observer, let's say MinMaxObserver, | |
and reshape is configured with ReuseInputObserver, we'll reuse the observer instance | |
for x0 for x1 (output of reshape). If x0 is not observed, we also won't observe x1. | |
Note: this is only enabled in FX Graph Mode Quantization | |
""" | |
def __init__(self): | |
super().__init__(torch.quint8, is_dynamic=False) | |
def forward(self, x): | |
return x | |
def calculate_qparams(self): | |
raise Exception("calculate_qparams should not be called for ReuseInputObserver") | |
def _is_observer_script_module(mod, obs_type_name): | |
"""Returns true if given mod is an instance of Observer script module.""" | |
if isinstance(mod, torch.jit.RecursiveScriptModule): | |
# qualified name looks like '__torch__.torch.ao.quantization.observer.___torch_mangle_2.MinMaxObserver' | |
suffix = mod._c.qualified_name.split(".", 1)[1] | |
name = re.sub(r"\.___torch_mangle_\d+", "", suffix) | |
return obs_type_name in name | |
return False | |
def _is_activation_post_process(module): | |
return ( | |
isinstance(module, (torch.ao.quantization.ObserverBase, | |
torch.ao.quantization.FakeQuantizeBase)) or _is_observer_script_module(module, "quantization.observer") | |
) | |
def _is_per_channel_script_obs_instance(module): | |
if isinstance(module, torch.jit.RecursiveScriptModule): | |
return _is_observer_script_module( | |
module, "quantization.observer.PerChannelMinMaxObserver" | |
) or _is_observer_script_module( | |
module, "quantization.observer.MovingAveragePerChannelMinMaxObserver" | |
) | |
return False | |
def get_observer_state_dict(mod): | |
r""" | |
Returns the state dict corresponding to the observer stats. | |
Traverse the model state_dict and extract out the stats. | |
""" | |
od = OrderedDict() | |
if isinstance(mod, torch.jit.RecursiveScriptModule): | |
for k, v in mod.state_dict().items(): | |
if "observer" in k: | |
od[k] = v | |
else: | |
# path for GraphModule and nn.Module (eager mode) | |
for k, v in mod.state_dict().items(): | |
if "activation_post_process" in k: | |
od[k] = v | |
od._metadata = mod.state_dict()._metadata # type: ignore[attr-defined] | |
return od | |
def load_observer_state_dict(mod, obs_dict): | |
r""" | |
Given input model and a state_dict containing model observer stats, | |
load the stats back into the model. The observer state_dict can be saved | |
using torch.ao.quantization.get_observer_state_dict | |
""" | |
missing_keys: List[str] = [] | |
unexpected_keys: List[str] = [] | |
for name, module in mod.named_modules(): | |
prefix = name + "." | |
if _is_activation_post_process(module): | |
if _is_per_channel_script_obs_instance(module): | |
# For per-channel observers we need to call a custom load_from_state_dict to resize the tensor. | |
# However this is not called when the module is scripted and we end up calling the default one in module.py | |
module._load_from_state_dict_script( | |
obs_dict, prefix, {}, True, missing_keys, unexpected_keys, [] | |
) | |
else: | |
module._load_from_state_dict( | |
obs_dict, prefix, {}, False, missing_keys, unexpected_keys, [] | |
) | |
for k in missing_keys: | |
if "observer" in k or "activation_post_process" in k: | |
raise Exception(f"Missing keys for observer {k} in state_dict") | |
for k in unexpected_keys: | |
if "observer" in k or "activation_post_process" in k: | |
raise Exception(f"Unexpected keys for observer {k} in state_dict") | |
# Restrict activations to be in the range (0,127) | |
default_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127) | |
""" | |
Default observer for static quantization, usually used for debugging. | |
""" | |
default_placeholder_observer = PlaceholderObserver | |
""" | |
Default placeholder observer, usually used for quantization to torch.float16. | |
""" | |
default_debug_observer = RecordingObserver | |
""" | |
Default debug-only observer. | |
""" | |
default_weight_observer = MinMaxObserver.with_args( | |
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric | |
) | |
""" | |
Default weight observer. | |
""" | |
weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args( | |
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, | |
quant_min=-127, quant_max=127, eps=2 ** -12) | |
""" | |
Symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. | |
""" | |
default_histogram_observer = HistogramObserver.with_args(quant_min=0, quant_max=127) | |
""" | |
Default histogram observer, usually used for PTQ. | |
""" | |
default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args( | |
dtype=torch.qint8, qscheme=torch.per_channel_symmetric | |
) | |
""" | |
Default per-channel weight observer, usually used on backends where per-channel | |
weight quantization is supported, such as `fbgemm`. | |
""" | |
per_channel_weight_observer_range_neg_127_to_127 = PerChannelMinMaxObserver.with_args( | |
dtype=torch.qint8, qscheme=torch.per_channel_symmetric, | |
quant_min=-127, quant_max=127, eps=2 ** -12) | |
""" | |
Per-channel, symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. | |
""" | |
default_dynamic_quant_observer = PlaceholderObserver.with_args( | |
dtype=torch.quint8, quant_min=0, quant_max=255, is_dynamic=True, | |
) | |
""" | |
Default observer for dynamic quantization. | |
""" | |
default_float_qparams_observer = PerChannelMinMaxObserver.with_args( | |
dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 | |
) | |
""" | |
Default observer for a floating point zero-point. | |
""" | |
default_float_qparams_observer_4bit = PerChannelMinMaxObserver.with_args( | |
dtype=torch.quint4x2, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 | |
) | |
""" | |
Default observer for a floating point zero-point and 4 bit activations. | |
""" | |
# TODO(future PR): remove these defaults and enforce activation functions | |
# to explicitly specify their output range | |
default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args( | |
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255) | |
default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args( | |
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255) | |
# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases | |
default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer | |
default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer | |
""" | |
Default observers for fixed qparams operations. | |
""" | |
default_reuse_input_observer = ReuseInputObserver | |
""" | |
Default observer for operators like reshape that reuses the observer of input to | |
the operator | |
""" | |