File size: 1,886 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Sequence

import numpy as np

import onnx
from onnx.backend.test.case.base import Base
from onnx.backend.test.case.node import expect


class Concat(Base):
    @staticmethod
    def export() -> None:
        test_cases: Dict[str, Sequence[Any]] = {
            "1d": ([1, 2], [3, 4]),
            "2d": ([[1, 2], [3, 4]], [[5, 6], [7, 8]]),
            "3d": (
                [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
                [[[9, 10], [11, 12]], [[13, 14], [15, 16]]],
            ),
        }

        for test_case, values_ in test_cases.items():
            values = [np.asarray(v, dtype=np.float32) for v in values_]
            for i in range(len(values[0].shape)):
                in_args = ["value" + str(k) for k in range(len(values))]
                node = onnx.helper.make_node(
                    "Concat", inputs=list(in_args), outputs=["output"], axis=i
                )
                output = np.concatenate(values, i)
                expect(
                    node,
                    inputs=list(values),
                    outputs=[output],
                    name="test_concat_" + test_case + "_axis_" + str(i),
                )

            for i in range(-len(values[0].shape), 0):
                in_args = ["value" + str(k) for k in range(len(values))]
                node = onnx.helper.make_node(
                    "Concat", inputs=list(in_args), outputs=["output"], axis=i
                )
                output = np.concatenate(values, i)
                expect(
                    node,
                    inputs=list(values),
                    outputs=[output],
                    name="test_concat_" + test_case + "_axis_negative_" + str(abs(i)),
                )