File size: 1,922 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import List

from torch.ao.quantization.pt2e.utils import _is_sym_size_node

from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
from torch.fx import Node


def _annotate_input_qspec_map(node: Node, input_node: Node, qspec):
    quantization_annotation = node.meta.get(
        "quantization_annotation", QuantizationAnnotation()
    )
    if quantization_annotation.input_qspec_map is None:
        quantization_annotation.input_qspec_map = {}
    quantization_annotation.input_qspec_map[input_node] = qspec
    node.meta["quantization_annotation"] = quantization_annotation


def _annotate_output_qspec(node: Node, qspec):
    quantization_annotation = node.meta.get(
        "quantization_annotation", QuantizationAnnotation()
    )
    quantization_annotation.output_qspec = qspec
    node.meta["quantization_annotation"] = quantization_annotation


def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]):
    """

    This utility is used to handle cases when dynami_shape=True tracing leads

    to symint nodes in the pattern of linear module. In those cases, we need to

    distinguish between the nodes that are in input for just extracting value of

    some dimentions (and symint nodes) vs. the one that is activation.

    For example:

    graph(x, y, weight):

       size_0 = torch.ops.aten.sym_size([x], [0])

       size_1 = torch.ops.aten.sym_size([y], [1])

       view_size = size_0 * size_1

       size_3 = torch.ops.aten.sym_size([x], [2])

       vie_out = torch.ops.aten.view(x, [view_size, size_3])

       return mm(view_out, weight)

    In the example above y node is not actual input. It exist only to extract size_1

    """
    if _is_sym_size_node(node):
        return True

    return all(
        ((user not in partition_nodes) or _is_sym_size_node(user))
        for user in node.users
    )