File size: 5,698 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.ao.quantization import ObserverOrFakeQuantize
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.fx import Node

__all__ = [
    "Quantizer",
    "QuantizationSpecBase",
    "QuantizationSpec",
    "FixedQParamsQuantizationSpec",
    "EdgeOrNode",
    "SharedQuantizationSpec",
    "DerivedQuantizationSpec",
    "QuantizationAnnotation",
]


class QuantizationSpecBase(ABC):  # noqa: B024
    """Base class for different types of quantization specs that allows users to

    specify how to quantize a Tensor (input/output of a Node) in the model

    """

    pass


@dataclass(eq=True, frozen=True)
class QuantizationSpec(QuantizationSpecBase):
    """Quantization spec for common operators that allows user to specify how to

    quantize a Tensor, this includes dtype, quant_min, quant_max etc.

    """

    dtype: torch.dtype
    # observer or fake_quantize constructor such as
    # MinMaxObserver, PerChannelHistogramObserver etc.
    # or we can attach some custom args to them
    # e.g. MinMaxObserver.with_args(eps=eps)
    observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor
    quant_min: Optional[int] = None
    quant_max: Optional[int] = None
    qscheme: Optional[torch.qscheme] = None
    ch_axis: Optional[int] = None
    is_dynamic: bool = False

    def __post_init__(self):
        # quant_min must be less than quant_max
        if (
            self.quant_min is not None
            and self.quant_max is not None
            and self.quant_min > self.quant_max
        ):
            raise ValueError(
                f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}."
            )

        # ch_axis must be less than the number of channels
        # but no way to check here. Just check that it is not < 0.
        if self.ch_axis is not None and self.ch_axis < 0:
            raise ValueError("Ch_axis is < 0.")


@dataclass(eq=True, frozen=True)
class FixedQParamsQuantizationSpec(QuantizationSpecBase):
    dtype: torch.dtype
    scale: float
    zero_point: int
    quant_min: Optional[int] = None
    quant_max: Optional[int] = None
    qscheme: Optional[torch.qscheme] = None


"""

The way we refer to other points of quantization in the graph will be either

an input edge or an output value

input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]

output value is an fx Node

"""
EdgeOrNode = Union[Tuple[Node, Node], Node]
EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer"


@dataclass(eq=True, frozen=True)
class SharedQuantizationSpec(QuantizationSpecBase):
    """

    Quantization spec for the Tensors whose quantization parameters are shared with other Tensors

    """

    # the edge or node to share observer or fake quant instances with
    edge_or_node: EdgeOrNode


@dataclass(eq=True, frozen=True)
class DerivedQuantizationSpec(QuantizationSpecBase):
    """Quantization spec for the Tensors whose quantization parameters are derived from other Tensors"""

    derived_from: List[EdgeOrNode]
    derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]]
    dtype: torch.dtype
    quant_min: Optional[int] = None
    quant_max: Optional[int] = None
    qscheme: Optional[torch.qscheme] = None
    ch_axis: Optional[int] = None


@dataclass
class QuantizationAnnotation:
    """How are input arguemnt or output should be quantized,

    expressed as QuantizationSpec, this corresponds to how a Tensor in the

    operator Graph is observed (PTQ) or fake quantized (QAT)

    """

    # a map from torch.fx.Node to a type of QuantizationSpecBase
    input_qspec_map: Dict[Node, Optional[QuantizationSpecBase]] = field(
        default_factory=dict
    )

    # How the output of this node is quantized, expressed as QuantizationSpec
    # TODO: change the value to QuantizationSpec in a separate PR
    output_qspec: Optional[QuantizationSpecBase] = None

    # For a Node: node1 and edge: (node1, node2), since they are observing the same
    # Tensor, we may want to implicitly share observers, this flag allows people to
    # turn off this behavior for the output of the node
    allow_implicit_sharing: bool = True

    # whether the node is annotated or not
    _annotated: bool = False


class Quantizer(ABC):
    def transform_for_annotation(

        self, model: torch.fx.GraphModule

    ) -> torch.fx.GraphModule:
        """Allows for user defined transforms to run before annotating the graph.

        This allows quantizer to allow quantizing part of the model that are otherwise not quantizable.

        For example quantizer can

        a) decompose a compound operator like scaled dot product attention,

        into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa

        or b) transform scalars to tensor to allow quantizing scalares.



        Note: this is an optional method

        """
        return model

    # annotate nodes in the graph with observer or fake quant constructors
    # to convey the desired way of quantization
    @abstractmethod
    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
        pass

    # validate the annotated graph is supported by the backend
    @abstractmethod
    def validate(self, model: torch.fx.GraphModule) -> None:
        pass