Spaces:
Running
Running
File size: 32,727 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 |
from typing import Any, Dict, Optional, Tuple, Union
import warnings
import torch
import copy
from torch.fx import GraphModule
from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY
from .fx.tracer import QuantizationTracer
from .fx.tracer import ( # noqa: F401
Scope,
ScopeContextManager
)
from .fx.fuse import fuse # noqa: F401
from .fx.prepare import prepare # noqa: F401
from .fx.convert import convert
from .backend_config import ( # noqa: F401
BackendConfig,
get_tensorrt_backend_config,
)
from .fx.graph_module import ObservedGraphModule # noqa: F401
from .fx.custom_config import (
ConvertCustomConfig,
FuseCustomConfig,
PrepareCustomConfig,
)
from .fx.utils import get_custom_module_class_keys # noqa: F401
from .fx.utils import get_skipped_module_name_and_classes
from .qconfig_mapping import QConfigMapping
def attach_preserved_attrs_to_model(
model: Union[GraphModule, torch.nn.Module],
preserved_attrs: Dict[str, Any],
) -> None:
""" Store preserved attributes to the model.meta so that it can be preserved during deepcopy
"""
model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs) # type: ignore[operator, index, assignment]
# set the preserved attributes in the model so that user can call
# model.attr as they do before calling fx graph mode quantization
for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): # type: ignore[index, union-attr]
setattr(model, attr_name, attr)
def _check_is_graph_module(model: torch.nn.Module) -> None:
if not isinstance(model, GraphModule):
raise ValueError(
"input model must be a GraphModule, "
+ "Got type:"
+ str(type(model))
+ " Please make "
+ "sure to follow the tutorials."
)
def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None:
""" Attach meta field to all nodes of the graph if it does not exist,
meta field is a field stores some meta information about the node, such
as dtype and shape information for output of the node, this only exists
if the program is captured by make_fx (used in quantize_pt2e flow), if
the program is captured by torch.fx symbolic tracing, this field may not exist,
so we add it here to avoid checking this all over the places
"""
for node in model.graph.nodes:
if not hasattr(node, "meta"):
node.meta = {}
def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
r""" Swap FloatFunctional with FXFloatFunctional
"""
modules_to_swap = []
for name, module in model.named_children():
if isinstance(module, torch.ao.nn.quantized.FloatFunctional):
modules_to_swap.append(name)
else:
_swap_ff_with_fxff(module)
for name in modules_to_swap:
del model._modules[name]
model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional()
def _fuse_fx(
model: GraphModule,
is_qat: bool,
fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
r""" Internal helper function to fuse modules in preparation for quantization
Args:
model: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
"""
_check_is_graph_module(model)
return fuse(
model, is_qat, fuse_custom_config, backend_config) # type: ignore[operator]
def _prepare_fx(
model: torch.nn.Module,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
is_qat: bool,
example_inputs: Tuple[Any, ...],
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
_equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
is_standalone_module: bool = False,
) -> GraphModule:
r""" Internal helper function for prepare_fx
Args:
`model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`:
see docs for :func:`~torch.ao.quantization.prepare_fx`
`is_standalone_module`: a boolean flag indicates whether we are
quantizing a standalone module or not, a standalone module
is a submodule of the parent module that is not inlined in the
forward graph of the parent module,
the way we quantize standalone module is described in:
:func:`~torch.ao.quantization._prepare_standalone_module_fx`
"""
if prepare_custom_config is None:
prepare_custom_config = PrepareCustomConfig()
if _equalization_config is None:
_equalization_config = QConfigMapping()
if isinstance(prepare_custom_config, Dict):
warnings.warn(
"Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
"in a future version. Please pass in a PrepareCustomConfig instead.")
prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
# swap FloatFunctional with FXFloatFunctional
_swap_ff_with_fxff(model)
skipped_module_names, skipped_module_classes = \
get_skipped_module_name_and_classes(prepare_custom_config, is_standalone_module)
preserved_attr_names = prepare_custom_config.preserved_attributes
preserved_attrs = {attr: getattr(model, attr) for attr in preserved_attr_names if hasattr(model, attr)}
# symbolically trace the model
tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) # type: ignore[arg-type]
graph_module = GraphModule(model, tracer.trace(model))
_attach_meta_to_node_if_not_exist(graph_module)
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(prepare_custom_config.preserved_attributes)
graph_module = _fuse_fx(
graph_module,
is_qat,
fuse_custom_config,
backend_config)
prepared = prepare(
graph_module,
qconfig_mapping,
is_qat,
tracer.node_name_to_scope,
example_inputs=example_inputs,
prepare_custom_config=prepare_custom_config,
_equalization_config=_equalization_config,
backend_config=backend_config,
is_standalone_module=is_standalone_module,
) # type: ignore[operator]
attach_preserved_attrs_to_model(prepared, preserved_attrs)
return prepared
def _prepare_standalone_module_fx(
model: torch.nn.Module,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
is_qat: bool,
example_inputs: Tuple[Any, ...],
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
parent module.
standalone_module means it a submodule that is not inlined in parent module,
and will be quantized separately as one unit.
How the standalone module is observed is specified by `input_quantized_idxs` and
`output_quantized_idxs` in the prepare_custom_config for the standalone module
Returns:
* model(GraphModule): prepared standalone module. It has these attributes in
model.meta:
* `standalone_module_input_quantized_idxs(List[Int])`: a list of
indexes for the graph input that is expected to be quantized,
same as input_quantized_idxs configuration provided
for the standalone module
* `standalone_module_output_quantized_idxs(List[Int])`: a list of
indexs for the graph output that is quantized
same as input_quantized_idxs configuration provided
for the standalone module
"""
return _prepare_fx(
model,
qconfig_mapping,
is_qat,
example_inputs,
prepare_custom_config,
backend_config=backend_config,
is_standalone_module=True,
)
def fuse_fx(
model: torch.nn.Module,
fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
Fusion rules are defined in torch.ao.quantization.fx.fusion_pattern.py
Args:
* `model` (torch.nn.Module): a torch.nn.Module model
* `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx.
See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details
Example::
from torch.ao.quantization import fuse_fx
m = Model().eval()
m = fuse_fx(m)
"""
if fuse_custom_config is None:
fuse_custom_config = FuseCustomConfig()
if isinstance(fuse_custom_config, Dict):
warnings.warn(
"Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
"in a future version. Please pass in a FuseCustomConfig instead.")
fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
preserved_attr_names = fuse_custom_config.preserved_attributes
preserved_attrs = {attr: getattr(model, attr) for attr in preserved_attr_names if hasattr(model, attr)}
graph_module = torch.fx.symbolic_trace(model)
_attach_meta_to_node_if_not_exist(graph_module)
graph_module = _fuse_fx(graph_module, False, fuse_custom_config, backend_config)
attach_preserved_attrs_to_model(graph_module, preserved_attrs)
return graph_module
def prepare_fx(
model: torch.nn.Module,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
example_inputs: Tuple[Any, ...],
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
_equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
r""" Prepare a model for post training quantization
Args:
* `model` (torch.nn.Module): torch.nn.Module model
* `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is
quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping`
for more details
* `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model,
Tuple of positional args (keyword args can be passed as positional args as well)
* `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool.
See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details
* `_equalization_config`: config for specifying how to perform equalization on the model
* `backend_config` (BackendConfig): config that specifies how operators are quantized
in a backend, this includes how the operators are observed,
supported fusion patterns, how quantize/dequantize ops are
inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
Return:
A GraphModule with observer (configured by qconfig_mapping), ready for calibration
Example::
import torch
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx
class Submodule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
x = self.linear(x)
return x
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
self.sub = Submodule()
def forward(self, x):
x = self.linear(x)
x = self.sub(x) + x
return x
# initialize a floating point model
float_model = M().eval()
# define calibration function
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
# qconfig is the configuration for how we insert observers for a particular
# operator
# qconfig = get_default_qconfig("fbgemm")
# Example of customizing qconfig:
# qconfig = torch.ao.quantization.QConfig(
# activation=MinMaxObserver.with_args(dtype=torch.qint8),
# weight=MinMaxObserver.with_args(dtype=torch.qint8))
# `activation` and `weight` are constructors of observer module
# qconfig_mapping is a collection of quantization configurations, user can
# set the qconfig for each operator (torch op calls, functional calls, module calls)
# in the model through qconfig_mapping
# the following call will get the qconfig_mapping that works best for models
# that target "fbgemm" backend
qconfig_mapping = get_default_qconfig_mapping("fbgemm")
# We can customize qconfig_mapping in different ways.
# e.g. set the global qconfig, which means we will use the same qconfig for
# all operators in the model, this can be overwritten by other settings
# qconfig_mapping = QConfigMapping().set_global(qconfig)
# e.g. quantize the linear submodule with a specific qconfig
# qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig)
# e.g. quantize all nn.Linear modules with a specific qconfig
# qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
# for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping`
# argument
# example_inputs is a tuple of inputs, that is used to infer the type of the
# outputs in the model
# currently it's not used, but please make sure model(*example_inputs) runs
example_inputs = (torch.randn(1, 3, 224, 224),)
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
# e.g. backend_config = get_default_backend_config("fbgemm")
# `prepare_fx` inserts observers in the model based on qconfig_mapping and
# backend_config. If the configuration for an operator in qconfig_mapping
# is supported in the backend_config (meaning it's supported by the target
# hardware), we'll insert observer modules according to the qconfig_mapping
# otherwise the configuration in qconfig_mapping will be ignored
#
# Example:
# in qconfig_mapping, user sets linear module to be quantized with quint8 for
# activation and qint8 for weight:
# qconfig = torch.ao.quantization.QConfig(
# observer=MinMaxObserver.with_args(dtype=torch.quint8),
# weight=MinMaxObserver.with-args(dtype=torch.qint8))
# Note: current qconfig api does not support setting output observer, but
# we may extend this to support these more fine grained control in the
# future
#
# qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
# in backend config, linear module also supports in this configuration:
# weighted_int8_dtype_config = DTypeConfig(
# input_dtype=torch.quint8,
# output_dtype=torch.quint8,
# weight_dtype=torch.qint8,
# bias_type=torch.float)
# linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \
# .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
# .add_dtype_config(weighted_int8_dtype_config) \
# ...
# backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config)
# `prepare_fx` will check that the setting requested by suer in qconfig_mapping
# is supported by the backend_config and insert observers and fake quant modules
# in the model
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)
# Run calibration
calibrate(prepared_model, sample_inference_data)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
return _prepare_fx(
model,
qconfig_mapping,
False, # is_qat
example_inputs,
prepare_custom_config,
_equalization_config,
backend_config,
)
def prepare_qat_fx(
model: torch.nn.Module,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
example_inputs: Tuple[Any, ...],
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
r""" Prepare a model for quantization aware training
Args:
* `model` (torch.nn.Module): torch.nn.Module model
* `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx`
* `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx`
* `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx`
* `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx`
Return:
A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for
quantization aware training
Example::
import torch
from torch.ao.quantization import get_default_qat_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_qat_fx
class Submodule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
x = self.linear(x)
return x
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
self.sub = Submodule()
def forward(self, x):
x = self.linear(x)
x = self.sub(x) + x
return x
# initialize a floating point model
float_model = M().train()
# (optional, but preferred) load the weights from pretrained model
# float_model.load_weights(...)
# define the training loop for quantization aware training
def train_loop(model, train_data):
model.train()
for image, target in data_loader:
...
# qconfig is the configuration for how we insert observers for a particular
# operator
# qconfig = get_default_qconfig("fbgemm")
# Example of customizing qconfig:
# qconfig = torch.ao.quantization.QConfig(
# activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)),
# weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)))
# `activation` and `weight` are constructors of observer module
# qconfig_mapping is a collection of quantization configurations, user can
# set the qconfig for each operator (torch op calls, functional calls, module calls)
# in the model through qconfig_mapping
# the following call will get the qconfig_mapping that works best for models
# that target "fbgemm" backend
qconfig_mapping = get_default_qat_qconfig("fbgemm")
# We can customize qconfig_mapping in different ways, please take a look at
# the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways
# to configure this
# example_inputs is a tuple of inputs, that is used to infer the type of the
# outputs in the model
# currently it's not used, but please make sure model(*example_inputs) runs
example_inputs = (torch.randn(1, 3, 224, 224),)
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
# e.g. backend_config = get_default_backend_config("fbgemm")
# `prepare_qat_fx` inserts observers in the model based on qconfig_mapping and
# backend_config, if the configuration for an operator in qconfig_mapping
# is supported in the backend_config (meaning it's supported by the target
# hardware), we'll insert fake_quantize modules according to the qconfig_mapping
# otherwise the configuration in qconfig_mapping will be ignored
# see :func:`~torch.ao.quantization.prepare_fx` for a detailed explanation of
# how qconfig_mapping interacts with backend_config
prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs)
# Run training
train_loop(prepared_model, train_loop)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
return _prepare_fx(
model,
qconfig_mapping,
True, # is_qat
example_inputs,
prepare_custom_config,
backend_config=backend_config,
)
def _convert_fx(
graph_module: GraphModule,
is_reference: bool,
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
is_standalone_module: bool = False,
_remove_qconfig: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
is_decomposed: bool = False,
) -> GraphModule:
""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
"""
if convert_custom_config is None:
convert_custom_config = ConvertCustomConfig()
if isinstance(convert_custom_config, Dict):
warnings.warn(
"Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
"in a future version. Please pass in a ConvertCustomConfig instead.")
convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
_check_is_graph_module(graph_module)
preserved_attr_names = convert_custom_config.preserved_attributes
preserved_attrs = {attr: getattr(graph_module, attr) for attr in preserved_attr_names if hasattr(graph_module, attr)}
quantized = convert(
graph_module,
is_reference,
convert_custom_config,
is_standalone_module,
_remove_qconfig_flag=_remove_qconfig,
qconfig_mapping=qconfig_mapping,
backend_config=backend_config,
is_decomposed=is_decomposed,
)
attach_preserved_attrs_to_model(quantized, preserved_attrs)
return quantized
def convert_fx(
graph_module: GraphModule,
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
_remove_qconfig: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
r""" Convert a calibrated or trained model to a quantized model
Args:
* `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule)
* `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details
* `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
* `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`,
with the same values or `None`. Additional keys can be specified with values set to `None`.
For each entry whose value is set to None, we skip quantizing that entry in the model::
qconfig_mapping = QConfigMapping
.set_global(qconfig_from_prepare)
.set_object_type(torch.nn.functional.add, None) # skip quantizing torch.nn.functional.add
.set_object_type(torch.nn.functional.linear, qconfig_from_prepare)
.set_module_name("foo.bar", None) # skip quantizing module "foo.bar"
* `backend_config` (BackendConfig): A configuration for the backend which describes how
operators should be quantized in the backend, this includes quantization
mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.),
observer placement for each operators and fused operators.
See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
Return:
A quantized model (torch.nn.Module)
Example::
# prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
# convert_fx converts a calibrated/trained model to a quantized model for the
# target hardware, this includes converting the model first to a reference
# quantized model, and then lower the reference quantized model to a backend
# Currently, the supported backends are fbgemm (onednn), qnnpack (xnnpack) and
# they share the same set of quantized operators, so we are using the same
# lowering procedure
#
# backend_config defines the corresponding reference quantized module for
# the weighted modules in the model, e.g. nn.Linear
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
# e.g. backend_config = get_default_backend_config("fbgemm")
quantized_model = convert_fx(prepared_model)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")
return _convert_fx(
graph_module,
is_reference=False,
convert_custom_config=convert_custom_config,
_remove_qconfig=_remove_qconfig,
qconfig_mapping=qconfig_mapping,
backend_config=backend_config,
)
def convert_to_reference_fx(
graph_module: GraphModule,
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
_remove_qconfig: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
r""" Convert a calibrated or trained model to a reference quantized model,
see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
reference quantized model is a standard representation of a quantized model provided
by FX Graph Mode Quantization, it can be further lowered to run on the target
hardware, like accelerators
Args:
* `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
* `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
* `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
* `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
* `backend_config` (BackendConfig): A configuration for the backend which describes how
operators should be quantized in the backend. See
:func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
Return:
A reference quantized model (GraphModule)
Example::
# prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
# e.g. backend_config = get_default_backend_config("fbgemm")
reference_quantized_model = convert_to_reference_fx(prepared_model)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference_fx")
return _convert_fx(
graph_module,
is_reference=True,
convert_custom_config=convert_custom_config,
_remove_qconfig=_remove_qconfig,
qconfig_mapping=qconfig_mapping,
backend_config=backend_config,
)
def _convert_to_reference_decomposed_fx(
graph_module: GraphModule,
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
r""" Convert a calibrated or trained model to a reference quantized model, with
decomposed representation for quantized Tensor
see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
reference quantized model is a standard representation of a quantized model provided
by FX Graph Mode Quantization, it can be further lowered to run on the target
hardware, like accelerators
Note: this is not public API
Args:
* `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
* `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
* `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
* `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
* `backend_config` (BackendConfig): A configuration for the backend which describes how
operators should be quantized in the backend. See
:func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
Return:
A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor
Example::
# prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
# e.g. backend_config = get_default_backend_config("fbgemm")
reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx._convert_to_reference_decomposed_fx")
return _convert_fx(
graph_module,
is_reference=True,
convert_custom_config=convert_custom_config,
_remove_qconfig=False,
qconfig_mapping=qconfig_mapping,
backend_config=backend_config,
is_decomposed=True,
)
def _convert_standalone_module_fx(
graph_module: GraphModule,
is_reference: bool = False,
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
) -> GraphModule:
r""" [Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx`
and convert it to a quantized model
Returns a quantized standalone module, whether input/output is quantized is
specified by prepare_custom_config, with
input_quantized_idxs, output_quantized_idxs, please
see docs for prepare_fx for details
"""
return _convert_fx(
graph_module,
is_reference,
convert_custom_config,
is_standalone_module=True,
)
|