Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
import unittest | |
from typing import Sequence | |
from shape_inference_test import TestShapeInferenceHelper | |
import onnx | |
import onnx.helper | |
import onnx.parser | |
import onnx.shape_inference | |
from onnx import AttributeProto, TypeProto | |
float_type_ = onnx.helper.make_tensor_type_proto(1, None) | |
uint8_type_ = onnx.helper.make_tensor_type_proto(2, None) | |
int8_type_ = onnx.helper.make_tensor_type_proto(3, None) | |
int32_type_ = onnx.helper.make_tensor_type_proto(6, None) | |
float16_type_ = onnx.helper.make_tensor_type_proto(10, None) | |
no_type_ = TypeProto() | |
class TestFunctionInference(TestShapeInferenceHelper): | |
def _check( | |
self, | |
function_text: str, | |
input_types: Sequence[TypeProto], | |
attributes: Sequence[AttributeProto], | |
expected_output_types: Sequence[TypeProto], | |
): | |
function = onnx.parser.parse_function(function_text) | |
result = onnx.shape_inference.infer_function_output_types( | |
function, input_types, attributes | |
) | |
self.assertEqual(len(expected_output_types), len(result)) | |
for expected, actual in zip(expected_output_types, result): | |
self._compare_value_infos(expected, actual) | |
def _check_fails( | |
self, | |
function_text: str, | |
input_types: Sequence[TypeProto], | |
attributes: Sequence[AttributeProto], | |
): | |
function = onnx.parser.parse_function(function_text) | |
def invoke_inference(): | |
onnx.shape_inference.infer_function_output_types( | |
function, input_types, attributes | |
) | |
self.assertRaises(onnx.shape_inference.InferenceError, invoke_inference) | |
def test_fi_basic(self): | |
code = """ | |
<opset_import: [ "" : 18 ], domain: "local"> | |
f (y, z) => (w) { | |
x = Add(y, z) | |
w = Mul(x, y) | |
} | |
""" | |
self._check(code, [float_type_, float_type_], [], [float_type_]) | |
self._check(code, [int32_type_, int32_type_], [], [int32_type_]) | |
self._check_fails(code, [float_type_, int32_type_], []) | |
def test_fi_attribute(self): | |
code = """ | |
<opset_import: [ "" : 18 ], domain: "local"> | |
CastTo <dtype> (x) => (y) { | |
y = Cast <to : int = @dtype> (x) | |
} | |
""" | |
dtype_6 = onnx.helper.make_attribute("dtype", 6) | |
self._check(code, [float_type_], [dtype_6], [int32_type_]) | |
dtype_10 = onnx.helper.make_attribute("dtype", 10) | |
self._check(code, [float_type_], [dtype_10], [float16_type_]) | |
def test_fi_optional_input(self): | |
code = """ | |
<opset_import: [ "" : 18 ], domain: "local"> | |
DoReduce (x, axes) => (y) { | |
y = ReduceMax (x, axes) | |
} | |
""" | |
# We can omit the type for a missing trailing optional parameter | |
self._check(code, [float_type_], [], [float_type_]) | |
# Or, we can pass in a default-value of TypeProto() for a missing optional parameter | |
self._check(code, [float_type_, no_type_], [], [float_type_]) | |
code = """ | |
<opset_import: [ "" : 18 ], domain: "local"> | |
Quantize (x, scale, zero_point) => (y) { | |
y = QuantizeLinear (x, scale, zero_point) | |
} | |
""" | |
# If the optional third parameter is specified, it determines the output type. | |
self._check(code, [float_type_, float_type_, int8_type_], [], [int8_type_]) | |
self._check(code, [float_type_, float_type_, uint8_type_], [], [uint8_type_]) | |
# If the optional third parameter is omitted, the output type is uint8 (default). | |
self._check(code, [float_type_, float_type_, no_type_], [], [uint8_type_]) | |
code = """ | |
<opset_import: [ "" : 18 ], domain: "local"> | |
DoClip (x, min, max) => (y) { | |
y = Clip (x, min, max) | |
} | |
""" | |
# A test-case with a non-trailing missing optional parameter | |
self._check(code, [float_type_, no_type_, float_type_], [], [float_type_]) | |
# A failing test-case with a non-trailing missing optional parameter | |
self._check_fails(code, [float_type_, no_type_, int8_type_], []) | |
if __name__ == "__main__": | |
unittest.main() | |