File size: 14,948 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
from __future__ import annotations
from collections import OrderedDict
from typing import Any, Callable, Dict, Tuple, Union, List

import torch

from .fake_quantize import (
    default_weight_fake_quant,
    FixedQParamsFakeQuantize,
)
from .observer import (
    _PartialWrapper,
    default_fixed_qparams_range_0to1_observer,
    default_fixed_qparams_range_neg1to1_observer,
    default_placeholder_observer,
    default_weight_observer,
)
from .qconfig import (
    default_reuse_input_qconfig,
    default_symmetric_qnnpack_qconfig,
    default_symmetric_qnnpack_qat_qconfig,
    get_default_qconfig,
    get_default_qat_qconfig,
    QConfig,
    QConfigAny,
    default_quint8_weight_qconfig
)


__all__ = [
    "get_default_qconfig_mapping",
    "get_default_qat_qconfig_mapping",
    "QConfigMapping",
]


# TODO: replace all usages with these constants
_GLOBAL_DICT_KEY = ""
_OBJECT_TYPE_DICT_KEY = "object_type"
_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex"
_MODULE_NAME_DICT_KEY = "module_name"
_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"

# TODO: derive this map from the BackendConfig
_FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = {
    torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer,
    torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer,
    "hardsigmoid": default_fixed_qparams_range_0to1_observer,
    "hardsigmoid_": default_fixed_qparams_range_0to1_observer,
    torch.nn.Sigmoid: default_fixed_qparams_range_0to1_observer,
    torch.sigmoid: default_fixed_qparams_range_0to1_observer,
    "sigmoid": default_fixed_qparams_range_0to1_observer,
    "sigmoid_": default_fixed_qparams_range_0to1_observer,
    torch.nn.Softmax: default_fixed_qparams_range_0to1_observer,
    torch.nn.Tanh: default_fixed_qparams_range_neg1to1_observer,
    torch.tanh: default_fixed_qparams_range_neg1to1_observer,
    "tanh": default_fixed_qparams_range_neg1to1_observer,
    "tanh_": default_fixed_qparams_range_neg1to1_observer,
}


def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QConfigMapping:
    """

    Return the default QConfigMapping for the given quantization type and backend.

    """
    if is_qat:
        qconfig = get_default_qat_qconfig(backend, version)
    else:
        qconfig = get_default_qconfig(backend, version)
    default_weight = default_weight_fake_quant if is_qat else default_weight_observer

    # default_per_channel_weight_observer is not currently compatible with fbgemm backend
    # so we have to modify the weight observer to default_weight_observer or another
    # per tensor supported observer.
    # see https://github.com/pytorch/pytorch/issues/47535
    if backend in ("fbgemm", "x86"):
        qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
    else:
        qconfig_transpose = qconfig

    # currently layernorm only supports float weights
    # we have to add this because otherwise there will be a extra quantize-dequantize pair
    qconfig_layernorm = QConfig(activation=qconfig.activation, weight=default_placeholder_observer)

    qconfig_mapping = QConfigMapping() \
        .set_global(qconfig) \
        .set_object_type("reshape", default_reuse_input_qconfig) \
        .set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
        .set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
        .set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
        .set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
        .set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
        .set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
        .set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) \
        .set_object_type(torch.nn.LayerNorm, qconfig_layernorm) \
        .set_object_type(torch.nn.PReLU, default_quint8_weight_qconfig) \

    # Use special observers for ops with fixed qparams
    fixed_qparams_observer_to_qconfig: Dict[Any, QConfigAny] = {}
    for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items():
        if observer in fixed_qparams_observer_to_qconfig:
            fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer]
        else:
            if is_qat:
                activation = FixedQParamsFakeQuantize.with_args(observer=observer)
            else:
                activation = observer
            fixed_qparams_qconfig = QConfig(activation=activation, weight=default_weight)
            fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig
        qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig)

    # TODO Currently it's required that separate ops in a fused op/module have the same qconfig.
    #      Need to be able to support fusion of ops with different qconfigs

    return qconfig_mapping

def get_default_qconfig_mapping(backend="x86", version=0) -> QConfigMapping:
    """

    Return the default QConfigMapping for post training quantization.



    Args:

      * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be

         one of ["x86" (default), "fbgemm", "qnnpack", "onednn"]

      * ``version`` (int) : the version for the default qconfig mapping

    """
    # TODO: add assert for backend choices
    return _get_default_qconfig_mapping(False, backend, version)

def get_default_qat_qconfig_mapping(backend="x86", version=1) -> QConfigMapping:
    """

    Return the default QConfigMapping for quantization aware training.



    Args:

      * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be

         one of ["x86" (default), "fbgemm", "qnnpack", "onednn"]

      * ``version`` (int) : the version for the default qconfig mapping

    """
    return _get_default_qconfig_mapping(True, backend, version)

def _get_symmetric_qnnpack_qconfig_mapping() -> QConfigMapping:
    """

    Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qconfig`

    as the default QConfig.

    """
    default_qconfig = default_symmetric_qnnpack_qconfig
    return _get_default_qconfig_mapping_with_default_qconfig(False, "qnnpack", default_qconfig)

def _get_symmetric_qnnpack_qat_qconfig_mapping() -> QConfigMapping:
    """

    Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig`

    as the default QConfig.

    """
    default_qconfig = default_symmetric_qnnpack_qat_qconfig
    return _get_default_qconfig_mapping_with_default_qconfig(True, "qnnpack", default_qconfig)

def _get_default_qconfig_mapping_with_default_qconfig(

    is_qat: bool,

    backend: str,

    default_qconfig: QConfig,

) -> QConfigMapping:
    """

    Return a QConfigMapping that uses the provided qconfig as the default QConfig.

    """
    if is_qat:
        qconfig_mapping = get_default_qat_qconfig_mapping(backend)
    else:
        qconfig_mapping = get_default_qconfig_mapping(backend)
    qconfig_mapping.set_global(default_qconfig)
    for pattern in qconfig_mapping.object_type_qconfigs.keys():
        if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER:
            qconfig_mapping.set_object_type(pattern, default_qconfig)
    return qconfig_mapping

_QCONFIG_STYLE_ORDER: List[str] = [
    "global_qconfig",
    "object_type_qconfigs",
    "module_name_regex_qconfigs",
    "module_name_qconfigs",
    "module_name_object_type_order_qconfigs",
]

class QConfigMapping:
    """

    Mapping from model ops to :class:`torch.ao.quantization.QConfig` s.



    The user can specify QConfigs using the following methods (in increasing match priority):



        ``set_global`` : sets the global (default) QConfig



        ``set_object_type`` : sets the QConfig for a given module type, function, or method name



        ``set_module_name_regex`` : sets the QConfig for modules matching the given regex string



        ``set_module_name`` : sets the QConfig for modules matching the given module name



        ``set_module_name_object_type_order`` : sets the QConfig for modules matching a combination

        of the given module name, object type, and the index at which the module appears



    Example usage::



        qconfig_mapping = QConfigMapping()

            .set_global(global_qconfig)

            .set_object_type(torch.nn.Linear, qconfig1)

            .set_object_type(torch.nn.ReLU, qconfig1)

            .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1)

            .set_module_name_regex("foo.*", qconfig2)

            .set_module_name("module1", qconfig1)

            .set_module_name("module2", qconfig2)

            .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3)



    """

    def __init__(self):
        # In increasing match priority:
        self.global_qconfig: QConfigAny = None
        self.object_type_qconfigs: OrderedDict[Union[Callable, str], QConfigAny] = OrderedDict()
        self.module_name_regex_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict()
        self.module_name_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict()
        self.module_name_object_type_order_qconfigs: OrderedDict[Tuple[str, Callable, int], QConfigAny] =\
            OrderedDict()

    def set_global(self, global_qconfig: QConfigAny) -> QConfigMapping:
        """

        Set the global (default) QConfig.

        """
        self.global_qconfig = global_qconfig
        return self

    def set_object_type(self, object_type: Union[Callable, str], qconfig: QConfigAny) -> QConfigMapping:
        """

        Set the QConfig for a given module type, function, or method name.

        If the QConfig for an existing object type was already set, the new QConfig will override the old one.

        """
        self.object_type_qconfigs[object_type] = qconfig
        return self

    def set_module_name_regex(self, module_name_regex: str, qconfig: QConfigAny) -> QConfigMapping:
        """

        Set the QConfig for modules matching the given regex string.



        Regexes will be matched in the order in which they are registered through this method.

        Thus, the caller should register more specific patterns first, e.g.::



            qconfig_mapping = QConfigMapping()

                .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1)

                .set_module_name_regex("foo.*bar.*", qconfig2)

                .set_module_name_regex("foo.*", qconfig3)



        In this example, "foo.bar.conv0" would match qconfig1, "foo.bar.linear" would match qconfig2,

        and "foo.baz.relu" would match qconfig3.



        If the QConfig for an existing module name regex was already set, the new QConfig will override the

        old one while preserving the order in which the regexes were originally registered.

        """
        self.module_name_regex_qconfigs[module_name_regex] = qconfig
        return self

    def set_module_name(self, module_name: str, qconfig: QConfigAny) -> QConfigMapping:
        """

        Set the QConfig for modules matching the given module name.

        If the QConfig for an existing module name was already set, the new QConfig will override the old one.

        """
        self.module_name_qconfigs[module_name] = qconfig
        return self

    def set_module_name_object_type_order(

            self,

            module_name: str,

            object_type: Callable,

            index: int,

            qconfig: QConfigAny) -> QConfigMapping:
        """

        Set the QConfig for modules matching a combination of the given module name, object type,

        and the index at which the module appears.



        If the QConfig for an existing (module name, object type, index)  was already set, the new QConfig

        will override the old one.

        """
        self.module_name_object_type_order_qconfigs[(module_name, object_type, index)] = qconfig
        return self

    def __repr__(self) -> str:
        output = self.__class__.__name__ + " ("
        for style_name in _QCONFIG_STYLE_ORDER:
            output += f"\n {style_name}"
            qconfigs = getattr(self, style_name)
            if isinstance(qconfigs, OrderedDict) and len(qconfigs) > 0:
                for key, qconfig in qconfigs.items():
                    output += f"\n  {key}: {qconfig}"
            else:
                output += f"\n  {qconfigs}"
        return output + "\n)"

    # TODO: remove this
    def to_dict(self) -> Dict[str, Any]:
        """

        Convert this ``QConfigMapping`` to a dictionary with the following keys:



            "" (for global QConfig)



            "object_type"



            "module_name_regex"



            "module_name"



            "module_name_object_type_order"



        The values of this dictionary are lists of tuples.

        """
        return {
            _GLOBAL_DICT_KEY: self.global_qconfig,
            _OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()),
            _MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()),
            _MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()),
            _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [
                (*k, v) for k, v in self.module_name_object_type_order_qconfigs.items()
            ],
        }

    # TODO: remove this
    @classmethod
    def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping:
        """

        Create a ``QConfigMapping`` from a dictionary with the following keys (all optional):



            "" (for global QConfig)



            "object_type"



            "module_name_regex"



            "module_name"



            "module_name_object_type_order"



        The values of this dictionary are expected to be lists of tuples.

        """
        conf = cls()
        if _GLOBAL_DICT_KEY in qconfig_dict:
            conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY])
        for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []):
            conf.set_object_type(object_type, qconfig)
        for module_name_regex, qconfig in qconfig_dict.get(_MODULE_NAME_REGEX_DICT_KEY, []):
            conf.set_module_name_regex(module_name_regex, qconfig)
        for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []):
            conf.set_module_name(module_name, qconfig)
        for module_name, object_type, index, qconfig in qconfig_dict.get(_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []):
            conf.set_module_name_object_type_order(module_name, object_type, index, qconfig)
        return conf