Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import functools | |
import sys | |
import warnings | |
from typing import List, Optional, Sequence, Tuple, Union | |
import torch | |
import torch._C._onnx as _C_onnx | |
import torch.onnx | |
from torch import _C | |
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics | |
from torch.onnx import ( | |
_constants, | |
_type_utils, | |
errors, | |
symbolic_helper, | |
symbolic_opset9 as opset9, | |
) | |
from torch.onnx._globals import GLOBALS | |
from torch.onnx._internal import _beartype, jit_utils, registration | |
# EDITING THIS FILE? READ THIS FIRST! | |
# see Note [Edit Symbolic Files] in README.md | |
# This file exports ONNX ops for opset 10 | |
# Opset 10 is supported by ONNX release 1.5.0 | |
# release on 04/24/19 | |
__all__ = [ | |
"dequantize", | |
"div", | |
"embedding_bag", | |
"fake_quantize_per_tensor_affine", | |
"flip", | |
"fmod", | |
"isfinite", | |
"isinf", | |
"nan_to_num", | |
"quantize_per_tensor", | |
"quantized_add_relu", | |
"quantized_add", | |
"quantized_cat", | |
"quantized_conv1d_relu", | |
"quantized_conv2d_relu", | |
"quantized_conv3d_relu", | |
"quantized_conv1d", | |
"quantized_conv2d", | |
"quantized_conv3d", | |
"quantized_conv_transpose1d", | |
"quantized_conv_transpose2d", | |
"quantized_conv_transpose3d", | |
"quantized_group_norm", | |
"quantized_hardswish", | |
"quantized_instance_norm", | |
"quantized_layer_norm", | |
"quantized_leaky_relu", | |
"quantized_linear", | |
"quantized_linear_relu", | |
"quantized_mul", | |
"quantized_sigmoid", | |
"slice", | |
"sort", | |
"topk", | |
] | |
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10) | |
def _apply_params(*args, **kwargs): | |
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" | |
def _apply(fn): | |
return fn(*args, **kwargs) | |
return _apply | |
def div(g: jit_utils.GraphContext, self, other, *args): | |
if len(args) == 0: | |
return opset9.true_divide(g, self, other) | |
else: | |
return _div_rounding_mode(g, self, other, *args) | |
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): | |
if rounding_mode == "floor": | |
return _floor_divide(g, self, other) | |
else: | |
return opset9._div_rounding_mode(g, self, other, rounding_mode) | |
def _floor_divide(g: jit_utils.GraphContext, self, other): | |
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): | |
out = opset9.true_divide(g, self, other) | |
return g.op("Floor", out) | |
else: | |
# Integer division does trunction rounding | |
div = g.op("Div", self, other) | |
# Division is negative if: self < 0 != other < 0 | |
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) | |
negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero)) | |
# For negative numbers with self % other != 0, subtract 1 to round down instead of up | |
mod = g.op("Mod", self, other, fmod_i=0) | |
fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) | |
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) | |
fixup = g.op("Sub", div, one) | |
return g.op("Where", fixup_mask, fixup, div) | |
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): | |
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) | |
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): | |
return symbolic_helper._topk_helper( | |
g, self, k, dim, largest=largest, sorted=sorted, out=out | |
) | |
def _aten_max_pool_onnx( | |
g: jit_utils.GraphContext, | |
self: _C.Value, | |
kernel_shape: Sequence[int], | |
strides: Sequence[int], | |
pads: Sequence[int], | |
dilations: Sequence[int], | |
ceil_mode: bool, | |
unbatched_rank: int, | |
) -> _C.Value: | |
self_rank = g.op("Size", g.op("Shape", self)) | |
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 | |
self = g.op( | |
"Unsqueeze", | |
self, | |
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), | |
) | |
pool_result, _ = g.op( | |
"MaxPool", | |
self, | |
outputs=2, | |
ceil_mode_i=ceil_mode, | |
dilations_i=dilations, | |
kernel_shape_i=kernel_shape, | |
pads_i=pads, | |
strides_i=strides, | |
) | |
if self_rank == unbatched_rank: | |
pool_result = g.op( | |
"Squeeze", | |
pool_result, | |
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), | |
) | |
return pool_result | |
# For MaxPool | |
def _adjust_attributes_of_max_pool( | |
expand_size: int, | |
kernel_size: Union[Sequence[int], int], | |
stride: Union[Sequence[int], int], | |
padding: Union[Sequence[int], int], | |
dilation: Union[Sequence[int], int], | |
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: | |
"""Adjust attributes of avg_pool to match ONNX specification.""" | |
if isinstance(dilation, int): | |
dilation = [dilation] * expand_size | |
if isinstance(kernel_size, int): | |
kernel_shape = [kernel_size] * expand_size | |
else: | |
kernel_shape = kernel_size # type: ignore[assignment] | |
if isinstance(padding, int): | |
pads = [padding] * expand_size * 2 # type: ignore[operator, assignment] | |
elif len(padding) == 1: | |
pads = padding * expand_size * 2 # type: ignore[operator, assignment] | |
elif len(padding) == 2: | |
# 2D padding | |
pads = padding * 2 # type: ignore[operator, assignment] | |
elif len(padding) == 3: | |
# 3D padding | |
pads = padding * 2 # type: ignore[operator, assignment] | |
else: | |
# When padding is already done for all dimensions, | |
# we don't need to double it | |
# eg: (1, 1, 1, 1, 1, 1) | |
pads = padding # type: ignore[assignment] | |
if isinstance(stride, int): | |
strides = [stride] * expand_size | |
elif not stride: | |
strides = kernel_shape | |
else: | |
strides = stride # type: ignore[assignment] | |
return (kernel_shape, strides, pads, dilation) | |
def _aten_max_pool_with_indices_onnx( | |
g: jit_utils.GraphContext, | |
self: _C.Value, | |
kernel_shape: Sequence[int], | |
strides: Sequence[int], | |
pads: Sequence[int], | |
dilations: Sequence[int], | |
ceil_mode: bool, | |
unbatched_rank: int, | |
n_dims_one: Sequence[int], | |
n_dims_zero: Sequence[int], | |
n_dims_axes: Sequence[int], | |
) -> Tuple[_C.Value, Sequence[int]]: | |
self_rank = g.op("Size", g.op("Shape", self)) | |
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 | |
self = g.op( | |
"Unsqueeze", | |
self, | |
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), | |
) | |
pool_result, indices = g.op( | |
"MaxPool", | |
self, | |
outputs=2, | |
ceil_mode_i=ceil_mode, | |
dilations_i=dilations, | |
kernel_shape_i=kernel_shape, | |
pads_i=pads, | |
strides_i=strides, | |
) | |
_, flatten_indices = g.op( | |
"MaxPool", | |
self, | |
outputs=2, | |
dilations_i=dilations, | |
kernel_shape_i=n_dims_one, | |
strides_i=n_dims_one, | |
) | |
ends = g.op("Constant", value_t=torch.tensor(n_dims_one)) | |
starts = g.op("Constant", value_t=torch.tensor(n_dims_zero)) | |
axes = g.op("Constant", value_t=torch.tensor(n_dims_axes)) | |
delta = g.op("Slice", flatten_indices, starts, ends, axes) | |
indices = g.op("Sub", indices, delta) | |
if self_rank == unbatched_rank: | |
pool_result = g.op( | |
"Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64) | |
) | |
indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64)) | |
return (pool_result, indices) | |
def _max_pool(name: str, expand_size: int, return_indices: bool): | |
def symbolic_fn( | |
g: jit_utils.GraphContext, | |
input: _C.Value, | |
kernel_size: Sequence[int], | |
stride: Sequence[int], | |
padding: Union[int, Sequence[int]], | |
dilation: Sequence[int], | |
ceil_mode: bool, | |
): | |
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool( | |
expand_size, kernel_size, stride, padding, dilation | |
) | |
if return_indices: | |
return _aten_max_pool_with_indices_onnx( | |
g, | |
input, | |
kernel_shape, | |
strides, | |
pads, | |
dilations, | |
ceil_mode, | |
expand_size + 1, | |
([1] * expand_size), | |
([0] * expand_size), | |
([2 + i for i in range(expand_size)]), | |
) | |
else: | |
return _aten_max_pool_onnx( | |
g, | |
input, | |
kernel_shape, | |
strides, | |
pads, | |
dilations, | |
ceil_mode, | |
expand_size + 1, | |
) | |
return symbolic_fn | |
# For AvgPool | |
def _adjust_attributes_of_avg_pool( | |
expand_size: int, | |
kernel_size: Union[Sequence[int], int], | |
stride: Union[Sequence[int], int], | |
padding: Union[Sequence[int], int], | |
) -> Tuple[Sequence[int], Sequence[int], Sequence[int]]: | |
"""Adjust attributes of avg_pool to match ONNX specification.""" | |
if isinstance(kernel_size, int): | |
kernel_shape = [kernel_size] * expand_size | |
else: | |
kernel_shape = kernel_size # type: ignore[assignment] | |
if isinstance(padding, int): | |
pads = [padding] * expand_size * 2 | |
elif len(padding) == 1: | |
pads = padding * expand_size * 2 # type: ignore[operator, assignment] | |
elif len(padding) == 2: | |
pads = padding * expand_size # type: ignore[operator, assignment] | |
else: | |
pads = padding * 2 # type: ignore[operator, assignment] | |
if isinstance(stride, int): | |
strides = [stride] * expand_size | |
elif not stride: | |
strides = kernel_shape | |
else: | |
strides = stride # type: ignore[assignment] | |
return (kernel_shape, strides, pads) | |
def _avg_pool(name, expand_size): | |
def symbolic_fn( | |
g, | |
input: _C.Value, | |
kernel_size: Sequence[int], | |
stride: Sequence[int], | |
padding: Union[int, Sequence[int]], | |
ceil_mode: int, | |
count_include_pad: int, | |
divisor_override=None, | |
): | |
kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( | |
expand_size, kernel_size, stride, padding | |
) | |
result = g.op( | |
"AveragePool", | |
input, | |
ceil_mode_i=ceil_mode, | |
count_include_pad_i=count_include_pad, | |
kernel_shape_i=kernel_shape, | |
pads_i=pads, | |
strides_i=strides, | |
) | |
return result | |
return symbolic_fn | |
def _interpolate(name, dim, interpolate_mode): | |
def symbolic_fn(g, input, output_size, *args): | |
scales, align_corners = symbolic_helper._get_interpolate_attributes( | |
g, interpolate_mode, args | |
) | |
symbolic_helper._interpolate_warning(interpolate_mode) | |
align_corners = symbolic_helper._maybe_get_scalar(align_corners) | |
if align_corners: | |
return symbolic_helper._unimplemented(name, "align_corners == True", input) | |
if scales is None: | |
scales = symbolic_helper._interpolate_size_to_scales( | |
g, input, output_size, dim | |
) | |
return g.op("Resize", input, scales, mode_s=interpolate_mode) | |
return symbolic_fn | |
def __interpolate( | |
g: jit_utils.GraphContext, | |
input, | |
size, | |
scale_factor, | |
mode, | |
align_corners, | |
recompute_scale_factor, | |
antialias, | |
): | |
scales, mode = symbolic_helper._interpolate_get_scales_and_mode( | |
g, input, size, scale_factor, mode, align_corners | |
) | |
return g.op("Resize", input, scales, mode_s=mode) | |
def _slice( | |
g: jit_utils.GraphContext, | |
input: torch._C.Value, | |
axes: Union[List, torch.Tensor, torch._C.Value], | |
starts: Union[List, torch.Tensor, torch._C.Value], | |
ends: Union[List, torch.Tensor, torch._C.Value], | |
steps: Optional[Union[List, torch.Tensor, torch._C.Value]] = None, | |
): | |
def is_none_value(value): | |
if value is None: | |
return True | |
return ( | |
isinstance(value, torch._C.Value) | |
and value.node().kind() == "prim::Constant" | |
and isinstance(value.type(), _C.NoneType) | |
) | |
def to_slice_input(list_or_value, default_value=None): | |
# Convert input param into a 1D torch.Value. | |
if is_none_value(list_or_value) and default_value is not None: | |
list_or_value = [default_value] | |
if isinstance(list_or_value, (list, torch.Tensor)): | |
return g.op("Constant", value_t=torch.tensor(list_or_value)) | |
rank = symbolic_helper._get_tensor_rank(list_or_value) | |
if rank == 0: | |
return symbolic_helper._unsqueeze_helper(g, list_or_value, [0]) | |
if rank == 1: | |
return list_or_value | |
raise errors.SymbolicValueError( | |
f"Rank must be 0 or 1, not {rank}", list_or_value | |
) | |
def get_const_value(list_or_value): | |
if isinstance(list_or_value, (list, torch.Tensor)): | |
if len(list_or_value) == 1: | |
return list_or_value[0] | |
return None | |
return symbolic_helper._maybe_get_const(list_or_value, "i") | |
# Check if slice is a no-op | |
if ( | |
get_const_value(starts) == 0 | |
and get_const_value(ends) == _constants.INT64_MAX | |
and (steps is None or get_const_value(steps) == 1) | |
): | |
return input | |
axes = to_slice_input(axes) | |
starts = to_slice_input(starts, default_value=0) | |
ends = to_slice_input(ends, default_value=_constants.INT64_MAX) | |
if steps is None: | |
return g.op("Slice", input, starts, ends, axes) | |
steps = to_slice_input(steps, default_value=1) | |
return g.op("Slice", input, starts, ends, axes, steps) | |
def slice(g: jit_utils.GraphContext, self, *args): | |
if len(args) == 4: | |
# aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor | |
dims, start, end, step = args | |
elif len(args) == 3: | |
# aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[] | |
start, end, step = args | |
dims = [0] | |
else: | |
raise errors.SymbolicValueError("Unknown aten::slice signature", self) | |
return symbolic_helper._slice_helper( | |
g, | |
self, | |
axes=dims, | |
starts=start, | |
ends=end, | |
steps=step, | |
) | |
def flip(g: jit_utils.GraphContext, input, dims): | |
return symbolic_helper._slice_helper( | |
g, | |
input, | |
axes=dims, | |
starts=[-1] * len(dims), | |
ends=[-_constants.INT64_MAX] * len(dims), | |
steps=[-1] * len(dims), | |
) | |
def fmod(g: jit_utils.GraphContext, input, other): | |
return g.op("Mod", input, other, fmod_i=1) | |
def embedding_bag( | |
g: jit_utils.GraphContext, | |
embedding_matrix, | |
indices, | |
offsets, | |
scale_grad_by_freq, | |
mode, | |
sparse, | |
per_sample_weights, | |
include_last_offset, | |
padding_idx, | |
): | |
if scale_grad_by_freq and GLOBALS.export_training: | |
return symbolic_helper._onnx_unsupported( | |
"embedding_bag with scale_grad_by_freq for training mode" | |
) | |
if padding_idx is not None and padding_idx >= 0: | |
raise RuntimeError("embedding_bag with padding_idx") | |
warnings.warn( | |
"Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " | |
"Please use opset 11 or higher to export model for dynamic input shape.'" | |
) | |
offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0) | |
if offsets_dim_0 is not None: | |
if include_last_offset: | |
offset_len = offsets_dim_0 - 1 | |
offsets_extended = offsets | |
else: | |
offset_len = offsets_dim_0 | |
offsets_extended = [ | |
offsets, | |
g.op("Constant", value_t=torch.tensor([sys.maxsize])), | |
] | |
offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) | |
list_ = [] | |
for i in range(offset_len): | |
start_ = symbolic_helper._unsqueeze_helper( | |
g, | |
opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), | |
[0], | |
) | |
end_ = symbolic_helper._unsqueeze_helper( | |
g, | |
opset9.select( | |
g, offsets_extended, torch.tensor(0), torch.tensor(i + 1) | |
), | |
[0], | |
) | |
axes_ = g.op("Constant", value_t=torch.tensor([0])) | |
indices_row = g.op("Slice", indices, start_, end_, axes_) | |
embeddings = g.op("Gather", embedding_matrix, indices_row) | |
if not symbolic_helper._is_none(per_sample_weights): | |
per_sample_weights_row = g.op( | |
"Slice", per_sample_weights, start_, end_, axes_ | |
) | |
per_sample_weights_row = symbolic_helper._unsqueeze_helper( | |
g, per_sample_weights_row, [1] | |
) | |
embeddings = g.op("Mul", embeddings, per_sample_weights_row) | |
if mode == 0: | |
embeddings = symbolic_helper._reducesum_helper( | |
g, embeddings, axes_i=[0], keepdims_i=0 | |
) | |
elif mode == 1: | |
embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) | |
else: | |
embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) | |
embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0]) | |
list_.append(embeddings) | |
output = g.op("Concat", *list_, axis_i=0) | |
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. | |
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. | |
return output, None, None, None | |
else: | |
return symbolic_helper._onnx_unsupported( | |
"embedding_bag with unknown shape of offsets for opset 10 is not supported. " | |
"please use opset 11 or higher." | |
) | |
def fake_quantize_per_tensor_affine( | |
g: jit_utils.GraphContext, | |
inputs, | |
scale, | |
zero_point, | |
quant_min=-128, | |
quant_max=127, | |
): | |
# NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127). | |
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 | |
if (quant_min, quant_max) == (0, 127): | |
symbolic_helper._onnx_opset_unsupported_detailed( | |
"fake_quantize_per_tensor_affine", | |
10, | |
13, | |
"Quantize range (0, 127) not supported, requires opset 13 Clip", | |
inputs, | |
) | |
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: | |
raise errors.SymbolicValueError( | |
f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " | |
f"Got ({quant_min}, {quant_max})", | |
inputs, | |
) | |
scale = symbolic_helper._maybe_get_scalar(scale) | |
if scale is None: | |
symbolic_helper._onnx_opset_unsupported_detailed( | |
"fake_quantize_per_tensor_affine", | |
10, | |
13, | |
"Non-constant scale not supported", | |
inputs, | |
) | |
scale = scale.float().data # Avoid exporter generating double type | |
if quant_min == 0: | |
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) | |
else: | |
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) | |
return g.op( | |
"DequantizeLinear", | |
g.op("QuantizeLinear", inputs, scale, zero_point), | |
scale, | |
zero_point, | |
) | |
def isinf(g: jit_utils.GraphContext, input): | |
return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)) | |
def isfinite(g: jit_utils.GraphContext, input): | |
inf_node = isinf(g, input) | |
nan_node = opset9.isnan(g, input) | |
return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node)) | |
def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
# TODO(justinchuby): Extract all the cast ops into a helper function. | |
zero_point = g.op( | |
"Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type() | |
) | |
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) | |
return symbolic_helper.quantize_helper(g, input, scale, zero_point) | |
def dequantize(g: jit_utils.GraphContext, input): | |
return symbolic_helper.dequantize_helper(g, input)[0] | |
def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf): | |
# Cannot create a int type tensor with inf/nan values, so we simply | |
# return the original tensor | |
if not symbolic_helper._is_fp(input): | |
return input | |
input_dtype = _type_utils.JitScalarType.from_value(input).dtype() | |
if nan is None: | |
nan = 0.0 | |
nan_cond = opset9.isnan(g, input) | |
nan_result = g.op( | |
"Where", | |
nan_cond, | |
g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), | |
input, | |
) | |
# For None values of posinf, neginf we use the greatest/lowest finite | |
# value representable by input’s dtype. | |
finfo = torch.finfo(input_dtype) | |
if posinf is None: | |
posinf = finfo.max | |
posinf_cond = opset9.logical_and( | |
g, | |
isinf(g, nan_result), | |
opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))), | |
) | |
nan_posinf_result = g.op( | |
"Where", | |
posinf_cond, | |
g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), | |
nan_result, | |
) | |
if neginf is None: | |
neginf = finfo.min | |
neginf_cond = opset9.logical_and( | |
g, | |
isinf(g, nan_posinf_result), | |
opset9.lt( | |
g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0])) | |
), | |
) | |
return g.op( | |
"Where", | |
neginf_cond, | |
g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), | |
nan_posinf_result, | |
) | |
# Quantized symbolics --------------------------------------------------------- | |
# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export | |
# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were | |
# introduced in opset version 10. | |
def quantized_linear( | |
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point | |
): | |
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | |
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | |
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | |
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | |
output = opset9.linear(g, input, weight, bias) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_linear_relu( | |
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point | |
): | |
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | |
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | |
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | |
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | |
output = opset9.linear(g, input, weight, bias) | |
output = opset9.relu(g, output) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): | |
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | |
y, _, _, _ = symbolic_helper.dequantize_helper(g, y) | |
output = opset9.add(g, x, y) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): | |
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | |
y, _, _, _ = symbolic_helper.dequantize_helper(g, y) | |
output = opset9.add(g, x, y) | |
output = opset9.relu(g, output) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): | |
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | |
y, _, _, _ = symbolic_helper.dequantize_helper(g, y) | |
output = opset9.mul(g, x, y) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): | |
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | |
output = opset9.hardswish(g, x) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point): | |
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | |
output = opset9.sigmoid(g, x) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_leaky_relu( | |
g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point | |
): | |
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | |
output = opset9.leaky_relu(g, x, negative_slope, inplace) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_layer_norm( | |
g: jit_utils.GraphContext, | |
x, | |
normalized_shape, | |
weight, | |
bias, | |
eps, | |
op_scale, | |
op_zero_point, | |
): | |
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | |
output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_group_norm( | |
g: jit_utils.GraphContext, | |
x, | |
num_groups, | |
weight, | |
bias, | |
eps, | |
op_scale, | |
op_zero_point, | |
): | |
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | |
output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_instance_norm( | |
g: jit_utils.GraphContext, | |
q_input, | |
weight, | |
bias, | |
eps, | |
op_scale, | |
op_zero_point, | |
): | |
input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) | |
output = opset9.instance_norm( | |
g, input, weight, bias, None, None, False, 0.0, eps, False | |
) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_conv1d_relu( | |
g: jit_utils.GraphContext, | |
q_input, | |
q_weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
groups, | |
op_scale, | |
op_zero_point, | |
): | |
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | |
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | |
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | |
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | |
output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) | |
output = opset9.relu(g, output) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_conv2d_relu( | |
g: jit_utils.GraphContext, | |
q_input, | |
q_weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
groups, | |
op_scale, | |
op_zero_point, | |
): | |
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | |
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | |
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | |
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | |
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) | |
output = opset9.relu(g, output) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_conv3d_relu( | |
g: jit_utils.GraphContext, | |
q_input, | |
q_weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
groups, | |
op_scale, | |
op_zero_point, | |
): | |
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | |
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | |
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | |
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | |
output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) | |
output = opset9.relu(g, output) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_conv1d( | |
g: jit_utils.GraphContext, | |
q_input, | |
q_weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
groups, | |
op_scale, | |
op_zero_point, | |
): | |
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | |
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | |
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | |
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | |
output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_conv2d( | |
g: jit_utils.GraphContext, | |
q_input, | |
q_weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
groups, | |
op_scale, | |
op_zero_point, | |
): | |
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | |
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | |
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | |
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | |
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_conv3d( | |
g: jit_utils.GraphContext, | |
q_input, | |
q_weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
groups, | |
op_scale, | |
op_zero_point, | |
): | |
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | |
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | |
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | |
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | |
output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_conv_transpose1d( | |
g: jit_utils.GraphContext, | |
q_input, | |
q_weight, | |
bias, | |
stride, | |
padding, | |
output_padding, | |
dilation, | |
groups, | |
op_scale, | |
op_zero_point, | |
): | |
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | |
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | |
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | |
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | |
output = opset9.conv_transpose2d( | |
g, input, weight, bias, stride, padding, output_padding, groups, dilation | |
) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_conv_transpose2d( | |
g: jit_utils.GraphContext, | |
q_input, | |
q_weight, | |
bias, | |
stride, | |
padding, | |
output_padding, | |
dilation, | |
groups, | |
op_scale, | |
op_zero_point, | |
): | |
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | |
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | |
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | |
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | |
output = opset9.conv_transpose2d( | |
g, input, weight, bias, stride, padding, output_padding, groups, dilation | |
) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_conv_transpose3d( | |
g: jit_utils.GraphContext, | |
q_input, | |
q_weight, | |
bias, | |
stride, | |
padding, | |
output_padding, | |
dilation, | |
groups, | |
op_scale, | |
op_zero_point, | |
): | |
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | |
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | |
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | |
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | |
output = opset9.conv_transpose3d( | |
g, input, weight, bias, stride, padding, output_padding, groups, dilation | |
) | |
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | |
def quantized_cat( | |
g: jit_utils.GraphContext, | |
q_inputs: _C.Value, | |
dim: int, | |
op_scale: _C.Value, | |
op_zero_point: _C.Value, | |
) -> _C.Value: | |
unpacked_inputs = symbolic_helper._unpack_list(q_inputs) | |
dequantized = [ | |
symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs | |
] | |
concatenated = g.op("Concat", *dequantized, axis_i=dim) | |
return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) | |