Spaces:
Running
Running
File size: 25,310 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 |
"""Implements modules used to perform fake quantization."""
import torch
from torch.nn import Module
from torch.ao.quantization.observer import (
MovingAverageMinMaxObserver,
HistogramObserver,
MovingAveragePerChannelMinMaxObserver,
FixedQParamsObserver,
default_fixed_qparams_range_0to1_observer,
default_fixed_qparams_range_neg1to1_observer,
_with_args,
)
import re
from abc import ABC, abstractmethod
from typing import Any, Tuple
__all__ = [
"FakeQuantizeBase",
"FakeQuantize",
"FixedQParamsFakeQuantize",
"FusedMovingAvgObsFakeQuantize",
"disable_fake_quant",
"disable_observer",
"enable_fake_quant",
"enable_observer",
"default_fake_quant",
"default_weight_fake_quant",
"default_dynamic_fake_quant",
"default_fixed_qparams_range_neg1to1_fake_quant",
"default_fixed_qparams_range_0to1_fake_quant",
"default_symmetric_fixed_qparams_fake_quant",
"default_affine_fixed_qparams_fake_quant",
"default_per_channel_weight_fake_quant",
"default_embedding_fake_quant",
"default_embedding_fake_quant_4bit",
"default_histogram_fake_quant",
"default_fused_act_fake_quant",
"default_fused_wt_fake_quant",
"default_fused_per_channel_wt_fake_quant",
"fused_wt_fake_quant_range_neg_127_to_127",
"fused_per_channel_wt_fake_quant_range_neg_127_to_127",
]
def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams]
def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool:
return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
def _is_float_qparams(qscheme: 'torch.qscheme') -> bool:
return qscheme in [torch.per_channel_affine_float_qparams, ]
class FakeQuantizeBase(ABC, Module):
r"""Base fake quantize module.
Base fake quantize module
Any fake quantize implementation should derive from this class.
Concrete fake quantize module should follow the same API. In forward, they will update
the statistics of the observed Tensor and fake quantize the input. They should also provide a
`calculate_qparams` function that computes the quantization parameters given
the collected statistics.
"""
fake_quant_enabled: torch.Tensor
observer_enabled: torch.Tensor
def __init__(self):
"""Set fake_quant_enabled and observer_enabled."""
super().__init__()
# fake_quant_enabled and observer_enabled are buffers to support their
# replication in DDP. Data type is uint8 because NCCL does not support
# bool tensors.
self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
@abstractmethod
def forward(self, x):
pass
@abstractmethod
def calculate_qparams(self, **kwargs):
pass
@torch.jit.export
def enable_fake_quant(self, enabled: bool = True) -> None:
self.fake_quant_enabled[0] = 1 if enabled else 0
@torch.jit.export
def disable_fake_quant(self):
self.enable_fake_quant(False)
@torch.jit.export
def enable_observer(self, enabled: bool = True) -> None:
self.observer_enabled[0] = 1 if enabled else 0
@torch.jit.export
def disable_observer(self):
self.enable_observer(False)
@classmethod
def with_args(cls, **kwargs):
fake_quant_constructor = _with_args(cls, **kwargs)
# need to assign the correct module to fake_quantize
# constructors to satisfy public v private requirements
fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize"
return fake_quant_constructor
class FakeQuantize(FakeQuantizeBase):
r"""Simulate the quantize and dequantize operations in training time.
The output of this module is given by::
x_out = (
clamp(round(x/scale + zero_point), quant_min, quant_max) - zero_point
) * scale
* :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization
operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq)
* :attr:`scale` defines the scale factor used for quantization.
* :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
* :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that
statistics can still be updated.
* :attr:`observer_enabled` controls statistics collection on tensors
* :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
allowable values are torch.qint8 and torch.quint8.
Args:
observer (module): Module for observing statistics on input tensors and calculating scale
and zero-point.
observer_kwargs (optional): Arguments for the observer module
Attributes:
activation_post_process (Module): User provided module that collects statistics on the input tensor and
provides a method to calculate scale and zero-point.
"""
scale: torch.Tensor
zero_point: torch.Tensor
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=None, quant_max=None, is_dynamic=False, **observer_kwargs):
super().__init__()
# Populate quant_min/quant_max to observer_kwargs if valid
if quant_min is not None and quant_max is not None:
assert quant_min <= quant_max, \
'quant_min must be less than or equal to quant_max'
dtype = observer_kwargs.get("dtype", torch.quint8)
if hasattr(observer, "p"):
# In case observer is _PartialWrapper, dtype can be stored in
# observer.p.keywords["dtype"]
dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
"dtype", dtype
)
assert torch.iinfo(dtype).min <= quant_min, 'quant_min out of bound'
assert quant_max <= torch.iinfo(dtype).max, 'quant_max out of bound'
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
observer_kwargs["is_dynamic"] = is_dynamic
self.activation_post_process = observer(**observer_kwargs)
# TODO: keeping self.quant_min/max for BC; remove after a couple releases
# Users should use self.activation_post_process.quant_min
self.quant_min = self.activation_post_process.quant_min
self.quant_max = self.activation_post_process.quant_max
self.is_dynamic = self.activation_post_process.is_dynamic
if _is_float_qparams(self.activation_post_process.qscheme):
zero_point_dtype = torch.float
else:
zero_point_dtype = torch.int
self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
self.register_buffer('zero_point', torch.tensor([0], dtype=zero_point_dtype))
self.dtype = self.activation_post_process.dtype
self.qscheme = self.activation_post_process.qscheme
self.ch_axis = self.activation_post_process.ch_axis \
if hasattr(self.activation_post_process, 'ch_axis') else -1
assert _is_per_channel(self.qscheme) or \
_is_per_tensor(self.qscheme), \
'Only per channel and per tensor quantization are supported in fake quantize' + \
' got qscheme: ' + str(self.qscheme)
self.is_per_channel = _is_per_channel(self.qscheme)
@torch.jit.export
def calculate_qparams(self):
return self.activation_post_process.calculate_qparams()
def forward(self, X):
if self.observer_enabled[0] == 1:
self.activation_post_process(X.detach())
_scale, _zero_point = self.calculate_qparams()
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
if self.scale.shape != _scale.shape:
self.scale.resize_(_scale.shape)
self.zero_point.resize_(_zero_point.shape)
self.scale.copy_(_scale)
self.zero_point.copy_(_zero_point)
if self.fake_quant_enabled[0] == 1:
if self.is_per_channel:
X = torch.fake_quantize_per_channel_affine(
X, self.scale, self.zero_point,
self.ch_axis, self.activation_post_process.quant_min, self.activation_post_process.quant_max)
else:
X = torch.fake_quantize_per_tensor_affine(
X, self.scale, self.zero_point,
self.activation_post_process.quant_min, self.activation_post_process.quant_max)
return X
@torch.jit.export
def extra_repr(self):
return 'fake_quant_enabled={}, observer_enabled={}, ' \
'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
'scale={}, zero_point={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.activation_post_process.quant_min, self.activation_post_process.quant_max,
self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
def _save_to_state_dict(self, destination, prefix, keep_vars):
# We cannot currently register scalar values as buffers, so need to manually
# specify serialization here.
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'scale'] = self.scale
destination[prefix + 'zero_point'] = self.zero_point
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
# Removing this function throws an error that the size of the loaded tensor does not match the original size
# i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
local_state = ['scale', 'zero_point']
for name in local_state:
key = prefix + name
if key in state_dict:
val = state_dict[key]
# Custom handling to allow loading scale and zero_point
# of size N into uninitialized buffers of size 0. The
# buffers are resized here, and the values are copied in
# the default state_dict loading code of the parent.
if name == 'scale':
self.scale.resize_(val.shape)
else:
assert name == 'zero_point'
self.zero_point.resize_(val.shape)
# For torchscript module we need to update the attributes here since we do not
# call the `_load_from_state_dict` function defined module.py
if torch.jit.is_scripting():
if name == 'scale':
self.scale.copy_(val)
else:
assert name == 'zero_point'
self.zero_point.copy_(val)
elif strict:
missing_keys.append(key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
class FixedQParamsFakeQuantize(FakeQuantize):
"""Simulate quantize and dequantize in training time.
Simulate quantize and dequantize with fixed quantization
parameters in training time. Only per tensor quantization
is supported.
"""
# TODO: rename observer to observer_ctr
def __init__(self, observer):
super().__init__(observer=observer)
assert type(self.activation_post_process) == FixedQParamsObserver, \
f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
self._observer_ctr = observer
self.scale = self.activation_post_process.scale
self.zero_point = self.activation_post_process.zero_point
assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
@torch.jit.export
def calculate_qparams(self):
return self.scale, self.zero_point
@torch.jit.export
def extra_repr(self):
"""Define a string representation of the object's attributes."""
return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.scale, self.zero_point, self.dtype,
self.activation_post_process.quant_min, self.activation_post_process.quant_max, self.qscheme)
class FusedMovingAvgObsFakeQuantize(FakeQuantize):
r"""Define a fused module to observe the tensor.
Fused module that is used to observe the input tensor (compute min/max), compute
scale/zero_point and fake_quantize the tensor.
This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
to compute the min/max values in order to compute the scale/zero_point.
The qscheme input in the observer is used to differentiate between symmetric/affine
quantization scheme.
The output of this module is given by
x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the
base class.
"""
def __init__(
self,
observer: Any = MovingAverageMinMaxObserver,
quant_min: int = 0,
quant_max: int = 255,
**observer_kwargs: Any
) -> None:
super().__init__(observer, quant_min, quant_max, **observer_kwargs)
assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)), \
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
self.is_symmetric_quant = _is_symmetric_quant(self.activation_post_process.qscheme)
@torch.jit.export
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
return self.activation_post_process.calculate_qparams()
@torch.jit.export
def extra_repr(self) -> str:
return (
"fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, "
"dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format(
self.fake_quant_enabled,
self.observer_enabled,
self.scale,
self.zero_point,
self.dtype,
self.activation_post_process.quant_min,
self.activation_post_process.quant_max,
self.qscheme,
self.activation_post_process.reduce_range,
)
)
def forward(self, X: torch.Tensor) -> torch.Tensor:
return torch.fused_moving_avg_obs_fake_quant(
X,
self.observer_enabled,
self.fake_quant_enabled,
self.activation_post_process.min_val,
self.activation_post_process.max_val,
self.scale,
self.zero_point,
self.activation_post_process.averaging_constant,
self.activation_post_process.quant_min,
self.activation_post_process.quant_max,
self.ch_axis,
self.is_per_channel,
self.is_symmetric_quant,
)
default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
"""
Default fake_quant for activations.
"""
default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
"""
Default fake_quant for weights.
Observer is memoryless since averaging_constant is 1.
"""
default_dynamic_fake_quant = FakeQuantize.with_args(
observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, is_dynamic=True,
dtype=torch.quint8, averaging_constant=1)
"""
Default dynamic fake_quant for activations.
"""
default_fixed_qparams_range_neg1to1_fake_quant = (
FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_neg1to1_observer)
)
default_fixed_qparams_range_0to1_fake_quant = (
FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer)
)
# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
default_symmetric_fixed_qparams_fake_quant = default_fixed_qparams_range_neg1to1_fake_quant
default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant
default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
reduce_range=False,
ch_axis=0)
"""
Default fake_quant for per-channel weights.
Observer is memoryless since averaging_constant is 1.
"""
default_embedding_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
qscheme=torch.per_channel_affine_float_qparams,
dtype=torch.quint8,
quant_min=0,
quant_max=255,
ch_axis=0,
averaging_constant=1)
"""
Default fake_quant for embeddings.
Observer is memoryless since averaging_constant is 1.
"""
default_embedding_fake_quant_4bit = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
qscheme=torch.per_channel_affine_float_qparams,
ch_axis=0,
dtype=torch.quint4x2,
averaging_constant=1)
default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
quant_min=0,
quant_max=255,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=True)
"""
Fake_quant for activations using a histogram..
"""
default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
dtype=torch.quint8,)
"""
Fused version of `default_fake_quant`, with improved performance.
"""
default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric)
"""
Fused version of `default_weight_fake_quant`, with improved performance.
"""
default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric)
"""
Fused version of `default_per_channel_weight_fake_quant`, with improved performance.
"""
fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
quant_min=-127,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
eps=2 ** -12)
"""
Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
"""
fused_per_channel_wt_fake_quant_range_neg_127_to_127 = \
FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
quant_min=-127,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
eps=2 ** -12)
"""
Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
"""
def _is_fake_quant_script_module(mod):
"""Return true if given mod is an instance of FakeQuantize script module."""
if isinstance(mod, torch.jit.RecursiveScriptModule):
# qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
suffix = mod._c.qualified_name.split('.', 1)[1]
name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
return name == 'torch.ao.quantization.fake_quantize.FakeQuantize' or \
name == 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'
return False
def disable_fake_quant(mod):
"""Disable fake quantization for the module.
Disable fake quantization for this module, if applicable. Example usage::
# model is any PyTorch model
model.apply(torch.ao.quantization.disable_fake_quant)
"""
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.disable_fake_quant()
def enable_fake_quant(mod):
"""Enable fake quantization for the module.
Enable fake quantization for this module, if applicable. Example usage::
# model is any PyTorch model
model.apply(torch.ao.quantization.enable_fake_quant)
"""
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.enable_fake_quant()
def disable_observer(mod):
"""Disable observation for this module.
Disable observation for this module, if applicable. Example usage::
# model is any PyTorch model
model.apply(torch.ao.quantization.disable_observer)
"""
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.disable_observer()
def enable_observer(mod):
"""Enable observation for this module.
Enable observation for this module, if applicable. Example usage::
# model is any PyTorch model
model.apply(torch.ao.quantization.enable_observer)
"""
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.enable_observer()
|