Spaces:
Running
Running
import enum | |
import operator | |
import torch | |
import torch.nn as nn | |
import torch.ao.nn.intrinsic.quantized as nniq | |
import torch.ao.nn.quantized as nnq | |
toq = torch.ops.quantized | |
from typing import Tuple, Callable, Dict, Set, List, Optional, Union | |
from torch.fx import GraphModule | |
from torch.fx.graph import Node | |
from torch.ao.quantization import ( | |
ObserverBase, | |
FakeQuantizeBase, | |
) | |
from torch.ao.quantization.utils import getattr_from_fqn | |
from torch.ao.quantization.observer import _is_activation_post_process | |
from .ns_types import NSNodeTargetType, NSResultsType | |
# TODO(future PR): consider deleting this enum and using the torch types | |
# directly. This might be tricky because it is not a one to one mapping. | |
class NodeInputOrOutputType(enum.Enum): | |
FP32 = enum.auto() # torch.float | |
INT8 = enum.auto() # torch.qint8 or torch.quint8 | |
FP16 = enum.auto() # torch.float16 | |
UNKNOWN = enum.auto() # we cannot determine input/output dtype | |
# TODO(future PR): while these functions can support multiple dtypes, | |
# for the purposes of numerical debugging we want to get the actual | |
# dtype used in the model. We will likely need some kind of dtype | |
# propagation to estimate this. | |
FP32_OR_INT8 = enum.auto() # either torch.float or torch.quint8 or torch.qint8 | |
# TODO(future PRs): dynamic quant, fake quant, etc | |
def get_node_first_input_and_output_type( | |
node: Node, | |
gm: GraphModule, | |
logger_cls: Callable, | |
node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]], | |
) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]: | |
# TODO(future PR): clean this up | |
FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"] | |
FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"] | |
FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"] | |
FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"] | |
MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"] | |
MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"] | |
MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"] | |
METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"] | |
if node.op == "call_function": | |
if node.target in FUNS_IO_TYPE_FP32: | |
return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) | |
if node.target in FUNS_IO_TYPE_FP16: | |
return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16) | |
elif node.target in FUNS_IO_TYPE_INT8: | |
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) | |
elif node.target in FUNS_IO_TYPE_FP32_OR_INT8: | |
first_arg = get_normalized_nth_input(node, gm, 0) | |
assert isinstance(first_arg, Node) | |
( | |
_prev_node_input_type, | |
prev_node_output_type, | |
) = get_node_first_input_and_output_type( | |
first_arg, gm, logger_cls, node_type_to_io_type_map | |
) | |
return (prev_node_output_type, prev_node_output_type) | |
else: | |
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) | |
elif node.op == "call_module": | |
assert node.op == "call_module" | |
assert isinstance(node.target, str) | |
mod = getattr_from_fqn(gm, node.target) | |
is_known_fp32_or_int8_input_module = any( | |
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type] | |
) | |
if ( | |
isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)) # type: ignore[arg-type] | |
or is_known_fp32_or_int8_input_module | |
): | |
# A logger or observer's input and output type is the output | |
# type of the preceding node. | |
first_arg = get_normalized_nth_input(node, gm, 0) | |
assert isinstance(first_arg, Node) | |
( | |
_prev_node_input_type, | |
prev_node_output_type, | |
) = get_node_first_input_and_output_type( | |
first_arg, gm, logger_cls, node_type_to_io_type_map | |
) | |
return (prev_node_output_type, prev_node_output_type) | |
is_known_fp32_input_module = any( | |
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32 # type: ignore[arg-type] | |
) | |
is_known_int8_input_module = any( | |
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8 # type: ignore[arg-type] | |
) | |
if is_known_fp32_input_module: | |
return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) | |
elif is_known_int8_input_module: | |
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) | |
else: | |
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) | |
elif node.op == "call_method": | |
if node.target == "dequantize": | |
# Dequantize is a special node because it allows multiple input types. | |
# So, we look up the output type of the previous node and return that | |
# as the input type of this node instance. | |
prev_node = get_normalized_nth_input(node, gm, 0) | |
assert isinstance(prev_node, Node) | |
( | |
_prev_node_input_type, | |
prev_node_output_type, | |
) = get_node_first_input_and_output_type( | |
prev_node, gm, logger_cls, node_type_to_io_type_map | |
) | |
return (prev_node_output_type, NodeInputOrOutputType.FP32) | |
elif node.target == "to": | |
# to is a special node because it allows multiple input types. | |
# So, we look up the output type of the previous node and return that | |
# as the input type of this node instance. We also look up the target | |
# of to and return the correct output type. | |
prev_node = get_normalized_nth_input(node, gm, 0) | |
assert isinstance(prev_node, Node) | |
( | |
_prev_node_input_type, | |
prev_node_output_type, | |
) = get_node_first_input_and_output_type( | |
prev_node, gm, logger_cls, node_type_to_io_type_map | |
) | |
cur_node_dtype_target = get_normalized_nth_input(node, gm, 1) | |
assert ( | |
cur_node_dtype_target is torch.float16 | |
), f"{cur_node_dtype_target} handling needs to be added" | |
return (prev_node_output_type, NodeInputOrOutputType.FP16) | |
elif node.target in METHS_IO_TYPE_FP32_OR_INT8: | |
first_arg = get_normalized_nth_input(node, gm, 0) | |
assert isinstance(first_arg, Node) | |
( | |
_prev_node_input_type, | |
prev_node_output_type, | |
) = get_node_first_input_and_output_type( | |
first_arg, gm, logger_cls, node_type_to_io_type_map | |
) | |
return (prev_node_output_type, prev_node_output_type) | |
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) | |
else: | |
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) | |
def get_node_input_qparams( | |
node: Node, | |
gm: GraphModule, | |
node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]], | |
) -> Optional[Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]: | |
""" | |
Returns the qparams (scale, zero_point) of the first input to `node`, | |
if they can be inferred from the graph. | |
""" | |
prev_node = get_normalized_nth_input(node, gm, 0) | |
if not isinstance(prev_node, Node): | |
return None | |
MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"] | |
def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx): | |
scale_node = get_normalized_nth_input(node, gm, scale_arg_idx) | |
zp_node = get_normalized_nth_input(node, gm, zp_arg_idx) | |
assert isinstance(scale_node, Node) and isinstance(scale_node.target, str) | |
assert isinstance(zp_node, Node) and isinstance(zp_node.target, str) | |
scale_obj = getattr_from_fqn(gm, scale_node.target) | |
zp_obj = getattr_from_fqn(gm, zp_node.target) | |
return (scale_obj, zp_obj) | |
if prev_node.op == "call_function": | |
# quantize - read the args directly | |
if prev_node.target == torch.quantize_per_tensor: | |
return _get_scale_zp_from_function_args(prev_node, gm, 1, 2) | |
elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu): | |
return _get_scale_zp_from_function_args(prev_node, gm, 2, 3) | |
return None | |
# TODO(future PR): handle more functionals | |
# TODO(future PR): handle functional ops which inherit qparams from input | |
elif prev_node.op == "call_module": | |
# get type of the module | |
assert isinstance(prev_node.target, str) | |
module_obj = getattr_from_fqn(gm, prev_node.target) | |
if isinstance( | |
module_obj, | |
( | |
nnq.Linear, | |
nnq.Conv1d, | |
nnq.Conv2d, | |
nniq.ConvReLU2d, | |
nnq.Conv3d, | |
nnq.BatchNorm2d, | |
nnq.BatchNorm3d, | |
nnq.ConvTranspose1d, | |
nnq.ConvTranspose2d, | |
nnq.ELU, | |
nnq.GroupNorm, | |
nnq.InstanceNorm1d, | |
nnq.InstanceNorm2d, | |
nnq.InstanceNorm3d, | |
nnq.LayerNorm, | |
nnq.Hardswish, | |
nnq.LeakyReLU, | |
nnq.ReLU6, | |
nniq.BNReLU2d, | |
nniq.BNReLU3d, | |
nniq.ConvReLU1d, | |
nniq.ConvReLU2d, | |
nniq.ConvReLU3d, | |
nniq.LinearReLU, | |
), | |
): | |
return (module_obj.scale, module_obj.zero_point) # type: ignore[return-value] | |
is_known_fp32_or_int8_input_module = any( | |
isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type] | |
) | |
if is_known_fp32_or_int8_input_module: | |
return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map) | |
return None | |
def return_first_non_observer_node( | |
node: Node, | |
gm: GraphModule, | |
) -> Node: | |
""" | |
If node is not an observer, returns it. If node is an observer, | |
navigates up the graph and returns the first parent which is not an | |
observer. For example, | |
graph: (node_non_obs), node = node_non_obs : returns node_non_obs | |
graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs | |
graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs | |
""" | |
if node.op == "call_module": | |
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] | |
if _is_activation_post_process(node_obj): | |
assert len(node.args) == 1 | |
assert isinstance(node.args[0], Node) | |
node = node.args[0] | |
# code duplication intended, not worth refactoring | |
assert isinstance(node.target, str) | |
node_obj = getattr_from_fqn(gm, node.target) | |
if _is_activation_post_process(node_obj): | |
assert len(node.args) == 1 | |
assert isinstance(node.args[0], Node) | |
node = node.args[0] | |
return node | |
def get_number_of_non_param_args( | |
node: Node, | |
gm: GraphModule, | |
) -> int: | |
""" | |
Assumes that all non-param args occur first. Returns the number of | |
non-param args expected for a node. For example, for | |
F.linear(x, weight, bias) | |
Returns 1, because x is a non-param arg and weight and bias are params. | |
For | |
lstm_mod(x, hid) | |
Returns 2, because both x and hid are non-param args. | |
""" | |
if node.op == "call_module": | |
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] | |
if isinstance(node_obj, nn.LSTM): | |
return 2 | |
# default is 1 | |
return 1 | |
def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]: | |
""" | |
Returns the indices of args of the node which we should attach | |
loggers to, if input logging is enabled. | |
For example, | |
* for (x + y), returns [0, 1] | |
* for (1 + y), returns [1] | |
* for (x + 1), returns [0] | |
* for (linear(x, w, b)) returns [0] | |
* by default, returns [0] | |
""" | |
if len(node.args) == 0: | |
return [] | |
if node.op == "call_function" and ( | |
# TODO(future PR): use relationship map instead of hardcoding | |
node.target in (torch.add, torch.ops.quantized.add, operator.add) | |
or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul) | |
): | |
result = [] | |
for i in range(2): | |
if type(node.args[i]) == Node: | |
result.append(i) | |
return result | |
return [0] | |
def get_target_type_str(node: Node, gm: GraphModule) -> str: | |
""" | |
Returns a string representation of the type of the function or module | |
pointed to by this node, or '' for other node types. | |
""" | |
target_type = "" | |
if node.op in ("call_function", "call_method"): | |
target_type = torch.typename(node.target) | |
elif node.op == "call_module": | |
assert isinstance(node.target, str) | |
target_mod = getattr_from_fqn(gm, node.target) | |
target_type = torch.typename(target_mod) | |
return target_type | |
def rekey_logger_info_on_node_name_of_model( | |
results: NSResultsType, | |
model_name: str, | |
) -> NSResultsType: | |
""" | |
Rekeys the layer name of a results dictionary to use node names | |
from `model_name`. | |
For example, transforms | |
{'base_op_1_0': {'node_output': {'model_a': | |
[{'ref_node_name': 'linear1', ...}]}}} | |
into | |
{'linear1': {'node_output': {'model_a': | |
[{'ref_node_name': 'linear1', ...}]}}} | |
Note: we cannot use these node names directly because they are not | |
guaranteed to be consistent across models. This is why we extract | |
the results first and rekey afterwards. | |
""" | |
new_results = {} | |
for old_layer_name, result_type_to_results in results.items(): | |
new_layer_name = None | |
for model_name_to_results in result_type_to_results.values(): | |
for cur_model_name, list_of_results in model_name_to_results.items(): | |
if cur_model_name == model_name: | |
assert len(list_of_results) | |
new_layer_name = list_of_results[0]["ref_node_name"] | |
else: | |
continue | |
if new_layer_name is not None: | |
new_results[new_layer_name] = result_type_to_results | |
else: | |
new_results[old_layer_name] = result_type_to_results | |
return new_results | |
def maybe_add_missing_fqns(results: NSResultsType) -> None: | |
""" | |
If `fqn` entries are filled in for one of the models in `results`, copies | |
them over to any models which do not have them filled out. | |
A common use case benefitting from this is comparing a model prepared by | |
quantization to a quantized model. In this case, the model prepared by | |
quantization would have `fqn` entries, and the quantized model would not. | |
""" | |
# Check in the first result to find any model with fqn entries defined. | |
model_name_with_fqns = None | |
for result_type_to_results in results.values(): | |
for model_name_to_results in result_type_to_results.values(): | |
for model_name, model_results in model_name_to_results.items(): | |
if len(model_results) > 0: | |
if model_results[0]["fqn"] is not None: | |
model_name_with_fqns = model_name | |
break | |
break | |
break | |
if model_name_with_fqns: | |
for result_type_to_results in results.values(): | |
for model_name_to_results in result_type_to_results.values(): | |
ref_model_results = model_name_to_results[model_name_with_fqns] | |
for model_name, model_results in model_name_to_results.items(): | |
if model_name == model_name_with_fqns: | |
continue | |
for i in range(len(model_results)): | |
fqn = ref_model_results[i]["fqn"] | |
model_results[i]["fqn"] = fqn | |
def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f): | |
def inner(*args, **kwargs): | |
a0, a1, *a_other = args | |
if (isinstance(a0, tuple) and isinstance(a1, tuple)) or ( | |
isinstance(a0, list) and isinstance(a1, list) | |
): | |
results = [] | |
for el0, el1 in zip(a0, a1): | |
new_args = (el0, el1, *a_other) | |
results.append(inner(*new_args, **kwargs)) | |
return results | |
elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor): | |
if a0.is_quantized: | |
a0 = a0.dequantize() | |
if a1.is_quantized: | |
a1 = a1.dequantize() | |
# for the purposes of this util, only handle floats | |
if a0.dtype != torch.float or a1.dtype != torch.float: | |
return None | |
new_args = (a0, a1, *a_other) | |
return f(*new_args, **kwargs) | |
return inner | |
def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
""" | |
Computes the SQNR between `x` and `y`. | |
Args: | |
x: Tensor or tuple of tensors | |
y: Tensor or tuple of tensors | |
Return: | |
float or tuple of floats | |
""" | |
Ps = torch.norm(x) | |
Pn = torch.norm(x - y) | |
return 20 * torch.log10(Ps / Pn) | |
def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
""" | |
Computes the normalized L2 error between `x` and `y`. | |
Args: | |
x: Tensor or tuple of tensors | |
y: Tensor or tuple of tensors | |
Return: | |
float or tuple of floats | |
""" | |
return torch.sqrt(((x - y) ** 2).sum() / (x ** 2).sum()) | |
def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
""" | |
Computes the cosine similarity between `x` and `y`. | |
Args: | |
x: Tensor or tuple of tensors | |
y: Tensor or tuple of tensors | |
Return: | |
float or tuple of floats | |
""" | |
# For convolutions, the shape of the quantized weight has one additional | |
# dimension compared to the shape of the fp32 weight. Match the shapes | |
# to enable cosine similarity comparison. | |
x = x.reshape(1, -1) | |
y = y.reshape(1, -1) | |
return torch.nn.functional.cosine_similarity(x, y) | |
def op_type_supports_shadowing(node: Node) -> bool: | |
if node.op == 'call_function': | |
if node.target in (torch.add, torch.mul, operator.add, operator.mul, torch.cat, torch.stack): | |
# shadowing for ops with multiple tensor inputs is not implemented yet | |
return False | |
return True | |
def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node: | |
""" | |
Given a node, gets the n'th input to that node, normalizing | |
args and kwargs to the best of its ability. | |
""" | |
try: | |
norm_args_and_kwargs = node.normalized_arguments( | |
gm, normalize_to_only_use_kwargs=True) | |
if norm_args_and_kwargs is not None: | |
norm_args, norm_kwargs = norm_args_and_kwargs | |
assert len(norm_args) + len(norm_kwargs) > idx | |
if idx < len(norm_args): | |
return norm_args[idx] | |
else: | |
# note: in Python 3.7+ dicts are ordered | |
return list(norm_kwargs.values())[idx] | |
else: | |
assert len(node.args) + len(node.kwargs) > idx | |
if idx < len(node.args): | |
return node.args[idx] # type: ignore[return-value] | |
else: | |
kwargs_idx = idx + len(node.args) | |
return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] | |
except RuntimeError: | |
# this RuntimeError happens when node argument normalization | |
# requires typehints to proceed, such as for torch.add where | |
# either the first, second or both arguments could be tensors | |
assert len(node.args) + len(node.kwargs) > idx | |
if idx < len(node.args): | |
return node.args[idx] # type: ignore[return-value] | |
else: | |
kwargs_idx = idx + len(node.args) | |
return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] | |