File size: 29,140 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
import copy
import itertools
import warnings

import torch
import torch.nn as nn
import torch.ao.nn.quantized as nnq
from torch.ao.nn.intrinsic import _FusedModule

from torch.ao.quantization.quantization_mappings import (
    get_default_dynamic_quant_module_mappings,
    get_default_static_quant_module_mappings,
    get_default_static_quant_reference_module_mappings,
    get_default_qat_module_mappings,
    get_default_qconfig_propagation_list,
    no_observer_set,
    _has_special_act_post_process,
    _get_special_act_post_process,
)
from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations
from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
from torch.ao.quantization.qconfig import (
    _add_module_to_qconfig_obs_ctr,
    default_dynamic_qconfig,
    float16_dynamic_qconfig,
    float_qparams_weight_only_qconfig,
    float_qparams_weight_only_qconfig_4bit,
    _activation_is_memoryless)
from torch.nn.utils.parametrize import type_before_parametrizations
from torch.ao.quantization.observer import _is_activation_post_process

# TODO remove this once BC is no longer required to avoid a SEV
from torch.ao.quantization.observer import (   # noqa: F401
    _is_activation_post_process as is_activation_post_process
)

__all__ = [
    "get_default_custom_config_dict",
    "propagate_qconfig_",
    "add_quant_dequant",
    "prepare",
    "quantize",
    "quantize_dynamic",
    "prepare_qat",
    "quantize_qat",
    "convert",
    "swap_module",
]

_DEFAULT_CUSTOM_CONFIG_DICT = {
    'float_to_observed_custom_module_class': {
        nn.LSTM: nn.quantizable.LSTM,
        nn.MultiheadAttention: nn.quantizable.MultiheadAttention,
    },
    'observed_to_quantized_custom_module_class': {
        nn.quantizable.LSTM: nn.quantized.LSTM,
        nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention,
    }
}

def get_default_custom_config_dict():
    r"""Defines the default custom config dict.

    """
    return _DEFAULT_CUSTOM_CONFIG_DICT

def _propagate_qconfig_helper(module, qconfig_dict,

                              qconfig_parent=None, prefix='', prepare_custom_config_dict=None):
    r"""This is a helper function for `propagate_qconfig_`



    Args:

        module: input module

        qconfig_dict: dictionary that maps from name of submodule to quantization

                     configuration

        qconfig_parent: quantization config of parent module, we will fallback to

                       this config when there is no specified config for current

                       module

        prefix: corresponding prefix of the current module, used as key in

                qconfig_dict

        prepare_custom_config_dict: dictionary for custom handling of modules

                                    see docs for :func:`~torch.ao.quantization.prepare_fx`



    Return:

        None, module is modified inplace with qconfig attached

    """

    module_qconfig = qconfig_dict.get(type_before_parametrizations(module), qconfig_parent)
    module_qconfig = qconfig_dict.get(prefix, module_qconfig)
    module_qconfig = getattr(module, 'qconfig', module_qconfig)

    torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module)

    qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module)
    module.qconfig = qconfig_with_device_check

    for name, child in module.named_children():
        module_prefix = prefix + '.' + name if prefix else name
        #  do no not propagate qconfig to child if child is non traceable
        if prepare_custom_config_dict is None or not (
            name in prepare_custom_config_dict.get("non_traceable_module_name", [])
            or type(child) in prepare_custom_config_dict.get("non_traceable_module_class", [])
        ):
            _propagate_qconfig_helper(
                child, qconfig_dict, qconfig_with_device_check, module_prefix
            )

def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None):
    r"""Propagate qconfig through the module hierarchy and assign `qconfig`

    attribute on each leaf module



    Args:

        module: input module

        qconfig_dict: dictionary that maps from name or type of submodule to

            quantization configuration, qconfig applies to all submodules of a

            given module unless qconfig for the submodules are specified (when

            the submodule already has qconfig attribute)

        prepare_custom_config_dict: dictionary for custom handling of modules

            see docs for :func:`~torch.ao.quantization.prepare_fx`



    Return:

        None, module is modified inplace with qconfig attached

    """
    if qconfig_dict is None:
        qconfig_dict = {}
    if prepare_custom_config_dict is None:
        prepare_custom_config_dict = {}
    _propagate_qconfig_helper(module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict)

def _observer_forward_hook(self, input, output):
    r"""Forward hook that calls observer on the output

    """
    return self.activation_post_process(output)

def _observer_forward_pre_hook(self, input):
    r"""Forward pre hook that calls observer on the output

    """
    return self.activation_post_process(input[0])

def _register_activation_post_process_hook(module, pre_hook=False):
    assert hasattr(module, 'activation_post_process'), \
        'Expect activation_post_process attribute already attached to the module'
    if pre_hook:
        handle = module.register_forward_pre_hook(
            _observer_forward_pre_hook, prepend=True
        )
    else:
        handle = module.register_forward_hook(
            _observer_forward_hook, prepend=True
        )


def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
    r"""Add observer for the leaf child of the module.



    This function insert observer module to all leaf child module that

    has a valid qconfig attribute.



    Args:

        module: input module with qconfig attributes for all the leaf modules that we want to quantize

        qconfig_propagation_list: a list of quantizable modules that will have observers added to them

            if they are leaf nodes

        device: parent device, if any

        non_leaf_module_list: list of non-leaf modules we want to add observer



    Return:

        None, module is modified inplace with added observer modules and forward_hooks

    """
    if qconfig_propagation_list is None:
        qconfig_propagation_list = get_default_qconfig_propagation_list()

    if custom_module_class_mapping is None:
        custom_module_class_mapping = {}

    # respect device affinity when adding observers
    if device is None:
        devices = _get_unique_devices_(module)
        assert len(devices) <= 1, (
            f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}"
        )
        device = next(iter(devices)) if len(devices) > 0 else None

    def get_activation_post_process(qconfig, device, special_act_post_process=None):
        activation = qconfig.activation() if special_act_post_process is None else special_act_post_process()
        if device is not None:
            activation.to(device)
        return activation

    def needs_observation(m):
        return hasattr(m, 'qconfig') and m.qconfig is not None

    def insert_activation_post_process(m, special_act_post_process=None):
        """ Adds an activation post process module and register

        a pre or post hook that calls the module

        """
        # We don't insert observer/fake_quantize for DeQuantStub
        if needs_observation(m) and not isinstance(m, DeQuantStub):
            # observer and hook will be gone after we swap the module
            m.add_module('activation_post_process', get_activation_post_process(
                m.qconfig, device, special_act_post_process))
            # Register observer as the first entry in the hook list
            # All post forward hooks are preserved and will be executed after the observer before convert
            _register_activation_post_process_hook(m, pre_hook=_activation_is_memoryless(m.qconfig))

    for name, child in module.named_children():
        # TODO remove Dropout special after codebase stable
        if type_before_parametrizations(child) in [nn.Dropout]:
            continue
        elif issubclass(type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)):
            if needs_observation(child):
                assert hasattr(child, "activation_post_process"), (
                    f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`"
                )
                child.activation_post_process = get_activation_post_process(child.qconfig, device)
        elif isinstance(child, _FusedModule):
            # activation_post_process are now added directly to nn.Sequential/_FusedModule
            if needs_observation(child):
                insert_activation_post_process(child)
        elif non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list:
            if needs_observation(child):
                insert_activation_post_process(child)
        elif _has_special_act_post_process(child):
            special_act_post_process = _get_special_act_post_process(child)
            insert_activation_post_process(child, special_act_post_process)
        elif needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping:
            observed_child = custom_module_class_mapping[type_before_parametrizations(child)].from_float(child)
            setattr(module, name, observed_child)
            # TODO: These are the modules that cannot be observed
            #       Once there are more, we should move them to a separate list
            if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set():
                insert_activation_post_process(observed_child)
        else:
            _add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)

    # Insert observers only for leaf nodes, note that this observer is for
    # the output of the module, for input QuantStub will observe them
    if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \
       and type_before_parametrizations(module) in qconfig_propagation_list:
        insert_activation_post_process(module)

def _get_unique_devices_(module):
    return {p.device for p in module.parameters()} | \
        {p.device for p in module.buffers()}

def add_quant_dequant(module):
    r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig

    Note that this function will modify the children of module inplace and it

    can return a new module which wraps the input module as well.



    Args:

        module: input module with qconfig attributes for all the leaf modules

        that we want to quantize



    Return:

        Either the inplace modified module with submodules wrapped in

        `QuantWrapper` based on qconfig or a new `QuantWrapper` module which

        wraps the input module, the latter case only happens when the input

        module is a leaf module and we want to quantize it.

    """
    if has_no_children_ignoring_parametrizations(module) and hasattr(module, 'qconfig') and module.qconfig:
        return QuantWrapper(module)

    for name, child in module.named_children():
        module._modules[name] = add_quant_dequant(child)
    return module

def prepare(model, inplace=False, allow_list=None,

            observer_non_leaf_module_list=None,

            prepare_custom_config_dict=None):
    r"""Prepares a copy of the model for quantization calibration or quantization-aware training.



    Quantization configuration should be assigned preemptively

    to individual submodules in `.qconfig` attribute.



    The model will be attached with observer or fake quant modules, and qconfig

    will be propagated.



    Args:

        `model`: input model to be modified in-place

        `inplace`: carry out model transformations in-place, the original module is mutated

        `allow_list`: list of quantizable modules

        `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer

        `prepare_custom_config_dict`: customization configuration dictionary for prepare function



    .. code-block:: python



       # Example of prepare_custom_config_dict:

       prepare_custom_config_dict = {

           # user will manually define the corresponding observed

           # module class which has a from_float class method that converts

           # float custom module to observed custom module

           "float_to_observed_custom_module_class": {

               CustomModule: ObservedCustomModule

           }

        }



    """
    torch._C._log_api_usage_once("quantization_api.quantize.prepare")
    if prepare_custom_config_dict is None:
        prepare_custom_config_dict = get_default_custom_config_dict()
    custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})

    if not inplace:
        model = copy.deepcopy(model)

    # TODO: remove allow_list
    qconfig_propagation_list = allow_list
    if allow_list is None:
        qconfig_propagation_list = get_default_qconfig_propagation_list()
    propagate_qconfig_(model, qconfig_dict=None)

    # sanity check common API misusage
    if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
        warnings.warn("None of the submodule got qconfig applied. Make sure you "
                      "passed correct configuration through `qconfig_dict` or "
                      "by assigning the `.qconfig` attribute directly on submodules")

    _add_observer_(
        model, qconfig_propagation_list, observer_non_leaf_module_list,
        custom_module_class_mapping=custom_module_class_mapping)
    return model

def _remove_activation_post_process(module):
    # TODO: maybe we should change activation_post_process to _activation_post_process
    # to prevent it from being used by user
    if hasattr(module, 'activation_post_process') and \
       _is_activation_post_process(module.activation_post_process):
        delattr(module, 'activation_post_process')

    # remove activation_post_process pre and post hooks
    def remove_hooks(pre_hook=False):
        hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks
        observer_hook = _observer_forward_pre_hook if pre_hook else _observer_forward_hook
        handle_ids_to_remove = set()
        for handle_id, hook_fn in hook_map.items():
            if hook_fn is observer_hook:
                handle_ids_to_remove.add(handle_id)
        for handle_id in handle_ids_to_remove:
            hook_map.pop(handle_id)

    remove_hooks(pre_hook=True)
    remove_hooks(pre_hook=False)

# TODO: rename to something more general
def _remove_qconfig(module):
    r"""Clean up the qconfig left in the module so that new qconfig can be

    propagated.



    Args:

        module: module to be cleaned up

    """
    for child in module.children():
        _remove_qconfig(child)

    if hasattr(module, "qconfig"):
        del module.qconfig

    _remove_activation_post_process(module)

def quantize(model, run_fn, run_args, mapping=None, inplace=False):
    r"""Quantize the input float model with post training static quantization.



    First it will prepare the model for calibration, then it calls

    `run_fn` which will run the calibration step, after that we will

    convert the model to a quantized model.



    Args:

        model: input float model

        run_fn: a calibration function for calibrating the prepared model

        run_args: positional arguments for `run_fn`

        inplace: carry out model transformations in-place, the original module is mutated

        mapping: correspondence between original module types and quantized counterparts



    Return:

        Quantized model.

    """
    torch._C._log_api_usage_once("quantization_api.quantize.quantize")
    if mapping is None:
        mapping = get_default_static_quant_module_mappings()
    if not inplace:
        model = copy.deepcopy(model)
    model.eval()
    prepare(model, inplace=True)
    run_fn(model, *run_args)
    convert(model, mapping, inplace=True)
    return model

def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,

                     mapping=None, inplace=False):
    r"""Converts a float model to dynamic (i.e. weights-only) quantized model.



    Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.



    For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization

    by default is performed for layers with large weights size - i.e. Linear and RNN variants.



    Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.

    If `qconfig` is provided, the `dtype` argument is ignored.



    Args:

        model: input model

        qconfig_spec: Either:



            - A dictionary that maps from name or type of submodule to quantization

              configuration, qconfig applies to all submodules of a given

              module unless qconfig for the submodules are specified (when the

              submodule already has qconfig attribute). Entries in the dictionary

              need to be QConfig instances.



            - A set of types and/or submodule names to apply dynamic quantization to,

              in which case the `dtype` argument is used to specify the bit-width



        inplace: carry out model transformations in-place, the original module is mutated

        mapping: maps type of a submodule to a type of corresponding dynamically quantized version

            with which the submodule needs to be replaced



    """
    torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
    if qconfig_spec is None:
        if dtype == torch.qint8:
            qconfig_spec = {
                nn.Linear : default_dynamic_qconfig,
                nn.LSTM : default_dynamic_qconfig,
                nn.GRU : default_dynamic_qconfig,
                nn.LSTMCell : default_dynamic_qconfig,
                nn.RNNCell : default_dynamic_qconfig,
                nn.GRUCell : default_dynamic_qconfig,
            }
        elif dtype == torch.float16:
            qconfig_spec = {
                nn.Linear : float16_dynamic_qconfig,
                nn.LSTM : float16_dynamic_qconfig,
                nn.GRU : float16_dynamic_qconfig,
                nn.LSTMCell : float16_dynamic_qconfig,
                nn.RNNCell : float16_dynamic_qconfig,
                nn.GRUCell : float16_dynamic_qconfig,
            }
        elif dtype == torch.quint8:
            qconfig_spec = {
                nn.EmbeddingBag : float_qparams_weight_only_qconfig,
                nn.Embedding : float_qparams_weight_only_qconfig,
            }
        elif dtype == torch.quint4x2:
            qconfig_spec = {
                nn.EmbeddingBag : float_qparams_weight_only_qconfig_4bit,
            }
        else:
            raise ValueError(
                f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please")
    elif isinstance(qconfig_spec, set):
        if dtype is torch.qint8:
            default_qconfig = default_dynamic_qconfig
        elif dtype is torch.float16:
            default_qconfig = float16_dynamic_qconfig
        elif dtype is torch.quint8:
            default_qconfig = float_qparams_weight_only_qconfig
        elif dtype is torch.quint4x2:
            default_qconfig = float_qparams_weight_only_qconfig_4bit
        else:
            raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype))
        qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))

    if mapping is None:
        mapping = get_default_dynamic_quant_module_mappings()

    if not inplace:
        model = copy.deepcopy(model)
    model.eval()
    propagate_qconfig_(model, qconfig_spec)
    convert(model, mapping, inplace=True)
    return model

def prepare_qat(model, mapping=None, inplace=False):
    r"""

    Prepares a copy of the model for quantization calibration or

    quantization-aware training and converts it to quantized version.



    Quantization configuration should be assigned preemptively

    to individual submodules in `.qconfig` attribute.



    Args:

        model: input model to be modified in-place

        mapping: dictionary that maps float modules to quantized modules to be

                 replaced.

        inplace: carry out model transformations in-place, the original module

                 is mutated

    """
    torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
    assert model.training, "prepare_qat only works on models in training mode"
    if mapping is None:
        mapping = get_default_qat_module_mappings()

    if not inplace:
        model = copy.deepcopy(model)

    propagate_qconfig_(model, qconfig_dict=None)
    convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
    prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
    return model

def quantize_qat(model, run_fn, run_args, inplace=False):
    r"""Do quantization aware training and output a quantized model



    Args:

        model: input model

        run_fn: a function for evaluating the prepared model, can be a

                function that simply runs the prepared model or a training

                loop

        run_args: positional arguments for `run_fn`



    Return:

        Quantized model.

    """
    torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
    if not inplace:
        model = copy.deepcopy(model)
    model.train()
    prepare_qat(model, inplace=True)
    run_fn(model, *run_args)
    convert(model, inplace=True)
    return model

def convert(

        module, mapping=None, inplace=False, remove_qconfig=True,

        is_reference=False, convert_custom_config_dict=None):
    r"""Converts submodules in input module to a different module according to `mapping`

    by calling `from_float` method on the target module class. And remove qconfig at the

    end if remove_qconfig is set to True.



    Args:

        `module`: prepared and calibrated module

        `mapping`: a dictionary that maps from source module type to target

                   module type, can be overwritten to allow swapping user defined

                   Modules

        `inplace`: carry out model transformations in-place, the original module

                   is mutated

        `convert_custom_config_dict`: custom configuration dictionary for convert function



    .. code-block:: python



       # Example of convert_custom_config_dict:

       convert_custom_config_dict = {

           # user will manually define the corresponding quantized

           # module class which has a from_observed class method that converts

           # observed custom module to quantized custom module

           "observed_to_quantized_custom_module_class": {

               ObservedCustomModule: QuantizedCustomModule

           }

       }



    """
    torch._C._log_api_usage_once("quantization_api.quantize.convert")
    if not inplace:
        module = copy.deepcopy(module)
    _convert(
        module, mapping, inplace=True, is_reference=is_reference,
        convert_custom_config_dict=convert_custom_config_dict)
    if remove_qconfig:
        _remove_qconfig(module)
    return module

def _convert(

        module, mapping=None, inplace=False,

        is_reference=False, convert_custom_config_dict=None):
    r"""Converts submodules in input module to a different module according to `mapping`

    by calling `from_float` method on the target module class



    Args:

        module: input module

        mapping: a dictionary that maps from source module type to target

                 module type, can be overwritten to allow swapping user defined

                 Modules

        inplace: carry out model transformations in-place, the original module

                 is mutated

        is_reference: a flag to enable quantized reference module



    """
    if mapping is None:
        mapping = get_default_static_quant_reference_module_mappings() if is_reference \
            else get_default_static_quant_module_mappings()
    if convert_custom_config_dict is None:
        convert_custom_config_dict = get_default_custom_config_dict()
    custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})

    if not inplace:
        module = copy.deepcopy(module)
    reassign = {}
    for name, mod in module.named_children():
        # both fused modules and observed custom modules are
        # swapped as one unit
        if not isinstance(mod, _FusedModule) and \
           type_before_parametrizations(mod) not in custom_module_class_mapping:
            _convert(mod, mapping, True,  # inplace
                     is_reference, convert_custom_config_dict)
        reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)

    for key, value in reassign.items():
        module._modules[key] = value

    return module

def swap_module(mod, mapping, custom_module_class_mapping):
    r"""Swaps the module if it has a quantized counterpart and it has an

    `observer` attached.



    Args:

        mod: input module

        mapping: a dictionary that maps from nn module to nnq module



    Return:

        The corresponding quantized module of `mod`

    """
    new_mod = mod
    if hasattr(mod, 'qconfig') and mod.qconfig is not None:
        swapped = False
        if type_before_parametrizations(mod) in custom_module_class_mapping:
            new_mod = custom_module_class_mapping[type_before_parametrizations(mod)].from_observed(mod)
            swapped = True
        elif type_before_parametrizations(mod) in mapping:
            qmod = mapping[type_before_parametrizations(mod)]
            if hasattr(qmod, '_IS_REFERENCE') and qmod._IS_REFERENCE:
                assert mod.qconfig is not None
                weight_post_process = mod.qconfig.weight()
                weight_post_process(mod.weight)
                weight_qparams = get_qparam_dict(weight_post_process)
                new_mod = qmod.from_float(mod, weight_qparams)
            else:
                new_mod = qmod.from_float(mod)
            swapped = True

        if swapped:
            # Preserve module's pre forward hooks. They'll be called on quantized input
            for pre_hook_fn in mod._forward_pre_hooks.values():
                new_mod.register_forward_pre_hook(pre_hook_fn)
            # Preserve module's post forward hooks except _observer_forward_hook
            # After convert they'll work with quantized output
            for hook_fn in mod._forward_hooks.values():
                if hook_fn is not _observer_forward_hook:
                    new_mod.register_forward_hook(hook_fn)

            # respect device affinity when swapping modules
            devices = _get_unique_devices_(mod)
            assert len(devices) <= 1, (
                f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
            )
            device = next(iter(devices)) if len(devices) > 0 else None
            if device:
                new_mod.to(device)
    return new_mod

def _get_observer_dict(mod, target_dict, prefix=""):
    r"""Traverse the modules and save all observers into dict.

    This is mainly used for quantization accuracy debug

    Args:

        mod: the top module we want to save all observers

        prefix: the prefix for the current module

        target_dict: the dictionary used to save all the observers

    """
    def get_prefix(prefix):
        return prefix if prefix == "" else prefix + '.'

    if hasattr(mod, 'activation_post_process'):
        target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process
    for name, child in mod.named_children():
        module_prefix = get_prefix(prefix) + name if prefix else name
        _get_observer_dict(child, target_dict, module_prefix)