Spaces:
Sleeping
Sleeping
File size: 4,788 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import unittest
# TODO: remove the following ignore after mypy upgrade in ONNX
from shape_inference_test import TestShapeInferenceHelper
import onnx.parser
from onnx import TensorProto
from onnx.helper import make_node, make_tensor, make_tensor_value_info
class TestDataPropagation(TestShapeInferenceHelper):
def test_expand_symbolic_input(self) -> None:
graph = self._make_graph(
[("x", TensorProto.INT32, (3, 1, 2)), ("y", TensorProto.INT32, (1, 4, 2))],
[
make_node("Shape", ["y"], ["shape"]),
make_node("Expand", ["x", "shape"], ["z"]),
],
[],
)
self._assert_inferred(
graph,
[
make_tensor_value_info("shape", TensorProto.INT64, (3,)),
make_tensor_value_info("z", TensorProto.INT32, (3, 4, 2)),
],
data_prop=True,
)
def test_constantofshape_with_symbolic_shape(self) -> None:
graph = self._make_graph(
[("x", TensorProto.FLOAT, (3, 4, 5))],
[
make_node("Shape", ["x"], ["shape"]),
make_node(
"ConstantOfShape",
["shape"],
["y"],
value=make_tensor("value", TensorProto.INT32, (1,), (2,)),
),
],
[],
)
self._assert_inferred(
graph,
[
make_tensor_value_info("shape", TensorProto.INT64, (3,)),
make_tensor_value_info("y", TensorProto.INT32, (3, 4, 5)),
],
data_prop=True,
) # type: ignore
def test_model_data_propagation(self) -> None:
"""Infer the shape of z by propagating the value of xshape."""
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[4, 1, 16] x, float[1, 8, 16] y) => () {
xshape = Shape (x)
z = Expand (y, xshape)
}
"""
)
self._assert_inferred(
model,
[
make_tensor_value_info("xshape", TensorProto.INT64, (3,)),
make_tensor_value_info("z", TensorProto.FLOAT, (4, 8, 16)),
],
data_prop=True,
)
def test_data_prop_via_function(self) -> None:
"""Test value-propagation through function calls.
Underlying core example is same as previous test_model_data_propagation.
"""
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 18, "local" : 1 ]>
agraph (float[4, 1, 16] x, float[1, 8, 16] y) => () {
xshape = local.GetShape (x)
z = Expand (y, xshape)
}
<domain: "local", opset_import: [ "" : 18 ]>
GetShape (x) => (shapeval) {
shapeval = Shape(x)
}
"""
)
self._assert_inferred(
model,
[
make_tensor_value_info("xshape", TensorProto.INT64, (3,)),
make_tensor_value_info("z", TensorProto.FLOAT, (4, 8, 16)),
],
data_prop=True,
)
def test_multiple_calls_to_function(self) -> None:
"""Test value-propagation handles multiple calls to same function correctly.
Underlying core example is same as previous test_model_data_propagation.
"""
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 18, "local" : 1 ]>
agraph (float[4, 1, 16] x, float[1, 8, 16] y) => () {
yshape = local.GetShape (y)
xshape = local.GetShape (x)
z = Expand (y, xshape)
w = Expand (y, yshape)
}
<domain: "local", opset_import: [ "" : 18 ]>
GetShape (x) => (shapeval) {
shapeval = Shape(x)
}
"""
)
self._assert_inferred(
model,
[
make_tensor_value_info("yshape", TensorProto.INT64, (3,)),
make_tensor_value_info("xshape", TensorProto.INT64, (3,)),
make_tensor_value_info("z", TensorProto.FLOAT, (4, 8, 16)),
make_tensor_value_info("w", TensorProto.FLOAT, (1, 8, 16)),
],
data_prop=True,
)
if __name__ == "__main__":
unittest.main()
|