Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao, | |
# Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team. | |
# Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import decimal | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch.autograd import Function | |
from ...utils import logging | |
logger = logging.get_logger(__name__) | |
class QuantEmbedding(nn.Module): | |
""" | |
Quantized version of :obj:`torch.nn.Embedding`. Adds quantization-specific arguments on top of | |
:obj:`torch.nn.Embedding`. | |
Args: | |
weight_bit (:obj:`int`, `optional`, defaults to :obj:`8`): | |
Bitwidth for the quantized weight. | |
momentum (:obj:`float`, `optional`, defaults to :obj:`0.95`): | |
Momentum for updating the activation quantization range. | |
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not the layer is quantized. | |
""" | |
def __init__( | |
self, | |
num_embeddings, | |
embedding_dim, | |
padding_idx=None, | |
max_norm=None, | |
norm_type=2.0, | |
scale_grad_by_freq=False, | |
sparse=False, | |
_weight=None, | |
weight_bit=8, | |
momentum=0.95, | |
quant_mode=False, | |
): | |
super().__init__() | |
self.num_ = num_embeddings | |
self.dim = embedding_dim | |
self.padding_idx = padding_idx | |
self.max_norm = max_norm | |
self.norm_type = norm_type | |
self.scale_grad_by_freq = scale_grad_by_freq | |
self.sparse = sparse | |
self.weight = nn.Parameter(torch.zeros([num_embeddings, embedding_dim])) | |
self.register_buffer("weight_scaling_factor", torch.zeros(1)) | |
self.register_buffer("weight_integer", torch.zeros_like(self.weight)) | |
self.weight_bit = weight_bit | |
self.momentum = momentum | |
self.quant_mode = quant_mode | |
self.percentile_mode = False | |
self.weight_function = SymmetricQuantFunction.apply | |
def forward(self, x, positions=None, incremental_state=None): | |
if not self.quant_mode: | |
return ( | |
nn.functional.embedding( | |
x, | |
self.weight, | |
self.padding_idx, | |
self.max_norm, | |
self.norm_type, | |
self.scale_grad_by_freq, | |
self.sparse, | |
), | |
None, | |
) | |
w = self.weight | |
w_transform = w.data.detach() | |
w_min = w_transform.min().expand(1) | |
w_max = w_transform.max().expand(1) | |
self.weight_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, False) | |
self.weight_integer = self.weight_function( | |
self.weight, self.weight_bit, self.percentile_mode, self.weight_scaling_factor | |
) | |
emb_int = nn.functional.embedding( | |
x, | |
self.weight_integer, | |
self.padding_idx, | |
self.max_norm, | |
self.norm_type, | |
self.scale_grad_by_freq, | |
self.sparse, | |
) | |
return emb_int * self.weight_scaling_factor, self.weight_scaling_factor | |
class QuantAct(nn.Module): | |
""" | |
Quantizes the given activation. | |
Args: | |
activation_bit (:obj:`int`): | |
Bitwidth for the quantized activation. | |
act_range_momentum (:obj:`float`, `optional`, defaults to :obj:`0.95`): | |
Momentum for updating the activation quantization range. | |
per_channel (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether to or not use channel-wise quantization. | |
channel_len (:obj:`int`, `optional`): | |
Specify the channel length when set the `per_channel` True. | |
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not the layer is quantized. | |
""" | |
def __init__(self, activation_bit, act_range_momentum=0.95, per_channel=False, channel_len=None, quant_mode=False): | |
super().__init__() | |
self.activation_bit = activation_bit | |
self.act_range_momentum = act_range_momentum | |
self.quant_mode = quant_mode | |
self.per_channel = per_channel | |
self.percentile = False | |
self.act_function = SymmetricQuantFunction.apply | |
if not self.per_channel: | |
self.register_buffer("x_min", torch.zeros(1)) | |
self.register_buffer("x_max", torch.zeros(1)) | |
self.register_buffer("act_scaling_factor", torch.zeros(1)) | |
self.x_min -= 1e-5 | |
self.x_max += 1e-5 | |
else: | |
raise NotImplementedError("per-channel mode is not currently supported for activation.") | |
def __repr__(self): | |
return ( | |
f"{self.__class__.__name__}(activation_bit={self.activation_bit}, " | |
f"quant_mode: {self.activation_bit}, Act_min: {self.x_min.item():.2f}, " | |
f"Act_max: {self.x_max.item():.2f})" | |
) | |
def forward( | |
self, | |
x, | |
pre_act_scaling_factor=None, | |
identity=None, | |
identity_scaling_factor=None, | |
specified_min=None, | |
specified_max=None, | |
): | |
x_act = x if identity is None else identity + x | |
# collect running stats if training | |
if self.training: | |
assert not self.percentile, "percentile mode is not currently supported for activation." | |
assert not self.per_channel, "per-channel mode is not currently supported for activation." | |
x_min = x_act.data.min() | |
x_max = x_act.data.max() | |
assert ( | |
x_max.isnan().sum() == 0 and x_min.isnan().sum() == 0 | |
), "NaN detected when computing min/max of the activation" | |
# Initialization | |
if self.x_min.min() > -1.1e-5 and self.x_max.max() < 1.1e-5: | |
self.x_min = self.x_min + x_min | |
self.x_max = self.x_max + x_max | |
# exponential moving average (EMA) | |
# use momentum to prevent the quantized values change greatly every iteration | |
elif self.act_range_momentum == -1: | |
self.x_min = torch.min(self.x_min, x_min) | |
self.x_max = torch.max(self.x_max, x_max) | |
else: | |
self.x_min = self.x_min * self.act_range_momentum + x_min * (1 - self.act_range_momentum) | |
self.x_max = self.x_max * self.act_range_momentum + x_max * (1 - self.act_range_momentum) | |
if not self.quant_mode: | |
return x_act, None | |
x_min = self.x_min if specified_min is None else specified_min | |
x_max = self.x_max if specified_max is None else specified_max | |
self.act_scaling_factor = symmetric_linear_quantization_params( | |
self.activation_bit, x_min, x_max, per_channel=self.per_channel | |
) | |
if pre_act_scaling_factor is None: | |
# this is for the input quantization | |
quant_act_int = self.act_function(x, self.activation_bit, self.percentile, self.act_scaling_factor) | |
else: | |
quant_act_int = FixedPointMul.apply( | |
x, | |
pre_act_scaling_factor, | |
self.activation_bit, | |
self.act_scaling_factor, | |
identity, | |
identity_scaling_factor, | |
) | |
correct_output_scale = self.act_scaling_factor.view(-1) | |
return quant_act_int * correct_output_scale, self.act_scaling_factor | |
class QuantLinear(nn.Module): | |
""" | |
Quantized version of :obj:`torch.nn.Linear`. Adds quantization-specific arguments on top of :obj:`torch.nn.Linear`. | |
Args: | |
weight_bit (:obj:`int`, `optional`, defaults to :obj:`8`): | |
Bitwidth for the quantized weight. | |
bias_bit (:obj:`int`, `optional`, defaults to :obj:`32`): | |
Bitwidth for the quantized bias. | |
per_channel (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not to use channel-wise quantization. | |
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not the layer is quantized. | |
""" | |
def __init__( | |
self, in_features, out_features, bias=True, weight_bit=8, bias_bit=32, per_channel=False, quant_mode=False | |
): | |
super().__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.weight = nn.Parameter(torch.zeros([out_features, in_features])) | |
self.register_buffer("weight_integer", torch.zeros_like(self.weight)) | |
self.register_buffer("fc_scaling_factor", torch.zeros(self.out_features)) | |
if bias: | |
self.bias = nn.Parameter(torch.zeros(out_features)) | |
self.register_buffer("bias_integer", torch.zeros_like(self.bias)) | |
self.weight_bit = weight_bit | |
self.quant_mode = quant_mode | |
self.per_channel = per_channel | |
self.bias_bit = bias_bit | |
self.quant_mode = quant_mode | |
self.percentile_mode = False | |
self.weight_function = SymmetricQuantFunction.apply | |
def __repr__(self): | |
s = super().__repr__() | |
s = f"({s} weight_bit={self.weight_bit}, quant_mode={self.quant_mode})" | |
return s | |
def forward(self, x, prev_act_scaling_factor=None): | |
if not self.quant_mode: | |
return nn.functional.linear(x, weight=self.weight, bias=self.bias), None | |
# assert that prev_act_scaling_factor is a scalar tensor | |
assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == (1,), ( | |
"Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. " | |
"Please add a QuantAct layer with `per_channel = True` before this QuantAct layer" | |
) | |
w = self.weight | |
w_transform = w.data.detach() | |
if self.per_channel: | |
w_min, _ = torch.min(w_transform, dim=1, out=None) | |
w_max, _ = torch.max(w_transform, dim=1, out=None) | |
else: | |
w_min = w_transform.min().expand(1) | |
w_max = w_transform.max().expand(1) | |
self.fc_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, self.per_channel) | |
self.weight_integer = self.weight_function( | |
self.weight, self.weight_bit, self.percentile_mode, self.fc_scaling_factor | |
) | |
bias_scaling_factor = self.fc_scaling_factor * prev_act_scaling_factor | |
if self.bias is not None: | |
self.bias_integer = self.weight_function(self.bias, self.bias_bit, False, bias_scaling_factor) | |
prev_act_scaling_factor = prev_act_scaling_factor.view(1, -1) | |
x_int = x / prev_act_scaling_factor | |
return ( | |
nn.functional.linear(x_int, weight=self.weight_integer, bias=self.bias_integer) * bias_scaling_factor, | |
bias_scaling_factor, | |
) | |
class IntGELU(nn.Module): | |
""" | |
Quantized version of :obj:`torch.nn.GELU`. Adds quantization-specific arguments on top of :obj:`torch.nn.GELU`. | |
Args: | |
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not the layer is quantized. | |
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`): | |
Force dequantize the layer if either "gelu" or "nonlinear" is given. | |
""" | |
def __init__(self, quant_mode=True, force_dequant="none"): | |
super().__init__() | |
self.quant_mode = quant_mode | |
if force_dequant in ["nonlinear", "gelu"]: | |
logger.info("Force dequantize gelu") | |
self.quant_mode = False | |
if not self.quant_mode: | |
self.activation_fn = nn.GELU() | |
self.k = 1.4142 | |
self.const = 14 # dummy integer constant | |
self.coeff = [-0.2888, -1.769, 1] # a(x+b)**2 + c | |
self.coeff[2] /= self.coeff[0] | |
def int_erf(self, x_int, scaling_factor): | |
b_int = torch.floor(self.coeff[1] / scaling_factor) | |
c_int = torch.floor(self.coeff[2] / scaling_factor ** 2) | |
sign = torch.sign(x_int) | |
abs_int = torch.min(torch.abs(x_int), -b_int) | |
y_int = sign * ((abs_int + b_int) ** 2 + c_int) | |
scaling_factor = scaling_factor ** 2 * self.coeff[0] | |
# avoid overflow | |
y_int = floor_ste.apply(y_int / 2 ** self.const) | |
scaling_factor = scaling_factor * 2 ** self.const | |
return y_int, scaling_factor | |
def forward(self, x, scaling_factor=None): | |
if not self.quant_mode: | |
return self.activation_fn(x), None | |
x_int = x / scaling_factor | |
sigmoid_int, sigmoid_scaling_factor = self.int_erf(x_int, scaling_factor / self.k) | |
shift_int = 1.0 // sigmoid_scaling_factor | |
x_int = x_int * (sigmoid_int + shift_int) | |
scaling_factor = scaling_factor * sigmoid_scaling_factor / 2 | |
return x_int * scaling_factor, scaling_factor | |
class IntSoftmax(nn.Module): | |
""" | |
Quantized version of :obj:`torch.nn.Softmax`. Adds quantization-specific arguments on top of | |
:obj:`torch.nn.Softmax`. | |
Args: | |
output_bit (:obj:`int`): | |
Bitwidth for the layer output activation. | |
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not the layer is quantized. | |
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`): | |
Force dequantize the layer if either "softmax" or "nonlinear" is given. | |
""" | |
def __init__(self, output_bit, quant_mode=False, force_dequant="none"): | |
super().__init__() | |
self.output_bit = output_bit | |
self.max_bit = 32 | |
self.quant_mode = quant_mode | |
if force_dequant in ["nonlinear", "softmax"]: | |
logger.info("Force dequantize softmax") | |
self.quant_mode = False | |
self.act = QuantAct(16, quant_mode=self.quant_mode) | |
self.x0 = -0.6931 # -ln2 | |
self.const = 30 # dummy integer constant | |
self.coef = [0.35815147, 0.96963238, 1.0] # ax**2 + bx + c | |
self.coef[1] /= self.coef[0] | |
self.coef[2] /= self.coef[0] | |
def int_polynomial(self, x_int, scaling_factor): | |
with torch.no_grad(): | |
b_int = torch.floor(self.coef[1] / scaling_factor) | |
c_int = torch.floor(self.coef[2] / scaling_factor ** 2) | |
z = (x_int + b_int) * x_int + c_int | |
scaling_factor = self.coef[0] * scaling_factor ** 2 | |
return z, scaling_factor | |
def int_exp(self, x_int, scaling_factor): | |
with torch.no_grad(): | |
x0_int = torch.floor(self.x0 / scaling_factor) | |
x_int = torch.max(x_int, self.const * x0_int) | |
q = floor_ste.apply(x_int / x0_int) | |
r = x_int - x0_int * q | |
exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor) | |
exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.const - q)), min=0) | |
scaling_factor = exp_scaling_factor / 2 ** self.const | |
return exp_int, scaling_factor | |
def forward(self, x, scaling_factor): | |
if not self.quant_mode: | |
return nn.Softmax(dim=-1)(x), None | |
x_int = x / scaling_factor | |
x_int_max, _ = x_int.max(dim=-1, keepdim=True) | |
x_int = x_int - x_int_max | |
exp_int, exp_scaling_factor = self.int_exp(x_int, scaling_factor) | |
# Avoid overflow | |
exp, exp_scaling_factor = self.act(exp_int, exp_scaling_factor) | |
exp_int = exp / exp_scaling_factor | |
exp_int_sum = exp_int.sum(dim=-1, keepdim=True) | |
factor = floor_ste.apply(2 ** self.max_bit / exp_int_sum) | |
exp_int = floor_ste.apply(exp_int * factor / 2 ** (self.max_bit - self.output_bit)) | |
scaling_factor = 1 / 2 ** self.output_bit | |
return exp_int * scaling_factor, scaling_factor | |
class IntLayerNorm(nn.Module): | |
""" | |
Quantized version of :obj:`torch.nn.LayerNorm`. Adds quantization-specific arguments on top of | |
:obj:`torch.nn.LayerNorm`. | |
Args: | |
output_bit (:obj:`int`, `optional`, defaults to :obj:`8`): | |
Bitwidth for the layer output activation. | |
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not the layer is quantized. | |
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`): | |
Force dequantize the layer if either "layernorm" or "nonlinear" is given. | |
""" | |
def __init__(self, normalized_shape, eps, output_bit=8, quant_mode=False, force_dequant="none"): | |
super().__init__() | |
self.normalized_shape = normalized_shape | |
self.eps = eps | |
self.weight = nn.Parameter(torch.zeros(normalized_shape)) | |
self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
self.quant_mode = quant_mode | |
if force_dequant in ["nonlinear", "layernorm"]: | |
logger.info("Force dequantize layernorm") | |
self.quant_mode = False | |
self.register_buffer("shift", torch.zeros(1)) | |
self.output_bit = output_bit | |
self.max_bit = 32 | |
self.dim_sqrt = None | |
self.activation = QuantAct(self.output_bit, quant_mode=self.quant_mode) | |
def set_shift(self, y_int): | |
with torch.no_grad(): | |
y_sq_int = y_int ** 2 | |
var_int = torch.sum(y_sq_int, axis=2, keepdim=True) | |
shift = (torch.log2(torch.sqrt(var_int / 2 ** self.max_bit)).ceil()).max() | |
shift_old = self.shift | |
self.shift = torch.max(self.shift, shift) | |
logger.info(f"Dynamic shift adjustment: {int(shift_old)} -> {int(self.shift)}") | |
def overflow_fallback(self, y_int): | |
""" | |
This fallback function is called when overflow is detected during training time, and adjusts the `self.shift` | |
to avoid overflow in the subsequent runs. | |
""" | |
self.set_shift(y_int) # adjusts `self.shift` | |
y_int_shifted = floor_ste.apply(y_int / 2 ** self.shift) | |
y_sq_int = y_int_shifted ** 2 | |
var_int = torch.sum(y_sq_int, axis=2, keepdim=True) | |
return var_int | |
def forward(self, x, scaling_factor=None): | |
if not self.quant_mode: | |
mean = x.mean(axis=2, keepdim=True) | |
y = x - mean | |
var = torch.mean(y ** 2, axis=2, keepdim=True) | |
x = y / torch.sqrt(self.eps + var) | |
x = x * self.weight + self.bias | |
return x, None | |
# compute sqrt of the feature dimension if it is the first run | |
if self.dim_sqrt is None: | |
n = torch.tensor(x.shape[2], dtype=torch.float) | |
self.dim_sqrt = torch.sqrt(n).to(x.device) | |
# Normalization: computes mean and variance(std) | |
x_int = x / scaling_factor | |
mean_int = round_ste.apply(x_int.mean(axis=2, keepdim=True)) | |
y_int = x_int - mean_int | |
y_int_shifted = floor_ste.apply(y_int / 2 ** self.shift) | |
y_sq_int = y_int_shifted ** 2 | |
var_int = torch.sum(y_sq_int, axis=2, keepdim=True) | |
# overflow handling in training time | |
if self.training: | |
# if overflow is detected | |
if var_int.max() >= 2 ** self.max_bit: | |
var_int = self.overflow_fallback(y_int) | |
assert var_int.max() < 2 ** self.max_bit + 0.1, ( | |
"Error detected in overflow handling: " | |
"`var_int` exceeds `self.max_bit` (the maximum possible bit width)" | |
) | |
# To be replaced with integer-sqrt kernel that produces the same output | |
std_int = floor_ste.apply(torch.sqrt(var_int)) * 2 ** self.shift | |
factor = floor_ste.apply(2 ** 31 / std_int) | |
y_int = floor_ste.apply(y_int * factor / 2) | |
scaling_factor = self.dim_sqrt / 2 ** 30 | |
# scaling and shifting | |
bias = self.bias.data.detach() / (self.weight.data.detach()) | |
bias_int = floor_ste.apply(bias / scaling_factor) | |
y_int = y_int + bias_int | |
scaling_factor = scaling_factor * self.weight | |
x = y_int * scaling_factor | |
return x, scaling_factor | |
def get_percentile_min_max(input, lower_percentile, upper_percentile, output_tensor=False): | |
""" | |
Calculate the percentile max and min values in a given tensor | |
Args: | |
input (:obj:`torch.Tensor`): | |
The target tensor to calculate percentile max and min. | |
lower_percentile (:obj:`float`): | |
If 0.1, means we return the value of the smallest 0.1% value in the tensor as percentile min. | |
upper_percentile (:obj:`float`): | |
If 99.9, means we return the value of the largest 0.1% value in the tensor as percentile max. | |
output_tensor (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
If True, this function returns tensors, otherwise it returns values. | |
Returns: | |
:obj:`Tuple(torch.Tensor, torch.Tensor)`: Percentile min and max value of `input` | |
""" | |
input_length = input.shape[0] | |
lower_index = round(input_length * (1 - lower_percentile * 0.01)) | |
upper_index = round(input_length * upper_percentile * 0.01) | |
upper_bound = torch.kthvalue(input, k=upper_index).values | |
if lower_percentile == 0: | |
lower_bound = upper_bound * 0 | |
# lower_index += 1 | |
else: | |
lower_bound = -torch.kthvalue(-input, k=lower_index).values | |
if not output_tensor: | |
lower_bound = lower_bound.item() | |
upper_bound = upper_bound.item() | |
return lower_bound, upper_bound | |
def linear_quantize(input, scale, zero_point, inplace=False): | |
""" | |
Quantize single-precision input tensor to integers with the given scaling factor and zeropoint. | |
Args: | |
input (:obj:`torch.Tensor`): | |
Single-precision input tensor to be quantized. | |
scale (:obj:`torch.Tensor`): | |
Scaling factor for quantization. | |
zero_pint (:obj:`torch.Tensor`): | |
Shift for quantization. | |
inplace (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether to compute inplace or not. | |
Returns: | |
:obj:`torch.Tensor`: Linearly quantized value of `input` according to `scale` and `zero_point`. | |
""" | |
# reshape scale and zeropoint for convolutional weights and activation | |
if len(input.shape) == 4: | |
scale = scale.view(-1, 1, 1, 1) | |
zero_point = zero_point.view(-1, 1, 1, 1) | |
# reshape scale and zeropoint for linear weights | |
elif len(input.shape) == 2: | |
scale = scale.view(-1, 1) | |
zero_point = zero_point.view(-1, 1) | |
else: | |
scale = scale.view(-1) | |
zero_point = zero_point.view(-1) | |
# quantized = float / scale + zero_point | |
if inplace: | |
input.mul_(1.0 / scale).add_(zero_point).round_() | |
return input | |
return torch.round(1.0 / scale * input + zero_point) | |
def symmetric_linear_quantization_params(num_bits, saturation_min, saturation_max, per_channel=False): | |
""" | |
Compute the scaling factor with the given quantization range for symmetric quantization. | |
Args: | |
saturation_min (:obj:`torch.Tensor`): | |
Lower bound for quantization range. | |
saturation_max (:obj:`torch.Tensor`): | |
Upper bound for quantization range. | |
per_channel (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether to or not use channel-wise quantization. | |
Returns: | |
:obj:`torch.Tensor`: Scaling factor that linearly quantizes the given range between `saturation_min` and | |
`saturation_max`. | |
""" | |
# in this part, we do not need any gradient computation, | |
# in order to enforce this, we put torch.no_grad() | |
with torch.no_grad(): | |
n = 2 ** (num_bits - 1) - 1 | |
if per_channel: | |
scale, _ = torch.max(torch.stack([saturation_min.abs(), saturation_max.abs()], dim=1), dim=1) | |
scale = torch.clamp(scale, min=1e-8) / n | |
else: | |
scale = max(saturation_min.abs(), saturation_max.abs()) | |
scale = torch.clamp(scale, min=1e-8) / n | |
return scale | |
class SymmetricQuantFunction(Function): | |
""" | |
Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth. | |
""" | |
def forward(ctx, x, k, percentile_mode, scale): | |
""" | |
Args: | |
x (:obj:`torch.Tensor`): | |
Floating point tensor to be quantized. | |
k (:obj:`int`): | |
Quantization bitwidth. | |
percentile_mode (:obj:`bool`): | |
Whether or not to use percentile calibration. | |
scale (:obj:`torch.Tensor`): | |
Pre-calculated scaling factor for `x`. Note that the current implementation of SymmetricQuantFunction | |
requires pre-calculated scaling factor. | |
Returns: | |
:obj:`torch.Tensor`: Symmetric-quantized value of `input`. | |
""" | |
zero_point = torch.tensor(0.0).to(scale.device) | |
n = 2 ** (k - 1) - 1 | |
new_quant_x = linear_quantize(x, scale, zero_point, inplace=False) | |
new_quant_x = torch.clamp(new_quant_x, -n, n - 1) | |
ctx.scale = scale | |
return new_quant_x | |
def backward(ctx, grad_output): | |
scale = ctx.scale | |
if len(grad_output.shape) == 4: | |
scale = scale.view(-1, 1, 1, 1) | |
# reshape scale and zeropoint for linear weights | |
elif len(grad_output.shape) == 2: | |
scale = scale.view(-1, 1) | |
else: | |
scale = scale.view(-1) | |
return grad_output.clone() / scale, None, None, None, None | |
class floor_ste(Function): | |
""" | |
Straight-through Estimator(STE) for torch.floor() | |
""" | |
def forward(ctx, x): | |
return torch.floor(x) | |
def backward(ctx, grad_output): | |
return grad_output.clone() | |
class round_ste(Function): | |
""" | |
Straight-through Estimator(STE) for torch.round() | |
""" | |
def forward(ctx, x): | |
return torch.round(x) | |
def backward(ctx, grad_output): | |
return grad_output.clone() | |
def batch_frexp(inputs, max_bit=31): | |
""" | |
Decompose the scaling factor into mantissa and twos exponent. | |
Args: | |
scaling_factor (:obj:`torch.Tensor`): | |
Target scaling factor to decompose. | |
Returns: | |
:obj:``Tuple(torch.Tensor, torch.Tensor)`: mantisa and exponent | |
""" | |
shape_of_input = inputs.size() | |
# trans the input to be a 1-d tensor | |
inputs = inputs.view(-1) | |
output_m, output_e = np.frexp(inputs.cpu().numpy()) | |
tmp_m = [] | |
for m in output_m: | |
int_m_shifted = int( | |
decimal.Decimal(m * (2 ** max_bit)).quantize(decimal.Decimal("1"), rounding=decimal.ROUND_HALF_UP) | |
) | |
tmp_m.append(int_m_shifted) | |
output_m = np.array(tmp_m) | |
output_e = float(max_bit) - output_e | |
return ( | |
torch.from_numpy(output_m).to(inputs.device).view(shape_of_input), | |
torch.from_numpy(output_e).to(inputs.device).view(shape_of_input), | |
) | |
class FixedPointMul(Function): | |
""" | |
Function to perform fixed-point arithmetic that can match integer arithmetic on hardware. | |
Args: | |
pre_act (:obj:`torch.Tensor`): | |
Input tensor. | |
pre_act_scaling_factor (:obj:`torch.Tensor`): | |
Scaling factor of the input tensor `pre_act`. | |
bit_num (:obj:`int`): | |
Quantization bitwidth. | |
z_scaling_factor (:obj:`torch.Tensor`): | |
Scaling factor of the output tensor. | |
identity (:obj:`torch.Tensor`, `optional`): | |
Identity tensor, if exists. | |
identity_scaling_factor (:obj:`torch.Tensor`, `optional`): | |
Scaling factor of the identity tensor `identity`, if exists. | |
Returns: | |
:obj:`torch.Tensor`: Output tensor(`pre_act` if `identity` is not given, otherwise the addition of `pre_act` | |
and `identity`), whose scale is rescaled to `z_scaling_factor`. | |
""" | |
def forward( | |
ctx, | |
pre_act, | |
pre_act_scaling_factor, | |
bit_num, | |
z_scaling_factor, | |
identity=None, | |
identity_scaling_factor=None, | |
): | |
if len(pre_act_scaling_factor.shape) == 3: | |
reshape = lambda x: x # noqa: E731 | |
else: | |
reshape = lambda x: x.view(1, 1, -1) # noqa: E731 | |
ctx.identity = identity | |
n = 2 ** (bit_num - 1) - 1 | |
with torch.no_grad(): | |
pre_act_scaling_factor = reshape(pre_act_scaling_factor) | |
if identity is not None: | |
identity_scaling_factor = reshape(identity_scaling_factor) | |
ctx.z_scaling_factor = z_scaling_factor | |
z_int = torch.round(pre_act / pre_act_scaling_factor) | |
_A = pre_act_scaling_factor.type(torch.double) | |
_B = (z_scaling_factor.type(torch.float)).type(torch.double) | |
new_scale = _A / _B | |
new_scale = reshape(new_scale) | |
m, e = batch_frexp(new_scale) | |
output = z_int.type(torch.double) * m.type(torch.double) | |
output = torch.round(output / (2.0 ** e)) | |
if identity is not None: | |
# needs addition of identity activation | |
wx_int = torch.round(identity / identity_scaling_factor) | |
_A = identity_scaling_factor.type(torch.double) | |
_B = (z_scaling_factor.type(torch.float)).type(torch.double) | |
new_scale = _A / _B | |
new_scale = reshape(new_scale) | |
m1, e1 = batch_frexp(new_scale) | |
output1 = wx_int.type(torch.double) * m1.type(torch.double) | |
output1 = torch.round(output1 / (2.0 ** e1)) | |
output = output1 + output | |
return torch.clamp(output.type(torch.float), -n - 1, n) | |
def backward(ctx, grad_output): | |
identity_grad = None | |
if ctx.identity is not None: | |
identity_grad = grad_output.clone() / ctx.z_scaling_factor | |
return grad_output.clone() / ctx.z_scaling_factor, None, None, None, None, identity_grad, None | |