Spaces:
Sleeping
Sleeping
| # Copyright (c) ONNX Project Contributors | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import typing | |
| import unittest | |
| import onnx | |
| import onnx.parser | |
| import onnx.shape_inference | |
| class TestModelInference(unittest.TestCase): | |
| def _check(self, model_text: str, *expected: int): | |
| """Check that the model inference infers the expected types for outputs. | |
| Restricted to the simple case of tensor types, so expected types specify | |
| only the element type (ints corresponding to onnx.TensorProto.DataType). | |
| """ | |
| model = onnx.parser.parse_model(model_text) | |
| inferred = onnx.shape_inference.infer_shapes(model) | |
| outputs = inferred.graph.output | |
| for output, expected_elem_type in zip(outputs, expected): | |
| inferred_type = output.type | |
| self.assertTrue(inferred_type.HasField("tensor_type")) | |
| tensor_type = inferred_type.tensor_type | |
| self.assertTrue(tensor_type.HasField("elem_type")) | |
| elem_type = tensor_type.elem_type | |
| self.assertEqual(elem_type, expected_elem_type) | |
| def _check_inference_error(self, model_text: str): | |
| """Check that the model inference raises an InferenceError.""" | |
| model = onnx.parser.parse_model(model_text) | |
| with self.assertRaises(onnx.shape_inference.InferenceError): | |
| onnx.shape_inference.infer_shapes(model, True, True) | |
| def test_unknown_op(self): | |
| """Test that model inference handles unknown ops. | |
| This special treatment is to support custom ops. | |
| See comments in shape inference code for details. | |
| """ | |
| model = """ | |
| <ir_version: 7, opset_import: [ "" : 17]> | |
| agraph (float[N] x) => (y) | |
| { | |
| y = SomeUnknownOp (x) | |
| } | |
| """ | |
| # No output types are inferred for unknown ops. | |
| # But ensure that the inference does not fail. | |
| self._check(model) | |
| def test_mi_basic(self): | |
| """Test that model inference infers model output type.""" | |
| model = """ | |
| < | |
| ir_version: 7, | |
| opset_import: [ "" : 17] | |
| > | |
| agraph (float[N] x) => (y) | |
| { | |
| y = Cast<to=6> (x) | |
| } | |
| """ | |
| self._check(model, onnx.TensorProto.INT32) | |
| def test_mi_function(self): | |
| """Test use of functions.""" | |
| model = """ | |
| < | |
| ir_version: 7, | |
| opset_import: [ "" : 17, "local" : 1] | |
| > | |
| agraph (float[N] x) => (y) | |
| { | |
| y = local.cast(x) | |
| } | |
| < | |
| opset_import: [ "" : 17 ], | |
| domain: "local" | |
| > | |
| cast (x) => (y) | |
| { | |
| y = Cast<to=6> (x) | |
| } | |
| """ | |
| self._check(model, onnx.TensorProto.INT32) | |
| def test_mi_function_attr(self): | |
| """Test use of functions with attribute parameters.""" | |
| model = """ | |
| < | |
| ir_version: 7, | |
| opset_import: [ "" : 17, "local" : 1] | |
| > | |
| agraph (float[N] x) => (y) | |
| { | |
| y = local.cast<target=6>(x) | |
| } | |
| < | |
| opset_import: [ "" : 17 ], | |
| domain: "local" | |
| > | |
| cast<target>(x) => (y) | |
| { | |
| y = Cast<to:int = @target> (x) | |
| } | |
| """ | |
| self._check(model, onnx.TensorProto.INT32) | |
| def test_mi_function_subgraph_attr(self): | |
| """Test use of function attributes within subgraphs.""" | |
| model = """ | |
| < | |
| ir_version: 7, | |
| opset_import: [ "" : 17, "local" : 1] | |
| > | |
| agraph (float[N] x, bool flag) => (y) | |
| { | |
| y = local.cast<target=6>(x, flag) | |
| } | |
| < | |
| opset_import: [ "" : 17 ], | |
| domain: "local" | |
| > | |
| cast<target>(x, flag) => (y) | |
| { | |
| y = If (flag) < | |
| then_branch = g1 () => (z_then) { z_then = Cast<to:int = @target> (x) }, | |
| else_branch = g2 () => (z_else) { z_else = Cast<to:int = @target> (x) } | |
| > | |
| } | |
| """ | |
| self._check(model, onnx.TensorProto.INT32) | |
| def test_mi_function_multiple_calls(self): | |
| """Test use of multiple invocation of functions.""" | |
| model = """ | |
| < | |
| ir_version: 7, | |
| opset_import: [ "" : 17, "local" : 1] | |
| > | |
| agraph (float[N] x, bool flag) => (y, z) | |
| { | |
| y = local.cast<target=6>(x, flag) | |
| z = local.cast<target=7>(x, flag) | |
| } | |
| < | |
| opset_import: [ "" : 17 ], | |
| domain: "local" | |
| > | |
| cast<target>(x, flag) => (y) | |
| { | |
| y = If (flag) < | |
| then_branch = g1 () => (z_then) { z_then = Cast<to:int = @target> (x) }, | |
| else_branch = g2 () => (z_else) { z_else = Cast<to:int = @target> (x) } | |
| > | |
| } | |
| """ | |
| self._check(model, onnx.TensorProto.INT32, onnx.TensorProto.INT64) | |
| def _check_shape(self, model_text: str, *expected: typing.Sequence[int]): | |
| """Check that the model inference infers the expected shapes for outputs. | |
| Restricted to the simple case of tensor type outputs with completely | |
| known shapes. | |
| """ | |
| model = onnx.parser.parse_model(model_text) | |
| inferred = onnx.shape_inference.infer_shapes(model, True, True, True) | |
| outputs = inferred.graph.output | |
| for output, expected_shape in zip(outputs, expected): | |
| inferred_type = output.type | |
| self.assertTrue(inferred_type.HasField("tensor_type")) | |
| tensor_type = inferred_type.tensor_type | |
| self.assertTrue(tensor_type.HasField("shape")) | |
| inferred_shape = tensor_type.shape | |
| self.assertEqual(len(inferred_shape.dim), len(expected_shape)) | |
| for inferred_dim, expected_dim in zip(inferred_shape.dim, expected_shape): | |
| self.assertTrue(inferred_dim.HasField("dim_value")) | |
| self.assertEqual(inferred_dim.dim_value, expected_dim) | |
| def test_mi_constant(self): | |
| model = """ | |
| < | |
| ir_version: 7, | |
| opset_import: [ "" : 17] | |
| > | |
| mymodel (float[4, 8, 16] x) => (y) { | |
| shape = Constant<value_ints=[8,4,16]>() | |
| y = Reshape(x, shape) | |
| } | |
| """ | |
| self._check_shape(model, [8, 4, 16]) | |
| def test_mi_constant_2(self): | |
| model = """ | |
| < | |
| ir_version: 7, | |
| opset_import: [ "" : 17] | |
| > | |
| mymodel (float[4, 8, 16] x) => (y) { | |
| shape = Constant<value_ints=[4,2,8]>() | |
| two = Constant<value_int=2>() | |
| shape2 = Mul(shape, two) | |
| y = Reshape(x, shape2) | |
| } | |
| """ | |
| self._check_shape(model, [8, 4, 16]) | |
| def test_mi_constant_in_function(self): | |
| model = """ | |
| < | |
| ir_version: 7, | |
| opset_import: [ "" : 17, "local" : 1] | |
| > | |
| main (float x) => (y, z) { | |
| y, z = local.expand(x) | |
| } | |
| < | |
| opset_import: [ "" : 17 ], | |
| domain: "local" | |
| > | |
| expand (x) => (y, z) { | |
| shape1 = Constant<value = int64[2] {4,4}>() | |
| shape2 = Constant<value = int64[3] {8,8,8}>() | |
| z = Expand (x, shape2) | |
| y = Expand (x, shape1) | |
| } | |
| """ | |
| self._check_shape(model, [4, 4], [8, 8, 8]) | |
| def test_mi_function_default_attr(self): | |
| """Test use of default values of function attributes.""" | |
| model = """ | |
| <ir_version: 7, opset_import: [ "" : 17, "local" : 1]> | |
| agraph (float[N] x) => (y, z) | |
| { | |
| y = local.cast <target=6> (x) # casts to INT32 type (encoding value 6) | |
| z = local.cast (x) # uses default-attribute value of 1 (FLOAT type) | |
| } | |
| <opset_import: [ "" : 17 ], domain: "local"> | |
| cast <target: int = 1> (x) => (y) | |
| { | |
| y = Cast <to:int = @target> (x) | |
| } | |
| """ | |
| self._check(model, onnx.TensorProto.INT32, onnx.TensorProto.FLOAT) | |
| def test_mi_overloaded_function(self): | |
| """Test use of functions.""" | |
| model = """ | |
| <ir_version: 10, opset_import: [ "" : 17, "local" : 1]> | |
| agraph (float[N] x) => (y, z) | |
| { | |
| y = local.cast:to_int32 (x) | |
| z = local.cast:to_int64 (x) | |
| } | |
| <opset_import: [ "" : 17 ], domain: "local", overload: "to_int32"> | |
| cast (x) => (y) | |
| { | |
| y = Cast<to=6> (x) | |
| } | |
| <opset_import: [ "" : 17 ], domain: "local", overload: "to_int64"> | |
| cast (x) => (y) | |
| { | |
| y = Cast<to=7> (x) | |
| } | |
| """ | |
| self._check(model, onnx.TensorProto.INT32, onnx.TensorProto.INT64) | |
| if __name__ == "__main__": | |
| unittest.main() | |