Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
import unittest | |
from typing import List, Optional | |
import onnx.shape_inference | |
from onnx import ModelProto, TensorProto, TensorShapeProto, ValueInfoProto, helper | |
from onnx.helper import make_model, make_tensor_value_info | |
class TestSymbolicShape(unittest.TestCase): | |
def _assert_valueinfo_shape( | |
self, onnx_model: ModelProto, value_infos: List[ValueInfoProto] | |
) -> None: | |
"""Assert onnx_model.value_info should be the same as expected value_infos | |
Instead of exact symbol, use -1 to represent symbolic shape in expected value_infos | |
""" | |
for expected_vi in value_infos: | |
shape = self._get_shape_from_name(onnx_model, expected_vi.name) | |
assert shape is not None, f"{onnx_model}" | |
if expected_vi.type.HasField("tensor_type"): | |
expected_shape = expected_vi.type.tensor_type.shape | |
elif expected_vi.type.HasField("sparse_tensor_type"): | |
expected_shape = expected_vi.type.sparse_tensor_type.shape | |
assert len(shape.dim) == len(expected_shape.dim), f"{onnx_model}" | |
for dim_i, dim in enumerate(shape.dim): | |
expected_dim = expected_shape.dim[dim_i] | |
# -1 means it's a symbolic shape | |
if expected_dim.dim_value == -1: | |
# symbolic dimension must exist | |
assert dim.dim_param, f"{onnx_model}" | |
else: | |
assert dim.dim_value == expected_dim.dim_value, f"{onnx_model}" | |
def _count_unique_dim_param_number(self, onnx_model: ModelProto) -> int: | |
"""Return the total number of unique symbolic shape""" | |
symbol_shape_set = set() | |
inputs = list(onnx_model.graph.input) | |
outputs = list(onnx_model.graph.output) | |
valueinfos = list(onnx_model.graph.value_info) | |
for v in inputs + outputs + valueinfos: | |
for dim in v.type.tensor_type.shape.dim: | |
if dim.dim_param: | |
symbol_shape_set.add(dim.dim_param) | |
return len(symbol_shape_set) | |
def _get_shape_from_name( | |
self, onnx_model: ModelProto, name: str | |
) -> Optional[TensorShapeProto]: | |
"""Get shape from tensor_type or sparse_tensor_type according to given name""" | |
inputs = list(onnx_model.graph.input) | |
outputs = list(onnx_model.graph.output) | |
valueinfos = list(onnx_model.graph.value_info) | |
for v in inputs + outputs + valueinfos: | |
if v.name == name: | |
if v.type.HasField("tensor_type"): | |
return v.type.tensor_type.shape # type: ignore | |
if v.type.HasField("sparse_tensor_type"): | |
return v.type.sparse_tensor_type.shape # type: ignore | |
return None | |
def test_concat_enable_symbolic(self) -> None: | |
concat = helper.make_node( | |
"Concat", inputs=["A", "B"], outputs=["C"], name="Concat", axis=1 | |
) | |
cast = onnx.helper.make_node( | |
"Cast", inputs=["C"], outputs=["output"], to=TensorProto.FLOAT | |
) | |
graph_def = helper.make_graph( | |
name="test_graph", | |
nodes=[concat, cast], | |
inputs=[ | |
helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, "A"]), | |
helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3]), | |
], | |
outputs=[ | |
helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, None]) | |
], | |
) | |
onnx_model = make_model(graph_def) | |
inferred_model = onnx.shape_inference.infer_shapes(onnx_model, strict_mode=True) | |
self._assert_valueinfo_shape( | |
inferred_model, [make_tensor_value_info("C", TensorProto.FLOAT, (2, -1))] | |
) | |
# the symbolic shape of C and output should be the same | |
assert self._get_shape_from_name( | |
inferred_model, "C" | |
) == self._get_shape_from_name(inferred_model, "output") | |
def test_two_symbolic_concat(self) -> None: | |
concat1 = helper.make_node( | |
"Concat", inputs=["A", "B"], outputs=["C"], name="Concat", axis=1 | |
) | |
concat2 = helper.make_node( | |
"Concat", inputs=["C", "D"], outputs=["E"], name="Concat", axis=1 | |
) | |
cast = onnx.helper.make_node( | |
"Cast", inputs=["E"], outputs=["output"], to=TensorProto.FLOAT | |
) | |
graph_def = helper.make_graph( | |
name="test_graph", | |
nodes=[concat1, concat2, cast], | |
inputs=[ | |
helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, "A"]), | |
helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3]), | |
helper.make_tensor_value_info("D", TensorProto.FLOAT, [2, "D"]), | |
], | |
outputs=[ | |
helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, None]) | |
], | |
) | |
onnx_model = make_model(graph_def) | |
inferred_model = onnx.shape_inference.infer_shapes(onnx_model, strict_mode=True) | |
self._assert_valueinfo_shape( | |
inferred_model, | |
[ | |
make_tensor_value_info("C", TensorProto.FLOAT, (2, -1)), | |
make_tensor_value_info("E", TensorProto.FLOAT, (2, -1)), | |
], | |
) | |
# the symbolic shape of E and output should be the same | |
assert self._get_shape_from_name( | |
inferred_model, "E" | |
) == self._get_shape_from_name(inferred_model, "output") | |
def test_duplicate_symbolic_shape(self) -> None: | |
concat1 = helper.make_node( | |
"Concat", inputs=["A", "B"], outputs=["C"], name="Concat", axis=1 | |
) | |
concat2 = helper.make_node( | |
"Concat", inputs=["C", "D"], outputs=["E"], name="Concat", axis=1 | |
) | |
cast = onnx.helper.make_node( | |
"Cast", inputs=["E"], outputs=["output"], to=TensorProto.FLOAT | |
) | |
graph_def = helper.make_graph( | |
name="test_graph", | |
nodes=[concat1, concat2, cast], | |
inputs=[ | |
helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, "unk__0"]), | |
helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3]), | |
helper.make_tensor_value_info("D", TensorProto.FLOAT, [2, "unk__1"]), | |
], | |
outputs=[ | |
helper.make_tensor_value_info( | |
"output", TensorProto.FLOAT, [2, "unk__0"] | |
) | |
], | |
) | |
onnx_model = make_model(graph_def) | |
original_count = self._count_unique_dim_param_number(onnx_model) | |
inferred_model = onnx.shape_inference.infer_shapes(onnx_model, strict_mode=True) | |
inferred_count = self._count_unique_dim_param_number(inferred_model) | |
# to prevent duplicate so the inferred count will be count + 2 | |
# new symbol 'unk__2' and 'unk__3' should be generated | |
# original: {'unk_0', 'unk__1'} | |
# inferred: {'unk_0', 'unk__1', 'unk__2', 'unk__3'} | |
assert inferred_count == original_count + 2, f"{inferred_model}{onnx_model}" | |
def test_unknown_shape(self) -> None: | |
concat = helper.make_node( | |
"Concat", inputs=["A", "B"], outputs=["C"], name="Concat", axis=1 | |
) | |
cast = onnx.helper.make_node( | |
"Cast", inputs=["C"], outputs=["output"], to=TensorProto.FLOAT | |
) | |
graph_def = helper.make_graph( | |
name="test_graph", | |
nodes=[concat, cast], | |
inputs=[ | |
helper.make_tensor_value_info( | |
"A", TensorProto.FLOAT, [3, None] | |
), # unknown shape | |
helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, None]), | |
], | |
outputs=[ | |
helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, None]) | |
], | |
) | |
onnx_model = make_model(graph_def) | |
inferred_model = onnx.shape_inference.infer_shapes(onnx_model, strict_mode=True) | |
self._assert_valueinfo_shape( | |
inferred_model, [make_tensor_value_info("C", TensorProto.FLOAT, (3, -1))] | |
) | |
# the symbolic shape of C and output should be the same | |
# ('unk__0', 'unk__1') | |
assert self._get_shape_from_name( | |
inferred_model, "C" | |
) == self._get_shape_from_name(inferred_model, "output") | |
if __name__ == "__main__": | |
unittest.main() | |