Spaces:
Running
Running
from abc import ABC, abstractmethod | |
from dataclasses import dataclass, field | |
from typing import Callable, Dict, List, Optional, Tuple, Union | |
import torch | |
from torch import Tensor | |
from torch.ao.quantization import ObserverOrFakeQuantize | |
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor | |
from torch.fx import Node | |
__all__ = [ | |
"Quantizer", | |
"QuantizationSpecBase", | |
"QuantizationSpec", | |
"FixedQParamsQuantizationSpec", | |
"EdgeOrNode", | |
"SharedQuantizationSpec", | |
"DerivedQuantizationSpec", | |
"QuantizationAnnotation", | |
] | |
class QuantizationSpecBase(ABC): # noqa: B024 | |
"""Base class for different types of quantization specs that allows users to | |
specify how to quantize a Tensor (input/output of a Node) in the model | |
""" | |
pass | |
class QuantizationSpec(QuantizationSpecBase): | |
"""Quantization spec for common operators that allows user to specify how to | |
quantize a Tensor, this includes dtype, quant_min, quant_max etc. | |
""" | |
dtype: torch.dtype | |
# observer or fake_quantize constructor such as | |
# MinMaxObserver, PerChannelHistogramObserver etc. | |
# or we can attach some custom args to them | |
# e.g. MinMaxObserver.with_args(eps=eps) | |
observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor | |
quant_min: Optional[int] = None | |
quant_max: Optional[int] = None | |
qscheme: Optional[torch.qscheme] = None | |
ch_axis: Optional[int] = None | |
is_dynamic: bool = False | |
def __post_init__(self): | |
# quant_min must be less than quant_max | |
if ( | |
self.quant_min is not None | |
and self.quant_max is not None | |
and self.quant_min > self.quant_max | |
): | |
raise ValueError( | |
f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}." | |
) | |
# ch_axis must be less than the number of channels | |
# but no way to check here. Just check that it is not < 0. | |
if self.ch_axis is not None and self.ch_axis < 0: | |
raise ValueError("Ch_axis is < 0.") | |
class FixedQParamsQuantizationSpec(QuantizationSpecBase): | |
dtype: torch.dtype | |
scale: float | |
zero_point: int | |
quant_min: Optional[int] = None | |
quant_max: Optional[int] = None | |
qscheme: Optional[torch.qscheme] = None | |
""" | |
The way we refer to other points of quantization in the graph will be either | |
an input edge or an output value | |
input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node] | |
output value is an fx Node | |
""" | |
EdgeOrNode = Union[Tuple[Node, Node], Node] | |
EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer" | |
class SharedQuantizationSpec(QuantizationSpecBase): | |
""" | |
Quantization spec for the Tensors whose quantization parameters are shared with other Tensors | |
""" | |
# the edge or node to share observer or fake quant instances with | |
edge_or_node: EdgeOrNode | |
class DerivedQuantizationSpec(QuantizationSpecBase): | |
"""Quantization spec for the Tensors whose quantization parameters are derived from other Tensors""" | |
derived_from: List[EdgeOrNode] | |
derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]] | |
dtype: torch.dtype | |
quant_min: Optional[int] = None | |
quant_max: Optional[int] = None | |
qscheme: Optional[torch.qscheme] = None | |
ch_axis: Optional[int] = None | |
class QuantizationAnnotation: | |
"""How are input arguemnt or output should be quantized, | |
expressed as QuantizationSpec, this corresponds to how a Tensor in the | |
operator Graph is observed (PTQ) or fake quantized (QAT) | |
""" | |
# a map from torch.fx.Node to a type of QuantizationSpecBase | |
input_qspec_map: Dict[Node, Optional[QuantizationSpecBase]] = field( | |
default_factory=dict | |
) | |
# How the output of this node is quantized, expressed as QuantizationSpec | |
# TODO: change the value to QuantizationSpec in a separate PR | |
output_qspec: Optional[QuantizationSpecBase] = None | |
# For a Node: node1 and edge: (node1, node2), since they are observing the same | |
# Tensor, we may want to implicitly share observers, this flag allows people to | |
# turn off this behavior for the output of the node | |
allow_implicit_sharing: bool = True | |
# whether the node is annotated or not | |
_annotated: bool = False | |
class Quantizer(ABC): | |
def transform_for_annotation( | |
self, model: torch.fx.GraphModule | |
) -> torch.fx.GraphModule: | |
"""Allows for user defined transforms to run before annotating the graph. | |
This allows quantizer to allow quantizing part of the model that are otherwise not quantizable. | |
For example quantizer can | |
a) decompose a compound operator like scaled dot product attention, | |
into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa | |
or b) transform scalars to tensor to allow quantizing scalares. | |
Note: this is an optional method | |
""" | |
return model | |
# annotate nodes in the graph with observer or fake quant constructors | |
# to convey the desired way of quantization | |
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: | |
pass | |
# validate the annotated graph is supported by the backend | |
def validate(self, model: torch.fx.GraphModule) -> None: | |
pass | |