Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import functools | |
import sys | |
from typing import Optional, Tuple | |
import torch | |
from torch._C import _onnx as _C_onnx | |
from torch.onnx import ( | |
_type_utils, | |
errors, | |
symbolic_helper, | |
symbolic_opset9 as opset9, | |
utils, | |
) | |
from torch.onnx._internal import _beartype, jit_utils, registration | |
# EDITING THIS FILE? READ THIS FIRST! | |
# see Note [Edit Symbolic Files] in README.md | |
# This file exports ONNX ops for opset 12 | |
__all__ = [ | |
"argmax", | |
"argmin", | |
"binary_cross_entropy_with_logits", | |
"celu", | |
"cross_entropy_loss", | |
"dropout", | |
"einsum", | |
"ge", | |
"le", | |
"native_dropout", | |
"nll_loss", | |
"nll_loss2d", | |
"nll_loss_nd", | |
"outer", | |
"pow", | |
"tensordot", | |
"unfold", | |
] | |
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12) | |
def _einsum_helper(g: jit_utils.GraphContext, equation, tensors): | |
if not tensors: | |
raise RuntimeError("Einsum inputs are empty.") | |
# ONNX does not support bool for Einsum inputs. | |
if symbolic_helper._is_bool(tensors[0]): | |
tensors = [ | |
g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64) | |
for tensor in tensors | |
] | |
return g.op( | |
"Cast", | |
g.op("Einsum", *tensors, equation_s=equation), | |
to_i=_C_onnx.TensorProtoDataType.BOOL, | |
) | |
else: | |
return g.op("Einsum", *tensors, equation_s=equation) | |
def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None): | |
tensors = symbolic_helper._unpack_list(tensor_list) | |
return _einsum_helper(g, equation, tensors) | |
def outer(g: jit_utils.GraphContext, input, other): | |
# make sure to cast other to self's type | |
if _type_utils.JitScalarType.from_value( | |
other, _type_utils.JitScalarType.UNDEFINED | |
) != _type_utils.JitScalarType.from_value(input): | |
other = g.op( | |
"Cast", | |
other, | |
to_i=_type_utils.JitScalarType.from_value(input).onnx_type(), | |
) | |
return _einsum_helper(g, "i,j->ij", [input, other]) | |
def _dropout_returns_masked_input_and_mask( | |
g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool | |
) -> Tuple[torch._C.Value, Optional[torch._C.Value]]: | |
symbolic_helper.check_training_mode(train, "dropout") | |
# In eval mode, dropout is non-op. That is, if the node's | |
# train param is set to False, dropout just returns its inputs. | |
if not train: | |
return input, None | |
p = g.op("Constant", value_t=torch.tensor(p)) | |
t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool)) | |
r, mask = g.op("Dropout", input, p, t, outputs=2) | |
return r, mask | |
def dropout(g: jit_utils.GraphContext, input, p, train): | |
masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train) | |
return masked | |
def native_dropout(g: jit_utils.GraphContext, input, p, train): | |
return _dropout_returns_masked_input_and_mask(g, input, p, train) | |
def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index): | |
# none reduction : onnx::Constant[value={0}] | |
# mean reduction : onnx::Constant[value={1}] | |
# sum reduction : onnx::Constant[value={2}] | |
reduction = symbolic_helper._maybe_get_const(reduction, "i") | |
reduction_vals = ["none", "mean", "sum"] | |
reduction = reduction_vals[reduction] | |
# in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. | |
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). | |
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") | |
if weight.node().mustBeNone(): | |
nllloss = g.op( | |
"NegativeLogLikelihoodLoss", | |
self, | |
target, | |
reduction_s=reduction, | |
ignore_index_i=ignore_index, | |
) | |
else: | |
nllloss = g.op( | |
"NegativeLogLikelihoodLoss", | |
self, | |
target, | |
weight, | |
reduction_s=reduction, | |
ignore_index_i=ignore_index, | |
) | |
return nllloss | |
def nll_loss2d( | |
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index | |
): | |
return nll_loss(g, self, target, weight, reduction, ignore_index) | |
def nll_loss_nd( | |
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index | |
): | |
return nll_loss(g, self, target, weight, reduction, ignore_index) | |
def cross_entropy_loss( | |
g: jit_utils.GraphContext, | |
self, | |
target, | |
weight, | |
reduction, | |
ignore_index, | |
label_smoothing, | |
): | |
# none reduction : onnx::Constant[value={0}] | |
# mean reduction : onnx::Constant[value={1}] | |
# sum reduction : onnx::Constant[value={2}] | |
reduction = symbolic_helper._maybe_get_const(reduction, "i") | |
reduction_vals = ["none", "mean", "sum"] | |
reduction = reduction_vals[reduction] | |
label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f") | |
if label_smoothing is not None and label_smoothing > 0.0: | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX does not support label_smoothing", self | |
) | |
# in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value. | |
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). | |
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") | |
if weight.node().mustBeNone(): | |
celoss = g.op( | |
"SoftmaxCrossEntropyLoss", | |
self, | |
target, | |
reduction_s=reduction, | |
ignore_index_i=ignore_index, | |
) | |
else: | |
celoss = g.op( | |
"SoftmaxCrossEntropyLoss", | |
self, | |
target, | |
weight, | |
reduction_s=reduction, | |
ignore_index_i=ignore_index, | |
) | |
return celoss | |
def binary_cross_entropy_with_logits( | |
g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction | |
): | |
p = g.op("Constant", value_t=torch.tensor([1])) | |
sig_x = opset9.sigmoid(g, input) | |
log_sig_x = opset9.log(g, sig_x) | |
sub_1_x = opset9.sub(g, p, sig_x) | |
sub_1_y = opset9.sub(g, p, target) | |
log_1_x = opset9.log(g, sub_1_x) | |
if pos_weight is None or symbolic_helper._is_none(pos_weight): | |
output = opset9.neg( | |
g, | |
opset9.add( | |
g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x) | |
), | |
) | |
else: | |
output = opset9.neg( | |
g, | |
opset9.add( | |
g, | |
opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight), | |
opset9.mul(g, sub_1_y, log_1_x), | |
), | |
) | |
if weight is not None and not symbolic_helper._is_none(weight): | |
output = opset9.mul(g, weight, output) | |
reduction = symbolic_helper._maybe_get_const(reduction, "i") | |
if reduction == 0: | |
return output | |
elif reduction == 1: | |
return g.op("ReduceMean", output, keepdims_i=0) | |
elif reduction == 2: | |
return g.op("ReduceSum", output, keepdims_i=0) | |
else: | |
return symbolic_helper._onnx_unsupported( | |
"binary_cross_entropy_with_logits with reduction other than none, mean, or sum", | |
input, | |
) | |
def celu(g: jit_utils.GraphContext, self, alpha): | |
alpha = symbolic_helper._maybe_get_const(alpha, "f") | |
# if the input is of type double cast it to float | |
if ( | |
_type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) | |
== _type_utils.JitScalarType.DOUBLE | |
): | |
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) | |
out = g.op("Celu", self, alpha_f=alpha) | |
return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE) | |
return g.op("Celu", self, alpha_f=alpha) | |
def argmax( | |
g: jit_utils.GraphContext, | |
input: torch._C.Value, | |
dim: torch._C.Value, | |
keepdim: bool, | |
): | |
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") | |
def argmin( | |
g: jit_utils.GraphContext, | |
input: torch._C.Value, | |
dim: torch._C.Value, | |
keepdim: bool, | |
): | |
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") | |
def pow(g: jit_utils.GraphContext, self, exponent): | |
return g.op("Pow", self, exponent) | |
def ge(g: jit_utils.GraphContext, input, other): | |
return g.op("GreaterOrEqual", input, other) | |
def le(g: jit_utils.GraphContext, input, other): | |
return g.op("LessOrEqual", input, other) | |
def unfold(g: jit_utils.GraphContext, input, dimension, size, step): | |
const_size = symbolic_helper._maybe_get_const(size, "i") | |
const_step = symbolic_helper._maybe_get_const(step, "i") | |
if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value( | |
const_step | |
): | |
return opset9.unfold(g, input, dimension, const_size, const_step) | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) | |
sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) | |
if sizedim is not None: | |
low_start = g.op("Constant", value_t=torch.tensor(0)) | |
low_end = g.op("Constant", value_t=torch.tensor(sizedim)) | |
hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) | |
low_indices = g.op("Range", low_start, low_end, step) | |
hi_indices = g.op("Range", size, hi_end, step) | |
low_size = symbolic_helper._size_helper( | |
g, low_indices, g.op("Constant", value_t=torch.tensor(0)) | |
) | |
hi_size = symbolic_helper._size_helper( | |
g, hi_indices, g.op("Constant", value_t=torch.tensor(0)) | |
) | |
ndim = symbolic_helper._get_tensor_rank(input) | |
assert ndim is not None | |
perm = list(range(0, ndim)) | |
perm.append(perm.pop(dimension)) | |
unsqueeze_list = [] | |
loop_condition = g.op("Constant", value_t=torch.tensor(1)) | |
loop_condition = g.op( | |
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL | |
) | |
loop_len = g.op("Min", low_size, hi_size) | |
loop, (loop_context,), _ = jit_utils.add_op_with_blocks( | |
g, "Loop", loop_len, loop_condition, n_blocks=1 | |
) | |
loop_block = loop_context.block | |
block_input_iter = utils._add_input_to_block(loop_block) | |
# FIXME(justinchuby): cond is unused? | |
cond = utils._add_input_to_block(loop_block) | |
starts = loop_context.op("Gather", low_indices, block_input_iter) | |
ends = loop_context.op("Gather", hi_indices, block_input_iter) | |
axes = loop_context.op("Constant", value_t=torch.tensor([2])) | |
starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0]) | |
ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0]) | |
stack = loop_context.op("Slice", input, starts, ends, axes) | |
unsqueeze = symbolic_helper._unsqueeze_helper( | |
loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension] | |
) | |
unsqueeze_list.append(unsqueeze) | |
concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0) | |
cond_out = loop_context.op( | |
"Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL | |
) | |
utils._add_output_to_block(loop_block, cond_out) | |
utils._add_output_to_block(loop_block, concat) | |
loop_output = loop.node().output() | |
perm = [0, 1, 2, 3, 4] | |
perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] | |
transpose = g.op("Transpose", loop_output, perm_i=perm) | |
squeeze = symbolic_helper._squeeze_helper(g, transpose, [0]) | |
return squeeze | |
return symbolic_helper._unimplemented("Unfold", "input size not accessible") | |
def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None): | |
if out is not None: | |
symbolic_helper._unimplemented( | |
"Tensordot", "Out parameter is not supported for tensordot." | |
) | |
dim_count_a = symbolic_helper._get_tensor_rank(input_a) | |
if dim_count_a is None: | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.", | |
input_a, | |
) | |
dim_count_b = symbolic_helper._get_tensor_rank(input_b) | |
if dim_count_b is None: | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.", | |
input_b, | |
) | |
dims_a = [ | |
(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] | |
for i in range(len(dims_a)) | |
] | |
dims_b = [ | |
(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] | |
for i in range(len(dims_b)) | |
] | |
left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] | |
left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] | |
new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a) | |
new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b) | |
input_shape = g.op("Shape", new_input_a) | |
left_sizes_a = symbolic_helper._slice_helper( | |
g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)] | |
) | |
shape_sizes = [ | |
left_sizes_a, | |
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), | |
] | |
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) | |
input_shape = g.op("Shape", output_a) | |
slices = symbolic_helper._slice_helper( | |
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] | |
) | |
shape_sizes = [ | |
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), | |
slices, | |
] | |
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) | |
input_shape = g.op("Shape", new_input_b) | |
left_sizes_b = symbolic_helper._slice_helper( | |
g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize] | |
) | |
slices = symbolic_helper._slice_helper( | |
g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)] | |
) | |
shape_sizes = [ | |
slices, | |
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), | |
] | |
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) | |
input_shape = g.op("Shape", output_b) | |
slices = symbolic_helper._slice_helper( | |
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] | |
) | |
shape_sizes = [ | |
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), | |
slices, | |
] | |
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) | |
output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) | |
shape_sizes = [left_sizes_a, left_sizes_b] | |
return opset9._reshape_from_tensor(g, output, shape_sizes) | |