Spaces:
Running
Running
// Copyright (c) ONNX Project Contributors | |
/* | |
* SPDX-License-Identifier: Apache-2.0 | |
*/ | |
using namespace ONNX_NAMESPACE::shape_inference; | |
namespace ONNX_NAMESPACE { | |
namespace Test { | |
inline bool CompareShape( | |
const TensorShapeProto& inferredShape, | |
const TensorShapeProto& expectedShape, | |
bool checkSameParam = false) { | |
EXPECT_TRUE(inferredShape.dim_size() == expectedShape.dim_size()) | |
<< "Dim size for inferred and expected shape is different."; | |
for (int i = 0; i < inferredShape.dim_size(); i++) { | |
EXPECT_TRUE( | |
(inferredShape.dim(i).has_dim_value() == expectedShape.dim(i).has_dim_value()) && | |
(inferredShape.dim(i).has_dim_param() == expectedShape.dim(i).has_dim_param())) | |
<< "Inferred and expected dim values are different."; | |
EXPECT_TRUE( | |
inferredShape.dim(i).has_dim_value() ? inferredShape.dim(i).dim_value() == expectedShape.dim(i).dim_value() | |
: checkSameParam ? inferredShape.dim(i).dim_param() == expectedShape.dim(i).dim_param() | |
: true) | |
<< "Inferred and expected dims are different."; | |
} | |
return true; | |
} | |
TensorShapeProto RunDataPropagation(const char* graphCode, int domainVersion = 15) { | |
// Parses the graph from graphCode | |
GraphProto graph; | |
OnnxParser parser(graphCode); | |
auto status = parser.Parse(graph); | |
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); | |
EXPECT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; | |
// Constructs name to TypeProto map from value_info, input, output | |
std::unordered_map<std::string, TypeProto*> valueTypesByName; | |
for (auto& vi : *graph.mutable_value_info()) { | |
if (vi.has_type()) { | |
valueTypesByName[vi.name()] = vi.mutable_type(); | |
} | |
} | |
for (auto& vi : *graph.mutable_input()) { | |
if (vi.has_type()) { | |
valueTypesByName[vi.name()] = vi.mutable_type(); | |
} | |
} | |
for (auto& vi : *graph.mutable_output()) { | |
if (vi.has_type()) { | |
valueTypesByName[vi.name()] = vi.mutable_type(); | |
} | |
} | |
// Constructs name to TensorProto map from initializer | |
std::unordered_map<std::string, const TensorProto*> inputDataByName; | |
for (const auto& tp : graph.initializer()) { | |
inputDataByName[tp.name()] = &tp; | |
} | |
// Collects data from constant nodes | |
for (const auto& n : graph.node()) { | |
if (n.op_type() != "Constant" || n.output().size() != 1) { | |
continue; | |
} | |
for (const auto& attr : n.attribute()) { | |
if (attr.name() == "value") { | |
if (attr.type() == AttributeProto::TENSOR && attr.has_t()) { | |
inputDataByName[n.output(0)] = &attr.t(); | |
} | |
} | |
} | |
} | |
// Runs data propagation on each node | |
std::unordered_map<std::string, TensorShapeProto> generatedShapeDataByName; | |
auto* schemaRegistry = OpSchemaRegistry::Instance(); | |
TensorShapeProto inferredShape; | |
for (auto n : graph.node()) { | |
// No need to run data propagation on Constant | |
if (n.op_type() == "Constant") { | |
continue; | |
} | |
DataPropagationContextImpl dataPropagationCtx(n, valueTypesByName, inputDataByName, generatedShapeDataByName); | |
const auto schema = schemaRegistry->GetSchema(n.op_type(), domainVersion, n.domain()); | |
EXPECT_TRUE(schema->has_data_propagation_function()); | |
schema->GetDataPropagationFunction()(dataPropagationCtx); | |
} | |
// Assuming the graph being tested only has 1 output. | |
// If this ever changes then fixes are required here. | |
const auto inputShapeDataIter = generatedShapeDataByName.find(graph.output(0).name()); | |
EXPECT_TRUE(inputShapeDataIter != generatedShapeDataByName.cend()); | |
inferredShape.CopyFrom(inputShapeDataIter->second); | |
// Returns the partial shape data for output | |
return inferredShape; | |
} | |
TEST(DataPropagationImplTest, ShapeTest) { | |
const char* code = R"ONNX( | |
agraph (int32[7,4,1] x) => (int32[3] y) | |
{ | |
xs = Shape(x) | |
y = Cast<to = 7>(xs) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(7); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(4); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(1); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, SymbolicShapeTest) { | |
const char* code = R"ONNX( | |
agraph (int32[N,3,256,256] x) => (int32[4] y) | |
{ | |
xs = Shape(x) | |
y = Cast<to = 7>(xs) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_param("N"); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(3); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(256); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(256); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp, true)); | |
} | |
TEST(DataPropagationImplTest, CastTest) { | |
const char* code = R"ONNX( | |
agraph (int32[2,5] x) => (int32[2] y) | |
{ | |
xs = Shape(x) | |
y = Cast<to = 7>(xs) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(2); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(5); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, SqueezeTest) { | |
const char* code = R"ONNX( | |
agraph (int32[2,5] x) => (int32[2] z) | |
{ | |
xs = Shape(x) | |
y = Squeeze(xs) | |
z = Cast<to = 7>(y) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(2); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(5); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, UnsqueezeTest) { | |
const char* code = R"ONNX( | |
agraph (int32[2,5] x) => (int32[1,2] w) | |
{ | |
xs = Shape(x) | |
axis = Constant<value = int64[1] {1}>() | |
z = Unsqueeze(xs, axis) | |
w = Cast<to = 7>(z) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(2); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(5); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, SizeTest) { | |
const char* code = R"ONNX( | |
agraph (int64[1] x) => (int32[1] w) | |
<int64[3] init = {2,3,5}> | |
{ | |
z = Size(init) | |
w = Cast<to = 7>(z) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(3); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, AddTest) { | |
const char* code = R"ONNX( | |
agraph (int32[2,4,5] x, int32[2,4,5] y) => (int32[3] w) | |
{ | |
xs = Shape(x) | |
ys = Shape(y) | |
z = Add(xs, ys) | |
w = Cast<to = 7>(z) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(4); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(8); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(10); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, AddSymbolicShapeTest) { | |
const char* code = R"ONNX( | |
agraph (int32[2,4,5] x, int32[2,4,M] y) => (int32[3] w) | |
{ | |
xs = Shape(x) | |
ys = Shape(y) | |
z = Add(xs, ys) | |
w = Cast<to = 7>(z) | |
} | |
)ONNX"; | |
// Add({2,4,5}, {2,4,M}) = {4,8,?} | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(4); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(8); | |
// Not computable so do not set value or param | |
expected_tsp.mutable_dim()->Add(); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, SubTest) { | |
const char* code = R"ONNX( | |
agraph (int32[10,11,6] x, int32[5] y) => (int32[3] w) | |
{ | |
xs = Shape(x) | |
ys = Shape(y) | |
z = Sub(xs, ys) | |
w = Cast<to = 7>(z) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(5); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(6); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(1); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, MulTest) { | |
const char* code = R"ONNX( | |
agraph (int32[2] x, int32[5,1,7] y) => (int32[3] w) | |
{ | |
xs = Shape(x) | |
ys = Shape(y) | |
z = Mul(xs, ys) | |
w = Cast<to = 7>(z) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(10); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(2); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(14); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, ConcatTest) { | |
const char* code = R"ONNX( | |
agraph (int32[1,2] x, int32[3,4] y) => (int32[4] w) | |
{ | |
xs = Shape(x) | |
ys = Shape(y) | |
z = Concat<axis = 0>(xs, ys) | |
w = Cast<to = 7>(z) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(1); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(2); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(3); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(4); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, GatherTest) { | |
const char* code = R"ONNX( | |
agraph (int32[1,2,3,4,5,6] x) => (int32[3] w) | |
{ | |
xs = Shape(x) | |
indices = Constant<value = int64[3] {0,3,5}>() | |
z = Gather<axis = 0>(xs, indices) | |
w = Cast<to = 7>(z) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(1); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(4); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(6); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, GatherNegativeIndicesTest) { | |
const char* code = R"ONNX( | |
agraph (int32[1,2,3,4,5,6] x) => (int32[2] w) | |
{ | |
xs = Shape(x) | |
indices = Constant<value = int64[2] {-2,-1}>() | |
z = Gather<axis = 0>(xs, indices) | |
w = Cast<to = 7>(z) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(5); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(6); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, SliceTest) { | |
const char* code = R"ONNX( | |
agraph (int32[1,2,3,4,5,6,7,8] x) => (int32[2] w) | |
{ | |
xs = Shape(x) | |
starts = Constant<value = int64[1] {1}>() | |
ends = Constant<value = int64[1] {7}>() | |
axes = Constant<value = int64[1] {0}>() | |
steps = Constant<value = int64[1] {3}>() | |
z = Slice(xs, starts, ends, axes, steps) | |
w = Cast<to = 7>(z) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(2); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(5); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, SliceDefaultAxesAndStepTest) { | |
const char* code = R"ONNX( | |
agraph (int32[1,2,3,4,5,6,7,8] x) => (int32[3] w) | |
{ | |
xs = Shape(x) | |
starts = Constant<value = int64[1] {2}>() | |
ends = Constant<value = int64[1] {5}>() | |
z = Slice(xs, starts, ends) | |
w = Cast<to = 7>(z) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(3); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(4); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(5); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
TEST(DataPropagationImplTest, SliceNegativeStartEndStepTest) { | |
const char* code = R"ONNX( | |
agraph (int32[1,2,3,4,5,6,7,8] x) => (int32[3] w) | |
{ | |
xs = Shape(x) | |
starts = Constant<value = int64[1] {-3}>() | |
ends = Constant<value = int64[1] {-7}>() | |
axes = Constant<value = int64[1] {0}>() | |
steps = Constant<value = int64[1] {-2}>() | |
z = Slice(xs, starts, ends, axes, steps) | |
w = Cast<to = 7>(z) | |
} | |
)ONNX"; | |
TensorShapeProto expected_tsp; | |
expected_tsp.mutable_dim()->Add()->set_dim_value(6); | |
expected_tsp.mutable_dim()->Add()->set_dim_value(4); | |
const auto propagated_tsp = RunDataPropagation(code); | |
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp)); | |
} | |
} // namespace Test | |
} // namespace ONNX_NAMESPACE | |