Spaces:
Running
Running
from typing import List | |
from torch.ao.quantization.pt2e.utils import _is_sym_size_node | |
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation | |
from torch.fx import Node | |
def _annotate_input_qspec_map(node: Node, input_node: Node, qspec): | |
quantization_annotation = node.meta.get( | |
"quantization_annotation", QuantizationAnnotation() | |
) | |
if quantization_annotation.input_qspec_map is None: | |
quantization_annotation.input_qspec_map = {} | |
quantization_annotation.input_qspec_map[input_node] = qspec | |
node.meta["quantization_annotation"] = quantization_annotation | |
def _annotate_output_qspec(node: Node, qspec): | |
quantization_annotation = node.meta.get( | |
"quantization_annotation", QuantizationAnnotation() | |
) | |
quantization_annotation.output_qspec = qspec | |
node.meta["quantization_annotation"] = quantization_annotation | |
def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]): | |
""" | |
This utility is used to handle cases when dynami_shape=True tracing leads | |
to symint nodes in the pattern of linear module. In those cases, we need to | |
distinguish between the nodes that are in input for just extracting value of | |
some dimentions (and symint nodes) vs. the one that is activation. | |
For example: | |
graph(x, y, weight): | |
size_0 = torch.ops.aten.sym_size([x], [0]) | |
size_1 = torch.ops.aten.sym_size([y], [1]) | |
view_size = size_0 * size_1 | |
size_3 = torch.ops.aten.sym_size([x], [2]) | |
vie_out = torch.ops.aten.view(x, [view_size, size_3]) | |
return mm(view_out, weight) | |
In the example above y node is not actual input. It exist only to extract size_1 | |
""" | |
if _is_sym_size_node(node): | |
return True | |
return all( | |
((user not in partition_nodes) or _is_sym_size_node(user)) | |
for user in node.users | |
) | |