Spaces:
Running
Running
import array | |
import enum | |
import functools | |
import logging | |
import operator | |
import struct | |
import sys | |
from typing import List, NamedTuple, Optional, Tuple | |
import torch | |
# TODO: Add type annotations | |
# TODO: Check tensor types for ops | |
LOG = logging.getLogger("nnapi_serialize") | |
class NNAPI_OperandCode: | |
FLOAT32 = 0 | |
INT32 = 1 | |
UINT32 = 2 | |
TENSOR_FLOAT32 = 3 | |
TENSOR_INT32 = 4 | |
TENSOR_QUANT8_ASYMM = 5 | |
BOOL = 6 | |
TENSOR_QUANT16_SYMM = 7 | |
TENSOR_FLOAT16 = 8 | |
TENSOR_BOOL8 = 9 | |
FLOAT16 = 10 | |
TENSOR_QUANT8_SYMM_PER_CHANNEL = 11 | |
TENSOR_QUANT16_ASYMM = 12 | |
class NNAPI_OperationCode: | |
ADD = 0 | |
AVERAGE_POOL_2D = 1 | |
CONCATENATION = 2 | |
CONV_2D = 3 | |
DEPTHWISE_CONV_2D = 4 | |
DEPTH_TO_SPACE = 5 | |
DEQUANTIZE = 6 | |
EMBEDDING_LOOKUP = 7 | |
FLOOR = 8 | |
FULLY_CONNECTED = 9 | |
HASHTABLE_LOOKUP = 10 | |
L2_NORMALIZATION = 11 | |
L2_POOL_2D = 12 | |
LOCAL_RESPONSE_NORMALIZATION = 13 | |
LOGISTIC = 14 | |
LSH_PROJECTION = 15 | |
LSTM = 16 | |
MAX_POOL_2D = 17 | |
MUL = 18 | |
RELU = 19 | |
RELU1 = 20 | |
RELU6 = 21 | |
RESHAPE = 22 | |
RESIZE_BILINEAR = 23 | |
RNN = 24 | |
SOFTMAX = 25 | |
SPACE_TO_DEPTH = 26 | |
SVDF = 27 | |
TANH = 28 | |
BATCH_TO_SPACE_ND = 29 | |
DIV = 30 | |
MEAN = 31 | |
PAD = 32 | |
SPACE_TO_BATCH_ND = 33 | |
SQUEEZE = 34 | |
STRIDED_SLICE = 35 | |
SUB = 36 | |
TRANSPOSE = 37 | |
ABS = 38 | |
ARGMAX = 39 | |
ARGMIN = 40 | |
AXIS_ALIGNED_BBOX_TRANSFORM = 41 | |
BIDIRECTIONAL_SEQUENCE_LSTM = 42 | |
BIDIRECTIONAL_SEQUENCE_RNN = 43 | |
BOX_WITH_NMS_LIMIT = 44 | |
CAST = 45 | |
CHANNEL_SHUFFLE = 46 | |
DETECTION_POSTPROCESSING = 47 | |
EQUAL = 48 | |
EXP = 49 | |
EXPAND_DIMS = 50 | |
GATHER = 51 | |
GENERATE_PROPOSALS = 52 | |
GREATER = 53 | |
GREATER_EQUAL = 54 | |
GROUPED_CONV_2D = 55 | |
HEATMAP_MAX_KEYPOINT = 56 | |
INSTANCE_NORMALIZATION = 57 | |
LESS = 58 | |
LESS_EQUAL = 59 | |
LOG = 60 | |
LOGICAL_AND = 61 | |
LOGICAL_NOT = 62 | |
LOGICAL_OR = 63 | |
LOG_SOFTMAX = 64 | |
MAXIMUM = 65 | |
MINIMUM = 66 | |
NEG = 67 | |
NOT_EQUAL = 68 | |
PAD_V2 = 69 | |
POW = 70 | |
PRELU = 71 | |
QUANTIZE = 72 | |
QUANTIZED_16BIT_LSTM = 73 | |
RANDOM_MULTINOMIAL = 74 | |
REDUCE_ALL = 75 | |
REDUCE_ANY = 76 | |
REDUCE_MAX = 77 | |
REDUCE_MIN = 78 | |
REDUCE_PROD = 79 | |
REDUCE_SUM = 80 | |
ROI_ALIGN = 81 | |
ROI_POOLING = 82 | |
RSQRT = 83 | |
SELECT = 84 | |
SIN = 85 | |
SLICE = 86 | |
SPLIT = 87 | |
SQRT = 88 | |
TILE = 89 | |
TOPK_V2 = 90 | |
TRANSPOSE_CONV_2D = 91 | |
UNIDIRECTIONAL_SEQUENCE_LSTM = 92 | |
UNIDIRECTIONAL_SEQUENCE_RNN = 93 | |
RESIZE_NEAREST_NEIGHBOR = 94 | |
class NNAPI_FuseCode: | |
FUSED_NONE = 0 | |
FUSED_RELU = 1 | |
FUSED_RELU1 = 2 | |
FUSED_RELU6 = 3 | |
class OperandValueSourceType: | |
IMMEDIATE = 0 | |
NUMBERED_BUFFER = 2 | |
NUMBERED_MEMORY = 3 | |
# Scalar types that appear explicitly in models. | |
# These must be kept in sync with | |
# AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS. | |
# TODO: Expose these directly to Python to avoid maintaining this list. | |
class TorchScalarTypes(enum.Enum): | |
QUINT8 = 13 | |
def approx_equal(lhs, rhs, tolerance=1e-6): | |
return abs(lhs - rhs) <= tolerance * min(lhs, rhs) | |
def tensor_size(op_type, dims): | |
ITEM_SIZES = { | |
NNAPI_OperandCode.TENSOR_FLOAT32: 4, | |
NNAPI_OperandCode.TENSOR_INT32: 4, | |
NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1, | |
NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2, | |
NNAPI_OperandCode.TENSOR_QUANT16_ASYMM: 2, | |
} | |
size = ITEM_SIZES[op_type] | |
for d in dims: | |
size *= d | |
return size | |
def change_element(tup, index, value): | |
ls = list(tup) | |
ls[index] = value | |
return tuple(ls) | |
class ConvPoolArgs2d(NamedTuple): | |
"""Configuration arguments for a convolution.""" | |
kernel_h: int | |
kernel_w: int | |
stride_h: int | |
stride_w: int | |
pad_t: int | |
pad_b: int | |
pad_l: int | |
pad_r: int | |
dilation_h: int | |
dilation_w: int | |
group: int | |
class DimOrder(enum.Enum): | |
PRESUMED_CONTIGUOUS = 0 | |
CHANNELS_LAST = 1 | |
SCALAR_OR_VECTOR = 2 | |
UNKNOWN_CONSTANT = 999 | |
class Operand(NamedTuple): | |
"""Represenation of an NNAPI operand.""" | |
# NNAPI operand type. One of NNAPI_OperandCode. | |
# TODO: Make this an enum. | |
op_type: int | |
# This is always the PyTorch shape, which is NCHW for feature maps. | |
# The actual NNAPI operand might have a transposed shape. | |
# we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes | |
shape: Tuple[int, ...] | |
# Specifies how the shape of the operand that we define in NNAPI | |
# relates to the shape we track above. | |
# - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match | |
# the shape of the PyTorch tensor. | |
# - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and | |
# the NNAPI operand will be represented explicitly as NHWC. | |
dim_order: DimOrder | |
# Quantization params | |
scale: float | |
zero_point: int | |
def use_nchw(self): | |
if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS: | |
return True | |
if self.dim_order is DimOrder.CHANNELS_LAST: | |
return False | |
raise Exception("Unknown dim order") | |
def broadcast_shapes(shape1, shape2): | |
assert len(shape1) > 0 | |
assert len(shape2) > 0 | |
s1 = list(shape1) | |
s2 = list(shape2) | |
# TODO: Support non-equal-rank broadcast where semantics match. | |
# This can be tricky for NHWC tensors because dimension orders | |
# don't match between PT and NNAPI, even though semantics match. | |
if len(s1) > len(s2): | |
# s2 = [1] * (len(s1) - len(s2)) + s2 | |
raise Exception("Non-equal-rank broadcast is not supported yet.") | |
if len(s2) > len(s1): | |
# s3 = [1] * (len(s2) - len(s1)) + s1 | |
raise Exception("Non-equal-rank broadcast is not supported yet.") | |
ret = [] | |
for d1, d2 in zip(s1, s2): | |
if d1 == 1: | |
ret.append(d2) | |
elif d2 == 1: | |
ret.append(d1) | |
elif d1 == d2: | |
ret.append(d1) | |
else: | |
raise Exception(f"Cannot broadcast shapes: {shape1} and {shape2}") | |
return tuple(ret) | |
def get_conv_pool_shape(image_shape, args, out_ch, transpose): | |
batch, in_c, in_h, in_w = image_shape | |
# TODO: Handle dilation | |
if args.dilation_h != 1 or args.dilation_w != 1: | |
raise Exception("Dilation not supported yet.") | |
if transpose: | |
out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b | |
out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l | |
else: | |
out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1 | |
out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1 | |
# Handle variable-sized tensors. | |
if in_h == 0: | |
out_h = 0 | |
if in_w == 0: | |
out_w = 0 | |
out_shape = (batch, out_ch, out_h, out_w) | |
return out_shape | |
def fix_shape(shape, dim_order): | |
# Return the actual shape that an operand should have in NNAPI, | |
# given a PyTorch shape and dimension order. This is where we | |
# convert from PyTorch's "always NCHW" shape to explicit NHWC. | |
if dim_order is DimOrder.PRESUMED_CONTIGUOUS: | |
return shape | |
if dim_order is DimOrder.CHANNELS_LAST: | |
return tuple([shape[0]] + list(shape[2:]) + [shape[1]]) | |
if dim_order is DimOrder.SCALAR_OR_VECTOR: | |
assert len(shape) == 0 or len(shape) == 1 | |
return shape | |
if dim_order is DimOrder.UNKNOWN_CONSTANT: | |
# XXX think this through | |
return shape | |
raise Exception(f"Bad dim_order: {dim_order!r}.") | |
def reverse_map_dim(dim_order, d): | |
# Return the original PyTorch dimension position for a given dimension. | |
# d should be the dimension that NNAPI will see. | |
# reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x | |
# reverse_map_dim(CHANNELS_LAST, 3) == 1 | |
if dim_order in (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.SCALAR_OR_VECTOR): | |
return d | |
assert dim_order is DimOrder.CHANNELS_LAST | |
return [0, 2, 3, 1][d] | |
def flex_name(op_id, dim): | |
# Return the local variable name for the computed flexible size | |
# for a given op and dimension. | |
return f"s_{op_id}_{dim}" | |
class _NnapiSerializer: | |
def __init__(self, config, use_int16_for_qint16=False): | |
self.operands = [] | |
self.values = [] | |
self.operations = [] | |
self.value_data = [] | |
self.operation_args = [] | |
self.inputs = [] | |
self.outputs = [] | |
self.flexible_shape_computation_lines = [] | |
self.modules = {} | |
self.constants = {} | |
self.tensor_sequences = {} | |
self.jitval_operand_map = {} | |
self.cached_immediates = {} | |
self.used_weights = [] | |
self.weight_offset = 0 | |
self.use_int16_for_qint16 = use_int16_for_qint16 | |
if config is None: | |
config = {} | |
def get_next_operand_id(self): | |
return len(self.operands) | |
# Add a tensor operand corresponding to a JIT Value. | |
# Returns the NNAPI operand ID. Can be looked up later with | |
# get_tensor_operand_by_jitval. | |
def add_tensor_operand(self, jitval, oper): | |
assert isinstance(oper, Operand) | |
if jitval in self.jitval_operand_map: | |
raise Exception(f"Duplicate tensor: {jitval!r}") | |
operand_id = self.get_next_operand_id() | |
self.operands.append(oper) | |
self.jitval_operand_map[jitval] = operand_id | |
return operand_id | |
# Add a tensor operand that does not correspond to a JIT Value. | |
# Useful for cases where multiple NNAPI operands are required | |
# to implement one JIT IR node. Returns the NNAPI operand ID. | |
def add_anonymous_tensor_operand(self, oper): | |
assert isinstance(oper, Operand) | |
operand_id = self.get_next_operand_id() | |
self.operands.append(oper) | |
return operand_id | |
def torch_tensor_to_operand(self, tensor, dim_order): | |
dtype = str(tensor.dtype).replace("torch.", "") | |
scale = 0.0 | |
zero_point = 0 | |
if dtype == "float32": | |
op_type = NNAPI_OperandCode.TENSOR_FLOAT32 | |
elif dtype == "int32": | |
op_type = NNAPI_OperandCode.TENSOR_INT32 | |
elif dtype == "quint8": | |
op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM | |
scale = tensor.q_scale() | |
zero_point = tensor.q_zero_point() | |
elif dtype == "qint32": | |
op_type = NNAPI_OperandCode.TENSOR_INT32 | |
scale = tensor.q_scale() | |
zero_point = tensor.q_zero_point() | |
assert zero_point == 0 | |
elif dtype == "int16": | |
if self.use_int16_for_qint16: | |
nnapi_dtype = getattr(tensor, "nnapi_dtype", None) | |
op_codes = ( | |
NNAPI_OperandCode.TENSOR_QUANT16_SYMM, | |
NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, | |
) | |
if nnapi_dtype in op_codes: | |
op_type = nnapi_dtype | |
scale = tensor.nnapi_scale | |
zero_point = tensor.nnapi_zero_point | |
else: | |
raise Exception( | |
f"`nnapi_type` needs to be one of {op_codes} for `int16`" | |
) | |
else: | |
raise Exception( | |
"`int16` isn't supported. If you're trying to represent NNAPI" | |
" qint16 with Pytorch int16, set `use_int16_for_qint16 = True`" | |
) | |
else: | |
raise Exception(f"Can't handle input with dtype '{tensor.dtype}'") | |
return Operand( | |
shape=tuple(tensor.shape), | |
op_type=op_type, | |
dim_order=dim_order, | |
scale=scale, | |
zero_point=zero_point, | |
) | |
def add_tensor_operand_for_input(self, arg_idx, jitval, tensor): | |
dim_order = ( | |
DimOrder.CHANNELS_LAST | |
if getattr(tensor, "nnapi_nhwc", False) | |
else DimOrder.PRESUMED_CONTIGUOUS | |
) | |
toper = self.torch_tensor_to_operand(tensor, dim_order) | |
operand_id = self.add_tensor_operand(jitval, toper) | |
self.inputs.append(operand_id) | |
for dim, size in enumerate(tensor.shape): | |
if size == 0: | |
self.compute_operand_shape( | |
operand_id, dim, f"args[{arg_idx}].shape[{dim}]" | |
) | |
return operand_id | |
def add_tensor_operand_for_weight( | |
self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT | |
): | |
toper = self.torch_tensor_to_operand(tensor, dim_order) | |
operand_id = len(self.operands) | |
self.operands.append(toper) | |
tsize = tensor_size(toper.op_type, toper.shape) | |
psize = ((tsize - 1) | 0x3) + 1 | |
self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER)) | |
buf_num = len(self.used_weights) | |
offset = 0 | |
self.value_data.append(struct.pack("iii", buf_num, offset, tsize)) | |
# For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor | |
if dim_order == DimOrder.CHANNELS_LAST: | |
tensor = tensor.permute(0, 2, 3, 1) | |
self.used_weights.append(tensor) | |
return operand_id | |
def add_immediate_operand(self, code, value, dims): | |
assert isinstance(dims, tuple) | |
cache_key = (code, value) | |
if cache_key not in self.cached_immediates: | |
operand_id = len(self.operands) | |
self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0)) | |
self.values.append((operand_id, OperandValueSourceType.IMMEDIATE)) | |
self.value_data.append(value) | |
self.cached_immediates[cache_key] = operand_id | |
return self.cached_immediates[cache_key] | |
def add_immediate_int_scalar(self, value): | |
return self.add_immediate_operand( | |
NNAPI_OperandCode.INT32, struct.pack("i", value), () | |
) | |
def add_immediate_float_scalar(self, value): | |
return self.add_immediate_operand( | |
NNAPI_OperandCode.FLOAT32, struct.pack("f", value), () | |
) | |
def add_immediate_bool_scalar(self, value): | |
return self.add_immediate_operand( | |
NNAPI_OperandCode.BOOL, b"\x01" if value else b"\x00", () | |
) | |
def add_immediate_int_vector(self, value): | |
return self.add_immediate_operand( | |
NNAPI_OperandCode.TENSOR_INT32, | |
array.array("i", value).tobytes(), | |
(len(value),), | |
) | |
def has_operand_for_jitval(self, jitval): | |
return jitval in self.jitval_operand_map | |
def get_tensor_operand_by_jitval(self, jitval): | |
operand_id = self.jitval_operand_map[jitval] | |
return (operand_id, self.operands[operand_id]) | |
def get_tensor_operand_by_jitval_fixed_size(self, jitval): | |
op_id, oper = self.get_tensor_operand_by_jitval(jitval) | |
for s in oper.shape: | |
if s == 0: | |
# TODO: Improve this error message, possibly after converting | |
# many callsites to support flexible size. | |
raise Exception("Flexible size is not supported for this operand.") | |
if s < 0: | |
# runtime flex | |
LOG.warning("Operand %s has runtime flex shape", oper) | |
return op_id, oper | |
def get_tensor_operand_or_constant( | |
self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS | |
): | |
operand_id = self.jitval_operand_map.get(jitval) | |
if operand_id is None: | |
_, value = self.get_constant_value(jitval, "TensorType") | |
operand_id = self.add_tensor_operand_for_weight(value, dim_order) | |
return (operand_id, self.operands[operand_id]) | |
def get_tensor_operand_for_weight(self, jitval): | |
_, value = self.get_constant_value(jitval, "TensorType") | |
operand_id = self.add_tensor_operand_for_weight(value) | |
return (operand_id, self.operands[operand_id]) | |
def add_operation(self, opcode, inputs, outputs): | |
self.operations.append((opcode, len(inputs), len(outputs))) | |
self.operation_args.extend(inputs + outputs) | |
def add_tensor_sequence(self, jitval, values): | |
assert jitval not in self.tensor_sequences | |
self.tensor_sequences[jitval] = values | |
def add_constant_value(self, jitval, ctype, value): | |
assert jitval not in self.constants | |
self.constants[jitval] = (ctype, value) | |
def get_constant_value(self, jitval, typekind=None): | |
record = self.constants.get(jitval) | |
if record is None: | |
raise Exception(f"Could not find constant value for '{jitval!r}'.") | |
ctype, _ = record | |
if typekind is not None and ctype.kind() != typekind: | |
raise Exception( | |
f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'" | |
) | |
return record | |
def operand_to_template_torchscript(self, op_id, oper, shape=None): | |
"""Return a TorchScript expression to build a template for a given operand.""" | |
if shape is None: | |
shape = oper.shape | |
else: | |
assert len(shape) == len(oper.shape) | |
shape_parts = ["("] | |
for d, s in enumerate(shape): | |
if s > 0: | |
# Fixed shape dimension: just add the value. | |
shape_parts.append(str(s)) | |
elif s == 0: | |
# Load time flexible shape dimension: it should have been computed in a variable. | |
shape_parts.append(flex_name(op_id, d)) | |
elif s == -1: | |
# Runtime flexible shape | |
shape_parts.append("0") | |
else: | |
raise Exception("Unknown dim value, dimensions should be >= -1") | |
shape_parts.append(",") | |
shape_parts.append(")") | |
shape_code = "".join(shape_parts) | |
if oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32: | |
return f"torch.zeros({shape_code}, dtype=torch.float32)" | |
elif oper.op_type == NNAPI_OperandCode.TENSOR_INT32: | |
return f"torch.zeros({shape_code}, dtype=torch.int32)" | |
elif oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: | |
return ( | |
f"torch.quantize_per_tensor(" | |
f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)" | |
f".expand({shape_code}).contiguous()" | |
) | |
elif oper.op_type in ( | |
NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, | |
NNAPI_OperandCode.TENSOR_QUANT16_SYMM, | |
): | |
if self.use_int16_for_qint16: | |
return f"torch.zeros({shape_code}, dtype=torch.int16)" | |
else: | |
raise Exception( | |
"`int16` isn't supported. If you're trying to represent NNAPI" | |
" qint16 with Pytorch int16, set `use_int16_for_qint16 = True`" | |
) | |
raise Exception(f"Unsupported output operand type: {oper.op_type}") | |
def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim): | |
self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim)) | |
def compute_operand_shape(self, op_id, dim, expr): | |
self.flexible_shape_computation_lines.append( | |
f"{flex_name(op_id, dim)} = {expr}" | |
) | |
def transpose_to_nhwc(self, in_id, oper): | |
if oper.shape[2:] != (1, 1): | |
raise Exception("Automatic transpose only supported for H,W == 1,1") | |
out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST) | |
inputs = [None] * 2 | |
inputs[0] = in_id | |
inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1]) | |
outputs = [None] * 1 | |
outputs[0] = self.add_anonymous_tensor_operand(out_oper) | |
self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs) | |
return outputs[0], out_oper | |
# Transpose inputs as necessary to allow broadcasting. | |
def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper): | |
if in0_oper.dim_order == in1_oper.dim_order: | |
return in0_id, in0_oper, in1_id, in1_oper | |
# Assume NHWC is preferred if there is a mismatch. | |
orders = (in0_oper.dim_order, in1_oper.dim_order) | |
if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST): | |
return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper) | |
if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS): | |
return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper) | |
raise Exception( | |
f"Automatic transpose not supported for dim_orders: {in0_oper.dim_order!r}, {in1_oper.dim_order!r}" | |
) | |
def get_size_arg(self, jitval): | |
ctype, value = self.get_constant_value(jitval) | |
if ctype.kind() == "ListType": | |
assert ctype.getElementType().kind() == "IntType" | |
return value | |
raise Exception(f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'") | |
def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config): | |
pc = [i.item() for i in packed_config] | |
assert pc[0] == 2 | |
strides = [pc[1], pc[2]] | |
paddings = [pc[3], pc[4]] | |
dilations = [pc[5], pc[6]] | |
output_padding = [pc[7], pc[8]] | |
group_num = pc[9] | |
assert len(pc) == 11 | |
assert output_padding == [0, 0] | |
return self.get_conv_pool_args_2d_common( | |
kernel_size, strides, paddings, dilations, group_num | |
) | |
def get_conv_pool_args_2d_from_jit( | |
self, kernel_size, stride, padding, dilation=None, group=None | |
): | |
strides = self.get_size_arg(stride) | |
paddings = self.get_size_arg(padding) | |
if dilation is None: | |
dilations = [1, 1] | |
else: | |
dilations = self.get_size_arg(dilation) | |
if group is not None: | |
_, group_num = self.get_constant_value(group, "IntType") | |
else: | |
group_num = None | |
return self.get_conv_pool_args_2d_common( | |
kernel_size, strides, paddings, dilations, group_num | |
) | |
def get_conv_pool_args_2d_common( | |
self, kernel_size, strides, paddings, dilations, group_num | |
): | |
kernels = list(kernel_size) | |
assert len(kernels) == 2 | |
assert len(strides) == 2 | |
assert len(paddings) == 2 | |
assert len(dilations) == 2 | |
# NNAPI uses 4 values for padding. | |
ph, pw = paddings | |
real_paddings = [ph, ph, pw, pw] | |
return ConvPoolArgs2d( | |
*(kernels + strides + real_paddings + dilations + [group_num]) | |
) | |
def serialize_model(self, model, inputs, return_shapes=None): | |
self.add_immediate_bool_scalar(False) | |
self.add_immediate_bool_scalar(True) | |
inp_dim_orders = [] | |
out_dim_orders = [] | |
self_jitval = next(model.graph.inputs()) | |
self.add_constant_value(self_jitval, self_jitval.type(), model) | |
for arg_idx, (input_value, input_tensor) in enumerate( | |
zip(list(model.graph.inputs())[1:], inputs) | |
): | |
op_id = self.add_tensor_operand_for_input( | |
arg_idx, input_value, input_tensor | |
) | |
inp_dim_orders.append(self.operands[op_id].dim_order.value) | |
for idx, node in enumerate(model.graph.nodes()): | |
LOG.debug("Processing node #%d: %r", idx, node) | |
self.add_node(node) | |
retn = model.graph.return_node() | |
assert retn.inputsSize() == 1 | |
assert retn.outputsSize() == 0 | |
retn_input = retn.inputsAt(0) | |
template_return_lines = ["return ["] | |
if retn_input.type().kind() == "TensorType": | |
return_values = [retn_input] | |
retval_count = -1 | |
elif retn_input.type().kind() == "TupleType": | |
return_values = self.tensor_sequences[retn_input] | |
retval_count = len(return_values) | |
else: | |
raise Exception(f"Unsupported return type: {retn_input.type()}") | |
if return_shapes is not None: | |
assert len(return_shapes) == len(return_values) | |
for i, v in enumerate(return_values): | |
op_id = self.jitval_operand_map[v] | |
self.outputs.append(op_id) | |
out_dim_orders.append(self.operands[op_id].dim_order.value) | |
shape = return_shapes[i] if return_shapes else None | |
template_return_lines.append( | |
self.operand_to_template_torchscript(op_id, self.operands[op_id], shape) | |
+ "," | |
) | |
template_return_lines.append("]") | |
model = [] | |
version = 1 | |
header = struct.pack( | |
"iiiiii", | |
version, | |
len(self.operands), | |
len(self.values), | |
len(self.operations), | |
len(self.inputs), | |
len(self.outputs), | |
) | |
model.append(header) | |
serialized_values, serialized_value_data = self.serialize_values() | |
model.extend( | |
struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands | |
) | |
model.extend(serialized_values) | |
model.extend(struct.pack("iii", *x) for x in self.operations) | |
# Compact the model so we can get its length so far. | |
model = [b"".join(model)] | |
model_offset = len(model[0]) | |
# Model offset is the index into the model (in 32-bit words, not bytes) | |
# of the next dimension we're about to serialize. If it's 0, | |
# generate code to mutate it before passing to NNAPI. | |
assert model_offset % 4 == 0 | |
model_offset = int(model_offset / 4) | |
for op_id, (_, dims, dim_order, _, _) in enumerate(self.operands): | |
shape = fix_shape(dims, dim_order) | |
for d, s in enumerate(shape): | |
if s == 0: | |
pt_d = reverse_map_dim(dim_order, d) | |
self.flexible_shape_computation_lines.append( | |
f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}" | |
) | |
model_offset += 1 | |
# convert runtime flex shape from -1 to 0 | |
shape = tuple(d if d != -1 else 0 for d in shape) | |
model.append(self.serialize_ints(shape)) | |
model.extend(serialized_value_data) | |
model.append(self.serialize_ints(self.operation_args)) | |
model.append(self.serialize_ints(self.inputs)) | |
model.append(self.serialize_ints(self.outputs)) | |
self.flexible_shape_computation_lines.extend(template_return_lines) | |
return ( | |
array.array("i", b"".join(model)), | |
self.used_weights, | |
inp_dim_orders, | |
out_dim_orders, | |
self.flexible_shape_computation_lines, | |
retval_count, | |
) | |
def serialize_values(self): | |
serialized_values = [] | |
serialized_value_data = [] | |
assert len(self.values) == len(self.value_data) | |
for (op_index, source_type), data in zip(self.values, self.value_data): | |
source_length = len(data) | |
# Pad with 0 bytes out to a multiple of 4 for alignment. | |
physical_length = ((source_length - 1) | 0x3) + 1 | |
padded_data = data + (b"\0" * (physical_length - source_length)) | |
serialized_values.append( | |
struct.pack("iii", op_index, source_type, source_length) | |
) | |
serialized_value_data.append(padded_data) | |
return serialized_values, serialized_value_data | |
def serialize_ints(ints): | |
return array.array("i", ints).tobytes() | |
ADDER_MAP = { | |
"prim::GetAttr": lambda self, node: self.add_getattr(node), | |
"prim::Constant": lambda self, node: self.add_constant_node(node), | |
"prim::ListConstruct": lambda self, node: self.add_list_construct(node), | |
"prim::TupleConstruct": lambda self, node: self.add_tuple_construct(node), | |
"aten::unsqueeze": lambda self, node: self.add_unsqueeze(node), | |
"aten::to": lambda self, node: self.add_to(node), | |
"aten::detach": lambda self, node: self._identity(node), | |
"aten::reshape": lambda self, node: self.add_reshape(node), | |
"aten::flatten": lambda self, node: self.add_flatten(node), | |
"aten::slice": lambda self, node: self.add_slice(node), | |
"aten::size": lambda self, node: self.add_size(node), | |
"aten::cat": lambda self, node: self.add_cat(node), | |
"aten::mean": lambda self, node: self.add_mean(node), | |
"aten::quantize_per_tensor": lambda self, node: self.add_quantize(node), | |
"aten::dequantize": lambda self, node: self.add_dequantize(node), | |
"aten::add": lambda self, node: self.add_add_sub_op( | |
node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE | |
), | |
"aten::sub": lambda self, node: self.add_add_sub_op( | |
node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE | |
), | |
"aten::mul": lambda self, node: self.add_pointwise_simple_binary_broadcast_op( | |
node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE | |
), | |
"aten::div": lambda self, node: self.add_pointwise_simple_binary_broadcast_op( | |
node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE | |
), | |
"aten::relu": lambda self, node: self.add_pointwise_simple_unary_op( | |
node, NNAPI_OperationCode.RELU | |
), | |
"aten::sigmoid": lambda self, node: self.add_pointwise_simple_unary_op( | |
node, NNAPI_OperationCode.LOGISTIC | |
), | |
"aten::softmax": lambda self, node: self.add_softmax(node), | |
"aten::hardtanh": lambda self, node: self.add_hardtanh(node), | |
"aten::avg_pool2d": lambda self, node: self.add_avg_pool2d(node), | |
"aten::max_pool2d": lambda self, node: self.add_pool2d_node( | |
node, NNAPI_OperationCode.MAX_POOL_2D | |
), | |
"aten::adaptive_avg_pool2d": lambda self, node: self.add_adaptive_avg_pool2d( | |
node | |
), | |
"aten::upsample_nearest2d": lambda self, node: self.add_upsample_nearest2d( | |
node | |
), | |
"aten::prelu": lambda self, node: self.add_prelu_op(node), | |
"aten::addmm": lambda self, node: self.add_addmm(node), | |
"aten::linear": lambda self, node: self.add_linear(node), | |
"aten::_convolution": lambda self, node: self.add_conv_underscore(node), | |
"aten::conv2d": lambda self, node: self.add_conv2d(node), | |
"aten::log_softmax": lambda self, node: self.add_log_softmax(node), | |
"quantized::linear": lambda self, node: self.add_qlinear(node), | |
"quantized::conv2d": lambda self, node: self.add_qconv2d( | |
node, NNAPI_FuseCode.FUSED_NONE | |
), | |
"quantized::conv2d_relu": lambda self, node: self.add_qconv2d( | |
node, NNAPI_FuseCode.FUSED_RELU | |
), | |
"quantized::conv_transpose2d": lambda self, node: self.add_qconv2d( | |
node, NNAPI_FuseCode.FUSED_NONE, transpose=True | |
), | |
"quantized::add": lambda self, node: self.add_qadd( | |
node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE | |
), | |
"quantized::add_relu": lambda self, node: self.add_qadd( | |
node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU | |
), | |
"quantized::mul": lambda self, node: self.add_qadd( | |
node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE | |
), | |
} | |
def add_node(self, node): | |
adder = self.ADDER_MAP.get(node.kind()) | |
if not adder: | |
raise Exception(f"Unsupported node kind ({node.kind()!r}) in node {node!r}") | |
adder(self, node) | |
def _identity(self, node): | |
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) | |
jitval = node.outputsAt(0) | |
self.jitval_operand_map[jitval] = in_id | |
def add_getattr(self, node): | |
assert node.inputsSize() == 1 | |
assert node.outputsSize() == 1 | |
obj_ctype, obj = self.get_constant_value(node.inputsAt(0)) | |
assert str(obj_ctype).startswith("__torch__.") | |
name = node.s("name") | |
value = getattr(obj, name) | |
output = node.outputsAt(0) | |
ctype = output.type() | |
self.add_constant_value(output, ctype, value) | |
def add_constant_node(self, node): | |
assert node.inputsSize() == 0 | |
assert node.outputsSize() == 1 | |
output = node.outputsAt(0) | |
ctype = output.type() | |
value = output.toIValue() | |
self.add_constant_value(output, ctype, value) | |
def add_list_construct(self, node): | |
assert node.outputsSize() == 1 | |
output = node.outputsAt(0) | |
ctype = output.type() | |
const_vals: Optional[List] = [] | |
tensors: Optional[List] = [] | |
for inp in node.inputs(): | |
if const_vals is not None and inp in self.constants: | |
_, val = self.get_constant_value(inp) | |
const_vals.append(val) | |
else: | |
const_vals = None | |
if tensors is not None and inp.type().kind() == "TensorType": | |
tensors.append(inp) | |
else: | |
tensors = None | |
if const_vals is not None: | |
# NOTE: Now that TorchScript supports list constants, | |
# this code path might not be used anymore. | |
self.add_constant_value(output, ctype, const_vals) | |
if tensors is not None: | |
self.add_tensor_sequence(output, tensors) | |
if const_vals is None and tensors is None: | |
raise Exception( | |
f"Unable to handle ListConstruct node. Neither all constants nor all tensors. {node!r}" | |
) | |
def add_tuple_construct(self, node): | |
assert node.outputsSize() == 1 | |
output = node.outputsAt(0) | |
values = list(node.inputs()) | |
self.add_tensor_sequence(output, values) | |
def add_unsqueeze(self, node): | |
assert node.inputsSize() == 2 | |
assert node.outputsSize() == 1 | |
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) | |
_, dim = self.get_constant_value(node.inputsAt(1), "IntType") | |
assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS | |
real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1 | |
out_shape_list = list(in_oper.shape) | |
out_shape_list.insert(real_dim, 1) | |
out_shape = tuple(out_shape_list) | |
out_oper = in_oper._replace(shape=out_shape) | |
inputs = [None] * 2 | |
inputs[0] = in_id | |
inputs[1] = self.add_immediate_int_scalar(dim) | |
outputs = [None] * 1 | |
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) | |
self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs) | |
def add_to(self, node): | |
# Handle to("cpu") / to("gpu") case | |
self._identity(node) | |
def add_reshape(self, node): | |
assert node.inputsSize() == 2 | |
assert node.outputsSize() == 1 | |
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) | |
shape_ctype, shape = self.get_constant_value(node.inputsAt(1)) | |
assert shape_ctype.kind() == "ListType" | |
assert shape_ctype.getElementType().kind() == "IntType" | |
is_trivial_reshape = len(shape) == 2 and shape[1] == -1 | |
if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape: | |
raise Exception( | |
"Currently, reshape is only supported on NHWC tensors if the target size is [X, -1]." | |
) | |
# Bit of a hack here. Use a real tensor to infer the output shape. | |
out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape | |
out_oper = in_oper._replace( | |
shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS | |
) | |
inputs = [None] * 2 | |
inputs[0] = in_id | |
inputs[1] = self.add_immediate_int_vector(shape) | |
outputs = [None] * 1 | |
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) | |
self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs) | |
def add_flatten(self, node): | |
assert node.inputsSize() == 3 | |
assert node.outputsSize() == 1 | |
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) | |
start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType") | |
end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType") | |
# channels last with channels == 1 or (height & width both 1) | |
is_trivial_flatten = len(in_oper.shape) == 4 and ( | |
in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1) | |
) | |
if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_flatten: | |
raise Exception( | |
"Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1" | |
) | |
if start_dim < 0: | |
start_dim += len(in_oper.shape) | |
if end_dim < 0: | |
end_dim += len(in_oper.shape) | |
out_shape = ( | |
in_oper.shape[:start_dim] | |
+ (functools.reduce(operator.mul, in_oper.shape[start_dim : end_dim + 1]),) | |
+ in_oper.shape[end_dim + 1 :] | |
) | |
if any(dim == 0 for dim in in_oper.shape[start_dim : end_dim + 1]): | |
raise Exception("Flattening flexible dims is not supported yet") | |
non_flattened_dims = in_oper.shape[:start_dim] + in_oper.shape[end_dim + 1 :] | |
if non_flattened_dims.count(0) > 1: | |
raise Exception("Only 1 dim can be flexible") | |
out_oper = in_oper._replace( | |
shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS | |
) | |
out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) | |
for idx, dim in enumerate(out_shape): | |
if dim == 0: | |
self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0)) | |
inputs_1 = tuple(dim if dim != 0 else -1 for dim in out_shape) | |
inputs = [None] * 2 | |
inputs[0] = in_id | |
inputs[1] = self.add_immediate_int_vector(inputs_1) | |
outputs = [None] * 1 | |
outputs[0] = out_id | |
self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs) | |
def add_slice(self, node): | |
assert node.inputsSize() == 5 | |
assert node.outputsSize() == 1 | |
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) | |
_, dim_value = self.get_constant_value(node.inputsAt(1)) | |
_, start_value = self.get_constant_value(node.inputsAt(2)) | |
_, stop_value = self.get_constant_value(node.inputsAt(3)) | |
_, step_value = self.get_constant_value(node.inputsAt(4)) | |
if start_value is None: | |
start_value = 0 | |
if stop_value is None: | |
stop_value = sys.maxsize | |
if start_value < 0: | |
start_value += in_oper.shape[dim_value] | |
elif start_value == sys.maxsize: | |
start_value = 0 | |
if start_value == 0 and stop_value == sys.maxsize: | |
self._identity(node) | |
return | |
if in_oper.shape[dim_value] == 0: | |
raise Exception("Unable to slice with flexible shape") | |
if stop_value < 0: | |
stop_value += in_oper.shape[dim_value] | |
elif stop_value == sys.maxsize: | |
stop_value = in_oper.shape[dim_value] | |
if start_value >= stop_value: | |
raise Exception("Slice start value should be less than stop value") | |
out_len = (stop_value - start_value) // step_value | |
out_shape = tuple( | |
out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape) | |
) | |
out_id = self.add_tensor_operand( | |
node.outputsAt(0), in_oper._replace(shape=out_shape) | |
) | |
# flex inputs | |
end_mask = 0 | |
for idx, dim in enumerate(out_shape): | |
if dim == 0: | |
self.forward_operand_shape(out_id, idx, in_id, idx) | |
end_mask |= 1 << idx | |
inputs = [None] * 7 | |
inputs[0] = in_id | |
inputs[1] = self.add_immediate_int_vector( | |
[start_value if i == dim_value else 0 for i in range(len(in_oper.shape))] | |
) | |
inputs[2] = self.add_immediate_int_vector( | |
[ | |
stop_value if i == dim_value else dim | |
for i, dim in enumerate(in_oper.shape) | |
] | |
) | |
inputs[3] = self.add_immediate_int_vector( | |
[step_value if i == dim_value else 1 for i in range(len(in_oper.shape))] | |
) | |
inputs[4] = self.add_immediate_int_scalar(0) # begin mask | |
inputs[5] = self.add_immediate_int_scalar(end_mask) | |
inputs[6] = self.add_immediate_int_scalar(0) # shrink axis mas | |
outputs = [None] * 1 | |
outputs[0] = out_id | |
self.add_operation(NNAPI_OperationCode.STRIDED_SLICE, inputs, outputs) | |
def add_size(self, node): | |
assert node.inputsSize() == 2 | |
assert node.outputsSize() == 1 | |
_, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) | |
_, value = self.constants[node.inputsAt(1)] | |
res = in_oper.shape[value] | |
output = node.outputsAt(0) | |
self.add_constant_value(output, output.type(), res) | |
def add_cat(self, node): | |
assert node.inputsSize() == 2 | |
assert node.outputsSize() == 1 | |
tensors = self.tensor_sequences[node.inputsAt(0)] | |
_, dim = self.get_constant_value(node.inputsAt(1), "IntType") | |
assert len(tensors) > 0 | |
in_ids = [] | |
out_oper = None | |
out_dim_size = 0 | |
for inp in tensors: | |
in_id, in_oper = self.get_tensor_operand_by_jitval(inp) | |
if out_oper is None: | |
out_shape = change_element(in_oper.shape, dim, -1) | |
out_oper = in_oper._replace(shape=out_shape) | |
assert in_oper.op_type == out_oper.op_type | |
assert in_oper.dim_order == out_oper.dim_order | |
assert change_element(in_oper.shape, dim, -1) == change_element( | |
out_oper.shape, dim, -1 | |
) | |
# TODO: Possibly check scale and zero point. | |
in_ids.append(in_id) | |
# TODO: Possibly support variable-sized inputs. | |
out_dim_size += in_oper.shape[dim] | |
assert out_oper is not None | |
out_oper = out_oper._replace( | |
shape=change_element(out_oper.shape, dim, out_dim_size) | |
) | |
if in_oper.dim_order == DimOrder.CHANNELS_LAST: # type: ignore[possibly-undefined] | |
assert len(out_oper.shape) == 4 | |
nnapi_dim = [0, 3, 1, 2][dim] | |
else: | |
nnapi_dim = dim | |
out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) | |
for idx, d in enumerate(out_oper.shape): | |
if d == 0: | |
if idx == dim: | |
shape = " + ".join(flex_name(ip_id, dim) for ip_id in in_ids) | |
self.compute_operand_shape(out_id, idx, shape) | |
else: | |
self.forward_operand_shape(out_id, idx, in_ids[0], idx) | |
inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)] | |
outputs = [None] * 1 | |
outputs[0] = out_id | |
self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs) | |
def add_mean(self, node): | |
assert node.inputsSize() == 4 | |
assert node.outputsSize() == 1 | |
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) | |
dim_ctype, dim = self.get_constant_value(node.inputsAt(1)) | |
assert dim_ctype.kind() == "ListType" | |
assert dim_ctype.getElementType().kind() == "IntType" | |
_, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType") | |
# Expect None for dtype | |
self.get_constant_value(node.inputsAt(3), "NoneType") | |
if in_oper.dim_order == DimOrder.CHANNELS_LAST: | |
assert len(in_oper.shape) == 4 | |
nnapi_dim = [[0, 3, 1, 2][d] for d in dim] | |
else: | |
nnapi_dim = dim | |
collapsed_dims = set() | |
for d in dim: | |
if d < 0: | |
d += len(in_oper.shape) | |
collapsed_dims.add(d) | |
if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim: | |
assert collapsed_dims.issuperset({2, 3}) | |
out_dim_order = DimOrder.PRESUMED_CONTIGUOUS | |
else: | |
out_dim_order = in_oper.dim_order | |
out_shape = [] | |
for i, s in enumerate(in_oper.shape): | |
if i not in collapsed_dims: | |
out_shape.append(s) | |
elif keep_dim: | |
out_shape.append(1) | |
out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order) | |
inputs = [None] * 3 | |
inputs[0] = in_id | |
inputs[1] = self.add_immediate_int_vector(nnapi_dim) | |
inputs[2] = self.add_immediate_int_scalar(keep_dim) | |
outputs = [None] * 1 | |
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) | |
self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs) | |
def add_quantize(self, node): | |
assert node.inputsSize() == 4 | |
assert node.outputsSize() == 1 | |
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) | |
if in_oper.dim_order != DimOrder.CHANNELS_LAST: | |
raise Exception( | |
"Most hardware backends prefer NHWC quantized tensors. " | |
"Try setting `t.nnapi_nhwc = True` on your tensor inputs. " | |
) | |
_, scale = self.get_constant_value(node.inputsAt(1), "FloatType") | |
_, zero_point = self.get_constant_value(node.inputsAt(2), "IntType") | |
_, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType") | |
if scalar_type != TorchScalarTypes.QUINT8.value: | |
raise Exception( | |
"PyTorch NNAPI export only supports quantized tensors " | |
"with the quint8 dtype." | |
) | |
op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM | |
out_oper = in_oper._replace( | |
op_type=op_type, | |
scale=scale, | |
zero_point=zero_point, | |
) | |
inputs = [None] * 1 | |
inputs[0] = in_id | |
outputs = [None] * 1 | |
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) | |
self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs) | |
def add_dequantize(self, node): | |
assert node.inputsSize() == 1 | |
assert node.outputsSize() == 1 | |
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) | |
out_oper = in_oper._replace( | |
op_type=NNAPI_OperandCode.TENSOR_FLOAT32, | |
scale=0.0, | |
zero_point=0, | |
) | |
inputs = [None] * 1 | |
inputs[0] = in_id | |
outputs = [None] * 1 | |
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) | |
self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs) | |
def add_pointwise_simple_unary_op(self, node, opcode): | |
assert node.inputsSize() == 1 | |
assert node.outputsSize() == 1 | |
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) | |
out_oper = in_oper | |
if opcode == NNAPI_OperationCode.LOGISTIC: | |
# NNAPI docs: For ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, the scale | |
# must be 1.f / 256 and the zeroPoint must be 0. | |
# https://fburl.com/h52stoog | |
if in_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: | |
out_oper = in_oper._replace(zero_point=0, scale=1.0 / 256) | |
out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) | |
for idx, dim in enumerate(in_oper.shape): | |
if dim == 0: | |
self.forward_operand_shape(out_id, idx, in_id, idx) | |
inputs = [None] * 1 | |
inputs[0] = in_id | |
outputs = [None] * 1 | |
outputs[0] = out_id | |
self.add_operation(opcode, inputs, outputs) | |
def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None): # noqa: D401 | |
"""Helper for pointwise binary broadcast ops with superfluous extra args.""" | |
assert node.outputsSize() == 1 | |
assert node.inputsAt(0).type().kind() == "TensorType" | |
assert node.inputsAt(1).type().kind() == "TensorType" | |
if self.has_operand_for_jitval(node.inputsAt(0)): | |
in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) | |
in1_id, in1_oper = self.get_tensor_operand_or_constant( | |
node.inputsAt(1), in0_oper.dim_order | |
) | |
elif self.has_operand_for_jitval(node.inputsAt(1)): | |
in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1)) | |
in0_id, in0_oper = self.get_tensor_operand_or_constant( | |
node.inputsAt(0), in1_oper.dim_order | |
) | |
else: | |
raise Exception(f"Can't do a NNAPI binary op: {opcode} on two constants") | |
assert in0_oper.op_type == in1_oper.op_type | |
in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast( | |
in0_id, in0_oper, in1_id, in1_oper | |
) | |
# NOTE: PyTorch and NNAPI have the same broadcast semantics. | |
out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape) | |
out_oper = in0_oper._replace(shape=out_shape) | |
if qparams is not None: | |
scale, zp = qparams | |
out_oper = out_oper._replace(scale=scale, zero_point=zp) | |
out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) | |
for idx, (d0, d1) in enumerate(zip(in0_oper.shape, in1_oper.shape)): | |
if d0 == 1 and d1 == 0: | |
self.forward_operand_shape(out_id, idx, in1_id, idx) | |
elif d0 == 0 and d1 == 1: | |
self.forward_operand_shape(out_id, idx, in0_id, idx) | |
elif d0 == 0 and d1 == 0: | |
self.flexible_shape_computation_lines.append( | |
f"assert {flex_name(in0_id, idx)} == {flex_name(in1_id, idx)}" | |
) | |
self.forward_operand_shape(out_id, idx, in0_id, idx) | |
inputs = [None] * 3 | |
inputs[0] = in0_id | |
inputs[1] = in1_id | |
inputs[2] = self.add_immediate_int_scalar(fuse_code) | |
outputs = [None] * 1 | |
outputs[0] = out_id | |
self.add_operation(opcode, inputs, outputs) | |
def add_pointwise_simple_binary_broadcast_op(self, node, opcode, fuse_code): | |
assert node.inputsSize() == 2 | |
self._do_add_binary(node, opcode, fuse_code) | |
def add_add_sub_op(self, node, opcode, fuse_code): | |
assert node.inputsSize() == 3 | |
_, alpha = self.get_constant_value(node.inputsAt(2), "IntType") | |
if alpha != 1: | |
raise Exception("NNAPI does not support add/sub with alpha.") | |
self._do_add_binary(node, opcode, fuse_code) | |
def add_qadd(self, node, opcode, fuse_code): | |
assert node.inputsSize() == 4 | |
_, scale = self.get_constant_value(node.inputsAt(2), "FloatType") | |
_, zero_point = self.get_constant_value(node.inputsAt(3), "IntType") | |
self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point)) | |
def add_softmax(self, node): | |
assert node.inputsSize() == 3 | |
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) | |
_, softmax_dim = self.get_constant_value(node.inputsAt(1), "IntType") | |
out_id = self.add_tensor_operand(node.outputsAt(0), in_oper) | |
for dim, size in enumerate(in_oper.shape): | |
if size == 0: | |
self.forward_operand_shape(out_id, dim, in_id, dim) | |
inputs = [None] * 3 | |
inputs[0] = in_id | |
inputs[1] = self.add_immediate_float_scalar( | |
1.0 | |
) # positive scaling factor of exponent, beta | |
inputs[2] = self.add_immediate_int_scalar(softmax_dim) | |
outputs = [None] * 1 | |
outputs[0] = out_id | |
self.add_operation(NNAPI_OperationCode.SOFTMAX, inputs, outputs) | |
def add_hardtanh(self, node): | |
assert node.inputsSize() == 3 | |
assert node.outputsSize() == 1 | |
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) | |
_, min_val = self.get_constant_value(node.inputsAt(1), "FloatType") | |
_, max_val = self.get_constant_value(node.inputsAt(2), "FloatType") | |
op_map = { | |
(-1, 1): NNAPI_OperationCode.RELU1, | |
(0, 6): NNAPI_OperationCode.RELU6, # noqa: E201 | |
} | |
opcode = op_map.get((min_val, max_val)) | |
if opcode is None: | |
raise Exception("NNAPI only supports hardtanh with args (-1, 1) or (0, 6).") | |
inputs = [None] * 1 | |
inputs[0] = in_id | |
outputs = [None] * 1 | |
outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper) | |
self.add_operation(opcode, inputs, outputs) | |
def add_prelu_op(self, node): | |
assert node.inputsSize() == 2 | |
assert node.outputsSize() == 1 | |
assert node.inputsAt(0).type().kind() == "TensorType" | |
assert node.inputsAt(1).type().kind() == "TensorType" | |
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) | |
w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1)) | |
assert len(w_oper.shape) == 1 | |
assert w_oper.shape[0] > 0 | |
if w_oper.shape[0] > 1: | |
if in_oper.use_nchw(): | |
# TODO: Support this by adding trailing 1 dims. | |
raise Exception( | |
"Per-channel PReLU only supports channels_last right now." | |
) | |
out_id = self.add_tensor_operand(node.outputsAt(0), in_oper) | |
for dim, size in enumerate(in_oper.shape): | |
if size > 0: | |
pass | |
elif dim <= 1: | |
raise Exception("PReLU requires fixed size for dim 0 and dim 1.") | |
else: | |
self.forward_operand_shape(out_id, dim, in_id, dim) | |
inputs = [None] * 2 | |
inputs[0] = in_id | |
inputs[1] = w_id | |
outputs = [None] * 1 | |
outputs[0] = out_id | |
self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs) | |
def add_pool2d_node(self, node, opcode): | |
assert node.inputsSize() == 6 | |
assert node.outputsSize() == 1 | |
image, kernel, stride, padding, dilation, ceil_mode = node.inputs() | |
stride = stride or kernel | |
# TODO: Validate ceil_mode semantics. | |
args = self.get_conv_pool_args_2d_from_jit( | |
self.get_size_arg(kernel), stride, padding, dilation | |
) | |
if args.dilation_h != 1 or args.dilation_w != 1: | |
raise Exception("NNAPI does not support dilated pooling.") | |
image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image) | |
assert len(image_oper.shape) == 4 | |
out_shape = get_conv_pool_shape( | |
image_oper.shape, args, image_oper.shape[1], False | |
) | |
use_nchw = image_oper.use_nchw() | |
inputs = [None] * 11 | |
inputs[0] = image_id | |
inputs[1] = self.add_immediate_int_scalar(args.pad_l) | |
inputs[2] = self.add_immediate_int_scalar(args.pad_r) | |
inputs[3] = self.add_immediate_int_scalar(args.pad_t) | |
inputs[4] = self.add_immediate_int_scalar(args.pad_b) | |
inputs[5] = self.add_immediate_int_scalar(args.stride_w) | |
inputs[6] = self.add_immediate_int_scalar(args.stride_h) | |
inputs[7] = self.add_immediate_int_scalar(args.kernel_w) | |
inputs[8] = self.add_immediate_int_scalar(args.kernel_h) | |
inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) | |
inputs[10] = self.add_immediate_bool_scalar(use_nchw) | |
outputs = [None] * 1 | |
outputs[0] = self.add_tensor_operand( | |
node.outputsAt(0), image_oper._replace(shape=out_shape) | |
) | |
self.add_operation(opcode, inputs, outputs) | |
def add_avg_pool2d(self, node): | |
assert node.inputsSize() == 7 | |
assert node.outputsSize() == 1 | |
( | |
image, | |
kernel, | |
stride, | |
padding, | |
ceil_mode, | |
count_include_pad, | |
divisor_override, | |
) = node.inputs() | |
_, count_include_pad_value = self.get_constant_value(count_include_pad) | |
_, divisor_override_value = self.get_constant_value(divisor_override) | |
if not count_include_pad_value or divisor_override_value: | |
raise Exception( | |
"NNAPI doesn't support count_include_pad=False or divisor_override" | |
) | |
args = self.get_conv_pool_args_2d_from_jit( | |
self.get_size_arg(kernel), stride, padding | |
) | |
image_id, image_oper = self.get_tensor_operand_by_jitval(image) | |
assert len(image_oper.shape) == 4 | |
out_shape = get_conv_pool_shape( | |
image_oper.shape, args, image_oper.shape[1], False | |
) | |
use_nchw = image_oper.use_nchw() | |
inputs = [None] * 11 | |
inputs[0] = image_id | |
inputs[1] = self.add_immediate_int_scalar(args.pad_l) | |
inputs[2] = self.add_immediate_int_scalar(args.pad_r) | |
inputs[3] = self.add_immediate_int_scalar(args.pad_t) | |
inputs[4] = self.add_immediate_int_scalar(args.pad_b) | |
inputs[5] = self.add_immediate_int_scalar(args.stride_w) | |
inputs[6] = self.add_immediate_int_scalar(args.stride_h) | |
inputs[7] = self.add_immediate_int_scalar(args.kernel_w) | |
inputs[8] = self.add_immediate_int_scalar(args.kernel_h) | |
inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) | |
inputs[10] = self.add_immediate_bool_scalar(use_nchw) | |
outputs = [None] * 1 | |
out_id = self.add_tensor_operand( | |
node.outputsAt(0), image_oper._replace(shape=out_shape) | |
) | |
self._handle_conv_pool_flexible_input(out_id, image, args, False) | |
outputs[0] = out_id | |
self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs) | |
def add_adaptive_avg_pool2d(self, node): | |
assert node.inputsSize() == 2 | |
assert node.outputsSize() == 1 | |
image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size( | |
node.inputsAt(0) | |
) | |
assert len(image_oper.shape) == 4 | |
size_ctype, size_arg = self.get_constant_value(node.inputsAt(1)) | |
assert size_ctype.kind() == "ListType" | |
assert size_ctype.getElementType().kind() == "IntType" | |
if size_arg != [1, 1]: | |
raise Exception( | |
"NNAPI only supports adaptive_avg_pool2d with output size (1, 1)." | |
) | |
out_shape = image_oper.shape[0:2] + tuple(size_arg) | |
use_nchw = image_oper.use_nchw() | |
inputs = [None] * 11 | |
inputs[0] = image_id | |
inputs[1] = self.add_immediate_int_scalar(0) | |
inputs[2] = self.add_immediate_int_scalar(0) | |
inputs[3] = self.add_immediate_int_scalar(0) | |
inputs[4] = self.add_immediate_int_scalar(0) | |
inputs[5] = self.add_immediate_int_scalar(1) | |
inputs[6] = self.add_immediate_int_scalar(1) | |
inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3]) | |
inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2]) | |
inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) | |
inputs[10] = self.add_immediate_bool_scalar(use_nchw) | |
outputs = [None] * 1 | |
outputs[0] = self.add_tensor_operand( | |
node.outputsAt(0), image_oper._replace(shape=out_shape) | |
) | |
self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs) | |
def add_upsample_nearest2d(self, node): | |
assert node.inputsSize() == 3 or node.inputsSize() == 4 | |
assert node.outputsSize() == 1 | |
if node.inputsSize() == 3: | |
image, size_jit, scale_jit = node.inputs() | |
else: | |
image, size_jit, scale_h_jit, scale_w_jit = node.inputs() | |
size_ctype, size_arg = self.get_constant_value(size_jit) | |
if node.inputsSize() == 3: | |
scale_ctype, scale_arg = self.get_constant_value(scale_jit) # type: ignore[possibly-undefined] | |
else: | |
scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit) # type: ignore[possibly-undefined] | |
scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined] | |
# The only way for the 4-argument overload of upsample_nearest2d to | |
# have been added to the graph without error is if the scale_h and | |
# scale_w arguments are None | |
assert scale_h_ctype.kind() == "NoneType" | |
assert scale_w_ctype.kind() == "NoneType" | |
scale_ctype = scale_h_ctype | |
scale_arg = scale_h_arg | |
image_id, image_oper = self.get_tensor_operand_by_jitval(image) | |
assert len(image_oper.shape) == 4 | |
if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType": | |
raise Exception("Size and scale cannot both be non-None.") | |
elif size_ctype.kind() != "NoneType": | |
assert size_ctype.kind() == "ListType" | |
assert size_ctype.getElementType().kind() == "IntType" | |
assert scale_ctype.kind() == "NoneType" | |
assert scale_arg is None | |
assert isinstance(size_arg, list) | |
assert size_arg | |
assert all(isinstance(val, int) for val in size_arg) | |
if len(size_arg) == 1: | |
size_arg = size_arg * 2 | |
assert len(size_arg) == 2 | |
out_h = size_arg[0] | |
out_w = size_arg[1] | |
arg_h = self.add_immediate_int_scalar(out_h) | |
arg_w = self.add_immediate_int_scalar(out_w) | |
elif scale_ctype.kind() != "NoneType": | |
assert scale_ctype.kind() == "ListType" | |
assert scale_ctype.getElementType().kind() == "FloatType" | |
assert size_ctype.kind() == "NoneType" | |
assert size_arg is None | |
assert isinstance(scale_arg, list) | |
assert scale_arg | |
assert all(isinstance(val, float) for val in scale_arg) | |
if len(scale_arg) == 1: | |
scale_arg = scale_arg * 2 | |
assert len(scale_arg) == 2 | |
out_h = int(scale_arg[0] * image_oper.shape[2]) | |
out_w = int(scale_arg[1] * image_oper.shape[3]) | |
arg_h = self.add_immediate_float_scalar(scale_arg[0]) | |
arg_w = self.add_immediate_float_scalar(scale_arg[1]) | |
else: | |
raise Exception("Size and scale cannot both be None.") | |
out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w) | |
use_nchw = image_oper.use_nchw() | |
out_id = self.add_tensor_operand( | |
node.outputsAt(0), image_oper._replace(shape=out_shape) | |
) | |
if image_oper.shape[0] == 0 or image_oper.shape[1] == 0: | |
raise Exception("Flexible batch or channels not supported") | |
# Handle variable input size | |
for dim in (2, 3): # h, w indices | |
if image_oper.shape[dim] == 0: | |
if size_ctype.kind() != "NoneType": | |
self.compute_operand_shape(out_id, dim, size_arg[dim - 2]) | |
elif scale_ctype.kind() != "NoneType": | |
self.compute_operand_shape( | |
out_id, | |
dim, | |
f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})", | |
) | |
else: | |
raise Exception("Size and scale cannot both be None.") | |
inputs = [None] * 4 | |
inputs[0] = image_id | |
inputs[1] = arg_w | |
inputs[2] = arg_h | |
inputs[3] = self.add_immediate_bool_scalar(use_nchw) | |
outputs = [None] * 1 | |
outputs[0] = out_id | |
self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs) | |
def add_addmm(self, node): | |
assert node.inputsSize() == 5 | |
assert node.outputsSize() == 1 | |
jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs() | |
for jitval in (jit_beta, jit_alpha): | |
scale_ctype, scale_value = self.get_constant_value(jitval) | |
assert scale_ctype.kind() in ("IntType", "FloatType") | |
if scale_value != 1: | |
raise Exception( | |
"NNAPI Fully-Connected does not support alpha and beta." | |
) | |
self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias) | |
def add_linear(self, node): | |
assert node.inputsSize() == 3 | |
assert node.outputsSize() == 1 | |
jit_input, jit_weight, jit_bias = node.inputs() | |
self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias) | |
def add_addmm_or_linear( | |
self, node, transpose_weight, jit_input, jit_weight, jit_bias | |
): | |
input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input) | |
bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias) | |
assert len(input_oper.shape) == 2 | |
assert len(bias_oper.shape) == 1 | |
# TODO: Transform at load time to share weights with CPU model. | |
_, weight_tensor = self.get_constant_value(jit_weight, "TensorType") | |
assert len(weight_tensor.shape) == 2 | |
if transpose_weight: | |
nnapi_weight_tensor = weight_tensor.t().contiguous() | |
else: | |
nnapi_weight_tensor = weight_tensor.contiguous() | |
weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) | |
weight_oper = self.operands[weight_id] | |
out_shape = (input_oper.shape[0], weight_oper.shape[0]) | |
out_id = self.add_tensor_operand( | |
node.outputsAt(0), input_oper._replace(shape=out_shape) | |
) | |
if input_oper.shape[0] == 0: | |
self.forward_operand_shape(out_id, 0, input_id, 0) | |
inputs = [None] * 4 | |
inputs[0] = input_id | |
inputs[1] = weight_id | |
inputs[2] = bias_id | |
inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) | |
outputs = [None] * 1 | |
outputs[0] = out_id | |
self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) | |
def add_qlinear(self, node): | |
assert node.inputsSize() == 4 | |
assert node.outputsSize() == 1 | |
( | |
jit_input, | |
jit_packed_weight, | |
jit_scale, | |
jit_zero_point, | |
) = node.inputs() | |
input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input) | |
# TODO: Support automatic reshape | |
assert len(input_oper.shape) == 2 | |
_, out_scale = self.get_constant_value(jit_scale, "FloatType") | |
_, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") | |
weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) | |
assert weight_ctype.name() == "LinearPackedParamsBase" | |
raw_weight, raw_bias = packed_weight.__getstate__()[0] | |
assert raw_bias is not None | |
assert len(raw_weight.shape) == 2 | |
assert len(raw_bias.shape) == 1 | |
assert raw_bias.shape[0] == raw_weight.shape[0] | |
assert raw_weight.shape[1] == input_oper.shape[1] | |
assert raw_weight.qscheme() == torch.per_tensor_affine | |
if raw_weight.dtype == torch.quint8: | |
unsigned_weight = raw_weight | |
else: | |
assert raw_weight.dtype == torch.qint8 | |
unsigned_weight = torch._make_per_tensor_quantized_tensor( | |
(raw_weight.int_repr().int() + 128).to(torch.uint8), | |
scale=raw_weight.q_scale(), | |
zero_point=raw_weight.q_zero_point() + 128, | |
) | |
weight_scale = unsigned_weight.q_scale() | |
bias_scale = input_oper.scale * weight_scale | |
int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) | |
bias_id = self.add_tensor_operand_for_weight(int_bias) | |
multiplier = input_oper.scale * weight_scale / out_scale | |
assert multiplier > 0 | |
if multiplier >= 1: | |
raise Exception( | |
"Quantized convolution multiplier is greater than 1. " | |
"This is supported by NNAPI, but not by most hardware backends. " | |
"Try training a model without quantization-aware training. " | |
) | |
# TODO: Transform at load time to share weights with CPU model. | |
nnapi_weight_tensor = unsigned_weight.contiguous() | |
weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) | |
weight_oper = self.operands[weight_id] | |
out_shape = (input_oper.shape[0], weight_oper.shape[0]) | |
out_oper = input_oper._replace( | |
shape=out_shape, | |
scale=out_scale, | |
zero_point=out_zero_point, | |
) | |
inputs = [None] * 4 | |
inputs[0] = input_id | |
inputs[1] = weight_id | |
inputs[2] = bias_id | |
inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) | |
outputs = [None] * 1 | |
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) | |
self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) | |
def get_optional_bias(self, jit_bias, weight_tensor, transpose=False): | |
ctype, value = self.get_constant_value(jit_bias) | |
if ctype.kind() == "NoneType": | |
bias_idx = 1 if transpose else 0 | |
nnapi_bias_tensor = torch.zeros( | |
weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype | |
) | |
bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor) | |
bias_oper = self.operands[bias_id] | |
return bias_id, bias_oper | |
else: | |
return self.get_tensor_operand_for_weight(jit_bias) | |
def add_conv2d(self, node): | |
assert node.inputsSize() == 7 | |
assert node.outputsSize() == 1 | |
( | |
jit_image, | |
jit_weight, | |
jit_bias, | |
jit_stride, | |
jit_pad, | |
jit_dilation, | |
jit_groups, | |
) = node.inputs() | |
_, weight_tensor = self.get_constant_value(jit_weight, "TensorType") | |
bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor) | |
args = self.get_conv_pool_args_2d_from_jit( | |
weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups | |
) | |
return self.add_conv2d_common( | |
node.outputsAt(0), | |
0.0, | |
0, | |
jit_image, | |
weight_tensor, | |
bias_id, | |
args, | |
False, # transpose | |
NNAPI_FuseCode.FUSED_NONE, | |
) | |
def add_conv_underscore(self, node): | |
assert node.inputsSize() == 13 | |
assert node.outputsSize() == 1 | |
( | |
jit_image, | |
jit_weight, | |
jit_bias, | |
jit_stride, | |
jit_pad, | |
jit_dilation, | |
jit_transpose, | |
_, | |
jit_groups, | |
_, | |
_, | |
_, | |
_, | |
) = node.inputs() | |
_, weight_tensor = self.get_constant_value(jit_weight, "TensorType") | |
_, transpose = self.get_constant_value(jit_transpose) | |
bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose) | |
args = self.get_conv_pool_args_2d_from_jit( | |
weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups | |
) | |
return self.add_conv2d_common( | |
node.outputsAt(0), | |
0.0, | |
0, | |
jit_image, | |
weight_tensor, | |
bias_id, | |
args, | |
transpose, | |
NNAPI_FuseCode.FUSED_NONE, | |
) | |
def add_log_softmax(self, node): | |
assert node.inputsSize() == 3 | |
assert node.outputsSize() == 1 | |
(jit_input, jit_dim, jit_half_to_float) = node.inputs() | |
input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input) | |
_, dim = self.get_constant_value(jit_dim, "IntType") | |
out_shape = input_oper.shape | |
inputs = [None] * 3 | |
inputs[0] = input_id | |
# specifying 1 as the scaling factor for the exponent, beta | |
inputs[1] = self.add_immediate_float_scalar(1) | |
inputs[2] = self.add_immediate_int_scalar(dim) | |
outputs = [None] * 1 | |
outputs[0] = self.add_tensor_operand( | |
node.outputsAt(0), input_oper._replace(shape=out_shape) | |
) | |
self.add_operation(NNAPI_OperationCode.LOG_SOFTMAX, inputs, outputs) | |
def add_qconv2d(self, node, fuse_code, transpose=False): | |
assert node.inputsSize() == 4 | |
assert node.outputsSize() == 1 | |
( | |
jit_image, | |
jit_packed_weight, | |
jit_scale, | |
jit_zero_point, | |
) = node.inputs() | |
_, out_scale = self.get_constant_value(jit_scale, "FloatType") | |
_, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") | |
weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) | |
assert weight_ctype.name() == "Conv2dPackedParamsBase" | |
( | |
pack_version, | |
tensors, | |
opt_tensors, | |
) = packed_weight.__getstate__()[0] | |
assert pack_version == "2" | |
packed_config, raw_weight = tensors | |
(raw_bias,) = opt_tensors | |
assert raw_bias is not None | |
args = self.get_conv_pool_args_2d_from_pack( | |
raw_weight.shape[2:4], packed_config | |
) | |
assert raw_weight.qscheme() == torch.per_tensor_affine | |
if raw_weight.dtype == torch.quint8: | |
unsigned_weight = raw_weight | |
else: | |
assert raw_weight.dtype == torch.qint8 | |
unsigned_weight = torch._make_per_tensor_quantized_tensor( | |
(raw_weight.int_repr().int() + 128).to(torch.uint8), | |
scale=raw_weight.q_scale(), | |
zero_point=raw_weight.q_zero_point() + 128, | |
) | |
weight_scale = unsigned_weight.q_scale() | |
_, image_oper = self.get_tensor_operand_by_jitval(jit_image) | |
bias_scale = image_oper.scale * weight_scale | |
int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) | |
bias_id = self.add_tensor_operand_for_weight(int_bias) | |
multiplier = image_oper.scale * weight_scale / out_scale | |
assert multiplier > 0 | |
if multiplier >= 1: | |
raise Exception( | |
"Quantized convolution multiplier is greater than 1. " | |
"This is supported by NNAPI, but not by most hardware backends. " | |
"Try training a model without quantization-aware training. " | |
) | |
return self.add_conv2d_common( | |
node.outputsAt(0), | |
out_scale, | |
out_zero_point, | |
jit_image, | |
unsigned_weight, | |
bias_id, | |
args, | |
transpose, | |
fuse_code, | |
) | |
def add_conv2d_common( | |
self, | |
jit_out, | |
out_scale, | |
out_zero_point, | |
jit_image, | |
weight_tensor, | |
bias_id, | |
args, | |
transpose, | |
fuse_code, | |
): | |
image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) | |
in_c = image_oper.shape[1] | |
if args.group == 1: | |
# Full convolution | |
depthwise = False | |
if transpose: | |
weight_permutation = (1, 2, 3, 0) | |
else: | |
weight_permutation = (0, 2, 3, 1) | |
elif args.group == in_c: | |
# Depthwise convolution | |
depthwise = True | |
weight_permutation = (1, 2, 3, 0) | |
else: | |
raise Exception("Group convolution not supported yet.") | |
# TODO: Transform at load time to share weights with CPU model. | |
nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous() | |
weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) | |
weight_oper = self.operands[weight_id] | |
bias_oper = self.operands[bias_id] | |
if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32: | |
assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 | |
assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 | |
elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: | |
assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM | |
assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32 | |
assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale) | |
assert bias_oper.zero_point == 0 | |
else: | |
raise Exception(f"Unsupported input type for conv2d: {image_oper.op_type}") | |
assert len(image_oper.shape) == 4 | |
assert len(weight_oper.shape) == 4 | |
assert len(bias_oper.shape) == 1 | |
if depthwise: | |
# Depthwise convolution | |
one, kern_h, kern_w, out_c = weight_oper.shape | |
assert one == 1 | |
assert out_c % in_c == 0 | |
channel_multiplier = out_c // in_c | |
assert channel_multiplier == 1 # Don't support multiplier | |
assert out_c == in_c | |
else: | |
# Full convolution | |
out_c, kern_h, kern_w, kern_d = weight_oper.shape | |
assert kern_d == in_c | |
assert out_c == bias_oper.shape[0] | |
use_nchw = image_oper.use_nchw() | |
if depthwise: | |
num_args = 12 | |
opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D | |
else: | |
num_args = 11 | |
if transpose: | |
opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D | |
else: | |
opcode = NNAPI_OperationCode.CONV_2D | |
inputs = [None] * num_args | |
inputs[0] = image_id | |
inputs[1] = weight_id | |
inputs[2] = bias_id | |
inputs[3] = self.add_immediate_int_scalar(args.pad_l) | |
inputs[4] = self.add_immediate_int_scalar(args.pad_r) | |
inputs[5] = self.add_immediate_int_scalar(args.pad_t) | |
inputs[6] = self.add_immediate_int_scalar(args.pad_b) | |
inputs[7] = self.add_immediate_int_scalar(args.stride_w) | |
inputs[8] = self.add_immediate_int_scalar(args.stride_h) | |
if depthwise: | |
inputs[9] = self.add_immediate_int_scalar(1) | |
inputs[10] = self.add_immediate_int_scalar(fuse_code) | |
inputs[11] = self.add_immediate_bool_scalar(use_nchw) | |
else: | |
inputs[9] = self.add_immediate_int_scalar(fuse_code) | |
inputs[10] = self.add_immediate_bool_scalar(use_nchw) | |
outputs = [None] * 1 | |
out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose) | |
out_oper = image_oper._replace( | |
shape=out_shape, | |
scale=out_scale, | |
zero_point=out_zero_point, | |
) | |
out_id = self.add_tensor_operand(jit_out, out_oper) | |
self._handle_conv_pool_flexible_input(out_id, jit_image, args, transpose) | |
outputs[0] = out_id | |
self.add_operation(opcode, inputs, outputs) | |
def _handle_conv_pool_flexible_input(self, out_id, jit_image, args, transpose): | |
image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) | |
batch, in_ch, in_h, in_w = image_oper.shape | |
if batch == 0: | |
self.forward_operand_shape(out_id, 0, image_id, 0) | |
if in_ch == 0: | |
raise Exception("Input channels can't be flexible") | |
# H & W | |
if transpose: | |
if in_h == 0: | |
self.compute_operand_shape( | |
out_id, | |
2, | |
f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}", | |
) | |
if in_w == 0: | |
self.compute_operand_shape( | |
out_id, | |
3, | |
f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}", | |
) | |
else: | |
if in_h == 0: | |
self.compute_operand_shape( | |
out_id, | |
2, | |
f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1", | |
) | |
if in_w == 0: | |
self.compute_operand_shape( | |
out_id, | |
3, | |
f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1", | |
) | |
def serialize_model( | |
module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False | |
): | |
"""Convert to NNAPI and serialize torchscript module. | |
Parameters: | |
module: Torchscript module to convert | |
inputs: Tensors used to specify input details for NNAPI | |
config (optional): Optional config to attach to module | |
return_shapes (optional): Specify shape of outputs if | |
your module uses runtime flexible shapes to set output | |
buffer size for NNAPI | |
use_int16_for_qint16 (optional): Use Pytorch int16 to represent NNAPI qint16 values | |
""" | |
return _NnapiSerializer(config, use_int16_for_qint16).serialize_model( | |
module, inputs, return_shapes | |
) | |