Spaces:
Running
Running
import torch | |
from torch.fx import GraphModule | |
from torch.fx import Node | |
from .pt2e.prepare import prepare | |
from .pt2e.qat_utils import ( | |
_fuse_conv_bn_qat, | |
_fold_conv_bn_qat, | |
) | |
from .pt2e.utils import ( | |
_get_node_name_to_scope, | |
_fuse_conv_bn_, | |
_disallow_eval_train, | |
) | |
from .pt2e.representation import reference_representation_rewrite | |
from .quantize_fx import _convert_to_reference_decomposed_fx | |
from torch.ao.quantization.quantizer import ( # noqa: F401 | |
Quantizer, | |
QuantizationSpecBase, | |
QuantizationSpec, | |
FixedQParamsQuantizationSpec, | |
SharedQuantizationSpec, | |
DerivedQuantizationSpec, | |
QuantizationAnnotation, | |
) | |
from torch.fx.passes.infra.pass_manager import PassManager | |
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass | |
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ | |
from torch._inductor.constant_folding import constant_fold | |
__all__ = [ | |
"prepare_pt2e", | |
"prepare_qat_pt2e", | |
"convert_pt2e", | |
] | |
def prepare_pt2e( | |
model: GraphModule, | |
quantizer: Quantizer, | |
) -> GraphModule: | |
"""Prepare a model for post training quantization | |
Args: | |
* `model` (torch.fx.GraphModule): a model captured by `torch.export` API | |
in the short term we are using `torch._export.capture_pre_autograd_graph`, | |
in the long term we'll migrate to some `torch.export` API | |
* `quantizer`: A backend specific quantizer that conveys how user want the | |
model to be quantized. Tutorial for how to write a quantizer can be found here: | |
https://pytorch.org/tutorials/prototype/pt2e_quantizer.html | |
Return: | |
A GraphModule with observer (based on quantizer annotation), ready for calibration | |
Example:: | |
import torch | |
from torch.ao.quantization.quantize_pt2e import prepare_pt2e | |
from torch._export import capture_pre_autograd_graph | |
from torch.ao.quantization.quantizer import ( | |
XNNPACKQuantizer, | |
get_symmetric_quantization_config, | |
) | |
class M(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.linear = torch.nn.Linear(5, 10) | |
def forward(self, x): | |
return self.linear(x) | |
# initialize a floating point model | |
float_model = M().eval() | |
# define calibration function | |
def calibrate(model, data_loader): | |
model.eval() | |
with torch.no_grad(): | |
for image, target in data_loader: | |
model(image) | |
# Step 1. program capture | |
# NOTE: this API will be updated to torch.export API in the future, but the captured | |
# result shoud mostly stay the same | |
m = capture_pre_autograd_graph(m, *example_inputs) | |
# we get a model with aten ops | |
# Step 2. quantization | |
# backend developer will write their own Quantizer and expose methods to allow | |
# users to express how they | |
# want the model to be quantized | |
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) | |
m = prepare_pt2e(m, quantizer) | |
# run calibration | |
# calibrate(m, sample_inference_data) | |
""" | |
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e") | |
original_graph_meta = model.meta | |
node_name_to_scope = _get_node_name_to_scope(model) | |
# TODO: check qconfig_mapping to make sure conv and bn are both configured | |
# to be quantized before fusion | |
# TODO: (maybe) rewrite this with subgraph_rewriter | |
_fuse_conv_bn_(model) | |
quantizer.transform_for_annotation(model) | |
quantizer.annotate(model) | |
quantizer.validate(model) | |
model = prepare(model, node_name_to_scope, is_qat=False) | |
model.meta.update(original_graph_meta) | |
model = _disallow_eval_train(model) | |
return model | |
def prepare_qat_pt2e( | |
model: GraphModule, | |
quantizer: Quantizer, | |
) -> GraphModule: | |
"""Prepare a model for quantization aware training | |
Args: | |
* `model` (torch.fx.GraphModule): see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` | |
* `quantizer`: see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` | |
Return: | |
A GraphModule with fake quant modules (based on quantizer annotation), ready for | |
quantization aware training | |
Example:: | |
import torch | |
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e | |
from torch._export import capture_pre_autograd_graph | |
from torch.ao.quantization.quantizer import ( | |
XNNPACKQuantizer, | |
get_symmetric_quantization_config, | |
) | |
class M(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.linear = torch.nn.Linear(5, 10) | |
def forward(self, x): | |
return self.linear(x) | |
# initialize a floating point model | |
float_model = M().eval() | |
# define the training loop for quantization aware training | |
def train_loop(model, train_data): | |
model.train() | |
for image, target in data_loader: | |
... | |
# Step 1. program capture | |
# NOTE: this API will be updated to torch.export API in the future, but the captured | |
# result shoud mostly stay the same | |
m = capture_pre_autograd_graph(m, *example_inputs) | |
# we get a model with aten ops | |
# Step 2. quantization | |
# backend developer will write their own Quantizer and expose methods to allow | |
# users to express how they | |
# want the model to be quantized | |
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) | |
m = prepare_qat_pt2e(m, quantizer) | |
# run quantization aware training | |
train_loop(prepared_model, train_loop) | |
""" | |
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e") | |
original_graph_meta = model.meta | |
node_name_to_scope = _get_node_name_to_scope(model) | |
quantizer.transform_for_annotation(model) | |
quantizer.annotate(model) | |
quantizer.validate(model) | |
# Perform fusion after annotate to avoid quantizing ops in the new | |
# subgraph that don't need to be quantized | |
# TODO: only fuse if conv and bn are both configured to be quantized | |
_fuse_conv_bn_qat(model) | |
model = prepare(model, node_name_to_scope, is_qat=True) | |
model.meta.update(original_graph_meta) | |
model = _disallow_eval_train(model) | |
return model | |
_QUANT_OPS = [ | |
torch.ops.quantized_decomposed.quantize_per_tensor.default, | |
torch.ops.quantized_decomposed.quantize_per_tensor.tensor, | |
torch.ops.quantized_decomposed.quantize_per_channel.default, | |
] | |
def _quant_node_constraint(n: Node) -> bool: | |
"""If there is any pure ops between get_attr and quantize op they will be const propagated | |
e.g. get_attr(weight) -> transpose -> quantize -> dequantize* | |
(Note: dequantize op is not going to be constant propagated) | |
This filter is added because we don't want to constant fold the things that are not | |
related to quantization | |
""" | |
return n.op == "call_function" and n.target in _QUANT_OPS | |
def convert_pt2e( | |
model: GraphModule, | |
use_reference_representation: bool = False, | |
fold_quantize: bool = True, | |
) -> GraphModule: | |
"""Convert a calibrated/trained model to a quantized model | |
Args: | |
* `model` (torch.fx.GraphModule): calibrated/trained model | |
* `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not | |
* `fold_quantize` (bool): boolean flag for whether fold the quantize op or not | |
Returns: | |
quantized model, either in q/dq representation or reference representation | |
Example:: | |
# prepared_model: the model produced by `prepare_pt2e`/`prepare_qat_pt2e` and calibration/training | |
# `convert_pt2e` produces a quantized model that represents quantized computation with | |
# quantize dequantize ops and fp32 ops by default. | |
# Please refer to | |
# https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model | |
# for detailed explanation of output quantized model | |
quantized_model = convert_pt2e(prepared_model) | |
""" # flake8: noqa | |
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e") | |
if not isinstance(use_reference_representation, bool): | |
raise ValueError( | |
"Unexpected argument type for `use_reference_representation`, " | |
f"please make sure you intend to pass argument {use_reference_representation} to convert_pt2e") | |
original_graph_meta = model.meta | |
model = _convert_to_reference_decomposed_fx(model) | |
model = _fold_conv_bn_qat(model) | |
pm = PassManager([DuplicateDQPass()]) | |
model = pm(model).graph_module | |
pm = PassManager([PortNodeMetaForQDQ()]) | |
model = pm(model).graph_module | |
if fold_quantize: | |
constant_fold(model, _quant_node_constraint) | |
if use_reference_representation: | |
model = reference_representation_rewrite(model) | |
model.meta.update(original_graph_meta) | |
model = _disallow_eval_train(model) | |
return model | |