File size: 4,445 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
import unittest
from typing import Sequence

from shape_inference_test import TestShapeInferenceHelper

import onnx
import onnx.helper
import onnx.parser
import onnx.shape_inference
from onnx import AttributeProto, TypeProto

float_type_ = onnx.helper.make_tensor_type_proto(1, None)
uint8_type_ = onnx.helper.make_tensor_type_proto(2, None)
int8_type_ = onnx.helper.make_tensor_type_proto(3, None)
int32_type_ = onnx.helper.make_tensor_type_proto(6, None)
float16_type_ = onnx.helper.make_tensor_type_proto(10, None)
no_type_ = TypeProto()


class TestFunctionInference(TestShapeInferenceHelper):
    def _check(

        self,

        function_text: str,

        input_types: Sequence[TypeProto],

        attributes: Sequence[AttributeProto],

        expected_output_types: Sequence[TypeProto],

    ):
        function = onnx.parser.parse_function(function_text)
        result = onnx.shape_inference.infer_function_output_types(
            function, input_types, attributes
        )
        self.assertEqual(len(expected_output_types), len(result))
        for expected, actual in zip(expected_output_types, result):
            self._compare_value_infos(expected, actual)

    def _check_fails(

        self,

        function_text: str,

        input_types: Sequence[TypeProto],

        attributes: Sequence[AttributeProto],

    ):
        function = onnx.parser.parse_function(function_text)

        def invoke_inference():
            onnx.shape_inference.infer_function_output_types(
                function, input_types, attributes
            )

        self.assertRaises(onnx.shape_inference.InferenceError, invoke_inference)

    def test_fi_basic(self):
        code = """

            <opset_import: [ "" : 18 ], domain: "local">

            f (y, z) => (w) {

                x = Add(y, z)

                w = Mul(x, y)

            }

        """
        self._check(code, [float_type_, float_type_], [], [float_type_])
        self._check(code, [int32_type_, int32_type_], [], [int32_type_])
        self._check_fails(code, [float_type_, int32_type_], [])

    def test_fi_attribute(self):
        code = """

            <opset_import: [ "" : 18 ], domain: "local">

            CastTo <dtype> (x) => (y) {

                y = Cast <to : int = @dtype> (x)

            }

        """
        dtype_6 = onnx.helper.make_attribute("dtype", 6)
        self._check(code, [float_type_], [dtype_6], [int32_type_])

        dtype_10 = onnx.helper.make_attribute("dtype", 10)
        self._check(code, [float_type_], [dtype_10], [float16_type_])

    def test_fi_optional_input(self):
        code = """

            <opset_import: [ "" : 18 ], domain: "local">

            DoReduce (x, axes) => (y) {

                y = ReduceMax (x, axes)

            }

        """
        # We can omit the type for a missing trailing optional parameter
        self._check(code, [float_type_], [], [float_type_])
        # Or, we can pass in a default-value of TypeProto() for a missing optional parameter
        self._check(code, [float_type_, no_type_], [], [float_type_])

        code = """

            <opset_import: [ "" : 18 ], domain: "local">

            Quantize (x, scale, zero_point) => (y) {

                y = QuantizeLinear (x, scale, zero_point)

            }

        """
        # If the optional third parameter is specified, it determines the output type.
        self._check(code, [float_type_, float_type_, int8_type_], [], [int8_type_])
        self._check(code, [float_type_, float_type_, uint8_type_], [], [uint8_type_])
        # If the optional third parameter is omitted, the output type is uint8 (default).
        self._check(code, [float_type_, float_type_, no_type_], [], [uint8_type_])

        code = """

            <opset_import: [ "" : 18 ], domain: "local">

            DoClip (x, min, max) => (y) {

                y = Clip (x, min, max)

            }

        """
        # A test-case with a non-trailing missing optional parameter
        self._check(code, [float_type_, no_type_, float_type_], [], [float_type_])

        # A failing test-case with a non-trailing missing optional parameter
        self._check_fails(code, [float_type_, no_type_, int8_type_], [])


if __name__ == "__main__":
    unittest.main()