Spaces:
Running
Running
import torch | |
from torch._subclasses import FakeTensor | |
from torch.ao.quantization.fx.prepare import ( | |
_insert_obs_or_fq, | |
_save_state, | |
_is_activation_post_process_node, | |
_create_obs_or_fq_from_qspec, | |
) | |
from torch.fx import ( | |
GraphModule, | |
Graph, | |
Node, | |
) | |
from torch.fx.node import Argument | |
from torch.ao.quantization import QConfigMapping | |
from torch.ao.quantization.qconfig import QConfigAny | |
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig | |
from typing import Dict, Tuple, Union, Any, Optional | |
from torch.ao.quantization.quantizer import ( | |
EdgeOrNode, | |
SharedQuantizationSpec, | |
QuantizationSpecBase, | |
) | |
from torch.ao.quantization import ObserverOrFakeQuantize | |
# TODO: make pt2e folder private? | |
__all__ = [ | |
"prepare", | |
] | |
def _find_root_edge_or_node(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> EdgeOrNode: | |
"""Find the root node for the sharing tree | |
Args: | |
edge_or_node: edge/node that we want to find the root | |
shared_with_map: each edge/node points to the parent, the root node will points to itself | |
Returns: | |
root edge/node | |
""" | |
parent = shared_with_map[edge_or_node] | |
if parent == edge_or_node: | |
return edge_or_node | |
root = _find_root_edge_or_node(parent, shared_with_map) | |
# path compression | |
shared_with_map[edge_or_node] = root | |
return root | |
def _union(parent: EdgeOrNode, child: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> None: | |
"""Merge the subtree for `child` with `parent`, the order is important here | |
""" | |
root_parent = _find_root_edge_or_node(parent, shared_with_map) | |
root_child = _find_root_edge_or_node(child, shared_with_map) | |
# union the two trees by pointing the root of child to root of parent | |
shared_with_map[root_child] = root_parent | |
def _update_shared_with(child: EdgeOrNode, qspec: QuantizationSpecBase, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]): | |
"""Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec` | |
configuration and established the relationship between `edge_or_node` with the edge/node that it | |
is pointing to, we'll use this information in the end to get the group id | |
""" | |
if isinstance(qspec, SharedQuantizationSpec): | |
parent = qspec.edge_or_node | |
# we point from edge_or_node to the node that it is sharing_with, e.g. | |
# qspec for a = SharedQuantizationSpec(b) means `a` points to `b` | |
_union(parent, child, shared_with_map) | |
def _unwrap_shared_qspec( | |
qspec: QuantizationSpecBase, | |
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase], | |
shared_with_map: Dict[EdgeOrNode, EdgeOrNode] | |
) -> QuantizationSpecBase: | |
"""Unwraps qspec to get the final root qspec (non SharedQuantizationSpec) | |
if qspec is SharedQuantizationSpec | |
(1). tries to find the root edge or node for the node that the qspec points to | |
(2). recursively find the root qspec based on the qspec for the root node | |
""" | |
if isinstance(qspec, SharedQuantizationSpec): | |
sharing_with = qspec.edge_or_node | |
root = _find_root_edge_or_node(sharing_with, shared_with_map) | |
qspec = edge_or_node_to_qspec[root] | |
return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) | |
return qspec | |
def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase): | |
return ( | |
hasattr(qspec_a, "dtype") and | |
hasattr(qspec_b, "dtype") and | |
qspec_a.dtype == qspec_b.dtype | |
) | |
def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase): | |
return ( | |
hasattr(qspec_a, "is_dynamic") and | |
hasattr(qspec_b, "is_dynamic") and | |
qspec_a.is_dynamic == qspec_b.is_dynamic | |
) | |
def _get_edge_or_node_to_qspec(model: torch.fx.GraphModule) -> Dict[EdgeOrNode, QuantizationSpecBase]: | |
"""Get a map from EdgeOrNode to quantization spec based on annotations on the nodes | |
""" | |
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {} | |
for n in model.graph.nodes: | |
if hasattr(n, "meta") and "quantization_annotation" in n.meta: | |
qa = n.meta["quantization_annotation"] | |
for input_to_n, qspec in qa.input_qspec_map.items(): | |
input_edge = (input_to_n, n) | |
edge_or_node_to_qspec[input_edge] = qspec | |
if qa.output_qspec is not None: | |
output_node = n | |
qspec = qa.output_qspec | |
edge_or_node_to_qspec[output_node] = qspec | |
return edge_or_node_to_qspec | |
def _union_input_edge_with(input_edge, input_edge_root_qspec, edge_or_node, edge_or_node_to_qspec, shared_with_map): | |
"""Union input edge with another edge or node, used in implicit sharing to point the current input | |
edge to other user edges of the producer node, or the output of producer node since these are | |
referring to the same Tensor | |
""" | |
root_qspec = None | |
if edge_or_node in edge_or_node_to_qspec: | |
qspec = edge_or_node_to_qspec[edge_or_node] | |
root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) | |
# TODO: add assertions for types of root qspecs | |
if ( | |
root_qspec is not None and | |
_has_same_dtype(root_qspec, input_edge_root_qspec) and | |
_has_same_is_dynamic(root_qspec, input_edge_root_qspec) | |
): | |
# the input arg to the node should reuse the existing output observer for arg | |
# since dtype is the same (we may want to extend this to be a more strict check | |
# in the future) | |
# so we point from `input_edge` to `arg` (output of the argument) | |
_union(edge_or_node, input_edge, shared_with_map) | |
def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase]) -> Dict[EdgeOrNode, int]: | |
"""Map from edge/node to the group ID, generated from quantization annotations, | |
edge/node with the same group ID should use the same observer/fake_quant instance | |
This is applying SharedQuantizationSpec configuration and map each edge/node to a group | |
There is another implicit sharing that's built in the quantization, when we have the following: | |
* op1 -> op2 | |
* output of op1: int8_qspec | |
* (op1 -> op2) input edge: int8_qspec | |
we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor. | |
Figuring out the correct group ID for all edge/node is a standard union find problem: | |
https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/ | |
Args: | |
edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations | |
Returns: | |
edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that | |
belongs to the same group should have the same id | |
Example: | |
op2 -> cat1 -> cat2 | |
op1 / / | |
op3 | |
edge_or_node_to_qspec: { | |
op1: int8_qspec, | |
op2: int8_qspec, | |
(op1, cat1): int8_qspc, | |
(op2, cat1): SharedQuantizationSpec((op1, cat1)), | |
cat1: SharedQuantizationSpec((op1, cat1)), | |
(op3, cat2): int8_qspec, | |
(cat1, cat2): SharedQuantizationSpec((op3, cat2)), | |
cat2: SharedQuantizationSpec((op3, cat2)), | |
} | |
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) | |
edge_or_node_to_group_id: { | |
op1: 1, | |
op2: 1, | |
(op1, cat1): 1, | |
(op2, cat1): 1, | |
cat1: 1, | |
(op3, cat2): 1, | |
(cat1, cat2): 1, | |
cat2: 1, | |
} | |
# everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which | |
# connects the two sharing group around cat1 and cat2 op due to transitive sharing | |
""" | |
# means the observer of key should be shared with observer with value, by default it will | |
# be shared with itself | |
shared_with_map: Dict[EdgeOrNode, EdgeOrNode] = {k: k for k in edge_or_node_to_qspec.keys()} | |
for edge_or_node, qspec in edge_or_node_to_qspec.items(): | |
if isinstance(edge_or_node, torch.fx.Node): | |
output_node = edge_or_node | |
_update_shared_with(output_node, qspec, shared_with_map) | |
else: | |
input_edge = edge_or_node | |
input_edge_root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) | |
assert isinstance(input_edge, tuple) | |
arg, n = input_edge | |
if n.meta["quantization_annotation"].allow_implicit_sharing: | |
# NOTE: the order is important here, we first share with other users and then share with previous | |
# output because the reverse order could cause circular dependency | |
# e.g node1 -> node2 | |
# \ -> node3 | |
# when processing (node1, node2), if we first point (node1, node2) to node1 | |
# Step 1. shared_map = {(node1, node2): node1} | |
# Step 2. after that, we point the (node1, node2) to its other user (node1, node3) , | |
# which means shared_map = {(node1, node2): node1, node1: (node1, node3)} | |
# because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3) | |
# Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll | |
# have a circular dependency | |
# the following order works around this issue, but this does not allow arbitrary configuration | |
# of sharing so it might break in a different case in the future, when it breaks | |
# quantizer writer can check the notes here to debug the issue | |
# sharing with other users of the producer node | |
# (arg, user) | |
if not isinstance(arg, Node) or not isinstance(n, Node): | |
raise Exception(f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}") | |
for user in arg.users: | |
if user is n: | |
continue | |
arg_to_user_edge = (arg, user) | |
_union_input_edge_with( | |
input_edge, | |
input_edge_root_qspec, | |
arg_to_user_edge, | |
edge_or_node_to_qspec, | |
shared_with_map | |
) | |
# sharing with output of producer node | |
_union_input_edge_with(input_edge, input_edge_root_qspec, arg, edge_or_node_to_qspec, shared_with_map) | |
_update_shared_with(input_edge, qspec, shared_with_map) | |
# now that we get the sharing relations between all edges and nodes, we can assingn group ids | |
cur_group_id = 0 | |
edge_or_node_to_group_id: Dict[EdgeOrNode, int] = {} | |
for edge_or_node in shared_with_map.keys(): | |
root = _find_root_edge_or_node(edge_or_node, shared_with_map) | |
if root not in edge_or_node_to_group_id: | |
edge_or_node_to_group_id[root] = cur_group_id | |
cur_group_id += 1 | |
edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root] | |
return edge_or_node_to_group_id | |
def _get_obs_or_fq_map( | |
edge_or_node_to_group_id: Dict[EdgeOrNode, int], | |
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase], | |
is_qat: bool | |
) -> Dict[EdgeOrNode, ObserverOrFakeQuantize]: | |
"""Generates the EdgeOrNode to observer/fake_quant instances | |
Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant | |
instances | |
""" | |
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {} | |
group_id_to_obs_or_fq: Dict[int, ObserverOrFakeQuantize] = {} | |
for edge_or_node, qspec in edge_or_node_to_qspec.items(): | |
group_id = edge_or_node_to_group_id[edge_or_node] | |
if group_id not in group_id_to_obs_or_fq: | |
# TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify | |
# the implementation for _create_obs_or_fq_from_qspec | |
group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(qspec, obs_or_fq_map, is_qat) | |
obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id] | |
return obs_or_fq_map | |
def _maybe_insert_input_observer_for_arg_or_kwarg( | |
node: Union[Node, Any], | |
arg: Argument, | |
qconfig: QConfigAny, | |
model: torch.nn.Module, | |
named_modules: Dict[str, torch.nn.Module], | |
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], | |
is_qat: bool, | |
) -> Argument: | |
""" | |
Given a `node` and an `arg`, inserts an input observer between | |
`node` and `arg` if necessary. | |
""" | |
# for ops such as torch.cat([x0, x1]), | |
# traverse through the list | |
if isinstance(arg, (list, tuple)): | |
new_arg_to_return = [] | |
for inner_arg in arg: | |
new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( | |
node, inner_arg, qconfig, model, named_modules, obs_or_fq_map, is_qat, | |
) | |
new_arg_to_return.append(new_inner_arg) | |
return type(arg)(new_arg_to_return) | |
if not isinstance(arg, Node): | |
return arg | |
assert isinstance(arg, Node) | |
# default (no observer) | |
new_arg = arg | |
# find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes | |
original_arg = arg | |
while _is_activation_post_process_node(original_arg, named_modules): | |
original_arg = original_arg.args[0] # type: ignore[assignment] | |
assert isinstance(original_arg, Node), f"expect original argument to be a Node, but got: {type(original_arg)}" | |
input_edge = (original_arg, node) | |
if input_edge not in obs_or_fq_map: | |
return new_arg | |
# input_edge needs to be observed | |
input_edge_obs_or_fq = obs_or_fq_map[input_edge] | |
if input_edge_obs_or_fq is None: | |
return new_arg | |
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None) | |
# the arg is observed as the output and is using the same instance as the input_edge | |
# we'll reuse the inserted observer/fake_quant | |
if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(input_edge_obs_or_fq): | |
return new_arg | |
# otherwise, we'll insert a new observer/fake_quant node | |
existing_obs_node = None | |
# skip inserting new observers if the same observer instance is inserted before for another user | |
# Example: | |
# conv1 -> obs1 -> existing_obs -> conv2 | |
# \ -> conv3 | |
# | |
# instead of inserting new observers we will have: | |
# conv1 -> obs1 -> existing_obs -> conv2 | |
# \ -> conv3 | |
for maybe_obs_node in arg.users.keys(): | |
if not _is_activation_post_process_node(maybe_obs_node, named_modules): | |
continue | |
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] | |
if id(maybe_obs_mod) == id(input_edge_obs_or_fq): | |
return maybe_obs_node | |
new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph) | |
return new_arg | |
def _maybe_insert_input_observers_for_node( | |
node: Node, | |
qconfig: QConfigAny, | |
model: torch.nn.Module, | |
named_modules: Dict[str, torch.nn.Module], | |
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], | |
is_qat: bool, | |
) -> None: | |
""" | |
If needed, inserts observers to the input args and kwargs of `node`. | |
Note: modifies `node` inplace. | |
For example, if cur_node needs an observer after prev_node, we change from | |
prev_node -> cur_node | |
To | |
prev_node -> obs -> cur_node | |
""" | |
# Look through every input arg. If that arg's target dtype does not | |
# match the current node's target dtype, insert an observer. | |
new_args = [] | |
# map from old arg to new arg, used for updating the numeric debug handle map | |
remap = {} | |
for arg in node.args: | |
new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( | |
node, arg, qconfig, model, named_modules, obs_or_fq_map, is_qat, | |
) | |
new_args.append(new_arg) | |
remap[arg] = new_arg | |
if "numeric_debug_handle" in node.meta: | |
def remap_fn(x): | |
return remap.get(x, x) | |
numeric_debug_handle = node.meta["numeric_debug_handle"] | |
node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()} | |
# Clone has a memory_format kwarg and zeros_like has a pin_memory kwarg | |
# that persist in exported graph. This is just a work around for these. | |
assert ( | |
node.target == torch.ops.aten.clone.default or | |
node.target == torch.ops.aten.zeros_like.default or | |
len(node.kwargs) == 0 | |
), " expecting kwargs for aten op IR to be empty" | |
# assign the new args to the node, inplace | |
node.args = tuple(new_args) | |
def _maybe_insert_output_observer_for_node( | |
node: Node, | |
model: torch.nn.Module, | |
named_modules: Dict[str, torch.nn.Module], | |
graph: Graph, | |
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], | |
is_qat: bool, | |
) -> Optional[Node]: | |
if node in obs_or_fq_map: | |
output_act_obs_or_fq = obs_or_fq_map[node] | |
return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph) | |
return None | |
def _maybe_insert_input_and_output_observers_for_node( | |
node: Node, | |
model: torch.fx.GraphModule, | |
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], | |
is_qat: bool, | |
): | |
this_node_quantization_annotation = node.meta["quantization_annotation"] if "quantization_annotation" in node.meta else None | |
if this_node_quantization_annotation is None: | |
return | |
named_modules = dict(model.named_modules(remove_duplicate=False)) | |
_maybe_insert_input_observers_for_node( | |
node, | |
None, # qconfig | |
model, | |
named_modules, | |
obs_or_fq_map, | |
is_qat, | |
) | |
output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) | |
if not output_is_a_tensor: | |
return | |
# this returns the new observer node if it was needed | |
maybe_output_obs_node = _maybe_insert_output_observer_for_node( | |
node, model, named_modules, model.graph, obs_or_fq_map, is_qat) | |
if maybe_output_obs_node is None: | |
return | |
# Update users of original node to use the output observer | |
# instead. For example, change | |
# | |
# next_node | |
# / | |
# cur_node -> obs | |
# | |
# to | |
# | |
# next_node | |
# / | |
# cur_node -> obs | |
# | |
# We need to save orig users before updating uses because | |
# the list of users will change as we update uses | |
orig_users = list(node.users.keys()) | |
for user_node in orig_users: | |
if user_node is maybe_output_obs_node: | |
continue | |
user_node.replace_input_with(node, maybe_output_obs_node) | |
def prepare( | |
model: GraphModule, | |
node_name_to_scope: Dict[str, Tuple[str, type]], | |
is_qat: bool, | |
) -> GraphModule: | |
# Since we are mutating the graph as we go, we iterate over the original | |
# nodes before observer insertion, instead of model.graph.nodes. | |
nodes_before_observation = list(model.graph.nodes) | |
# At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance | |
# all edge/nodes that belongs to the same group will use the same instance | |
# and when we insert observers we'll just query this map to get the correct observer_or_fake_quant | |
# instance | |
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model) | |
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) | |
obs_or_fq_map = _get_obs_or_fq_map(edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat) | |
for node in nodes_before_observation: | |
# TODO: simplify logic for inserting observers | |
_maybe_insert_input_and_output_observers_for_node(node, model, obs_or_fq_map, is_qat) | |
model = GraphModule(model, model.graph) | |
_save_state( | |
model, | |
{}, # node_name_to_qconfig | |
node_name_to_scope, | |
PrepareCustomConfig(), | |
{}, # equalization_node_name_to_qconfig | |
QConfigMapping(), | |
is_qat, | |
set() # observed_node_names | |
) | |
return model | |