File size: 3,314 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0

from typing import Sequence

import numpy as np

import onnx
from onnx.backend.test.case.base import Base
from onnx.backend.test.case.model import expect


class ExpandDynamicShape(Base):
    @staticmethod
    def export() -> None:
        def make_graph(

            node: onnx.helper.NodeProto,

            input_shape: Sequence[int],

            shape_shape: Sequence[int],

            output_shape: Sequence[int],

        ) -> onnx.helper.GraphProto:
            graph = onnx.helper.make_graph(
                nodes=[node],
                name="Expand",
                inputs=[
                    onnx.helper.make_tensor_value_info(
                        "X", onnx.TensorProto.FLOAT, input_shape
                    ),
                    onnx.helper.make_tensor_value_info(
                        "shape", onnx.TensorProto.INT64, shape_shape
                    ),
                ],
                outputs=[
                    onnx.helper.make_tensor_value_info(
                        "Y", onnx.TensorProto.FLOAT, output_shape
                    )
                ],
            )
            return graph

        node = onnx.helper.make_node("Expand", ["X", "shape"], ["Y"], name="test")
        input_shape = [1, 3, 1]
        x = np.ones(input_shape, dtype=np.float32)

        # 1st testcase
        shape = np.array([3, 1], dtype=np.int64)
        y = x * np.ones(shape, dtype=np.float32)
        graph = make_graph(node, input_shape, shape.shape, y.shape)
        model = onnx.helper.make_model_gen_version(
            graph,
            producer_name="backend-test",
            opset_imports=[onnx.helper.make_opsetid("", 9)],
        )
        expect(model, inputs=[x, shape], outputs=[y], name="test_expand_shape_model1")

        # 2nd testcase
        shape = np.array([1, 3], dtype=np.int64)
        y = x * np.ones(shape, dtype=np.float32)
        graph = make_graph(node, input_shape, shape.shape, y.shape)
        model = onnx.helper.make_model_gen_version(
            graph,
            producer_name="backend-test",
            opset_imports=[onnx.helper.make_opsetid("", 9)],
        )
        expect(model, inputs=[x, shape], outputs=[y], name="test_expand_shape_model2")

        # 3rd testcase
        shape = np.array([3, 1, 3], dtype=np.int64)
        y = x * np.ones(shape, dtype=np.float32)
        graph = make_graph(node, input_shape, shape.shape, y.shape)
        model = onnx.helper.make_model_gen_version(
            graph,
            producer_name="backend-test",
            opset_imports=[onnx.helper.make_opsetid("", 9)],
        )
        expect(model, inputs=[x, shape], outputs=[y], name="test_expand_shape_model3")

        # 4th testcase
        shape = np.array([3, 3, 1, 3], dtype=np.int64)
        y = x * np.ones(shape, dtype=np.float32)
        graph = make_graph(node, input_shape, shape.shape, y.shape)
        model = onnx.helper.make_model_gen_version(
            graph,
            producer_name="backend-test",
            opset_imports=[onnx.helper.make_opsetid("", 9)],
        )
        expect(model, inputs=[x, shape], outputs=[y], name="test_expand_shape_model4")