Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
2.81 kB
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import onnx
from onnx.backend.test.case.base import Base
from onnx.backend.test.case.node import expect
class Compress(Base):
@staticmethod
def export_compress_0() -> None:
node = onnx.helper.make_node(
"Compress",
inputs=["input", "condition"],
outputs=["output"],
axis=0,
)
input = np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32)
condition = np.array([0, 1, 1])
output = np.compress(condition, input, axis=0)
# print(output)
# [[ 3. 4.]
# [ 5. 6.]]
expect(
node,
inputs=[input, condition.astype(bool)],
outputs=[output],
name="test_compress_0",
)
@staticmethod
def export_compress_1() -> None:
node = onnx.helper.make_node(
"Compress",
inputs=["input", "condition"],
outputs=["output"],
axis=1,
)
input = np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32)
condition = np.array([0, 1])
output = np.compress(condition, input, axis=1)
# print(output)
# [[ 2.]
# [ 4.]
# [ 6.]]
expect(
node,
inputs=[input, condition.astype(bool)],
outputs=[output],
name="test_compress_1",
)
@staticmethod
def export_compress_default_axis() -> None:
node = onnx.helper.make_node(
"Compress",
inputs=["input", "condition"],
outputs=["output"],
)
input = np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32)
condition = np.array([0, 1, 0, 0, 1])
output = np.compress(condition, input)
# print(output)
# [ 2., 5.]
expect(
node,
inputs=[input, condition.astype(bool)],
outputs=[output],
name="test_compress_default_axis",
)
@staticmethod
def export_compress_negative_axis() -> None:
node = onnx.helper.make_node(
"Compress",
inputs=["input", "condition"],
outputs=["output"],
axis=-1,
)
input = np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32)
condition = np.array([0, 1])
output = np.compress(condition, input, axis=-1)
# print(output)
# [[ 2.]
# [ 4.]
# [ 6.]]
expect(
node,
inputs=[input, condition.astype(bool)],
outputs=[output],
name="test_compress_negative_axis",
)