// Copyright (c) ONNX Project Contributors /* * SPDX-License-Identifier: Apache-2.0 */ #include "gtest/gtest.h" #include "onnx/checker.h" #include "onnx/defs/parser.h" #include "onnx/defs/printer.h" using namespace ONNX_NAMESPACE; namespace ONNX_NAMESPACE { namespace Test { template static void Parse(T& parsedData, const char* input) { OnnxParser parser(input); auto status = parser.Parse(parsedData); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); EXPECT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; // Extra checks for printer: // Check we can convert data back to text form. std::string text1 = ProtoToString(parsedData); // Check that we can round-trip between the two representations. // We cannot expect equality between text1 and input due to white-space and syntactic sugar, // so, we convert it once more, and check for equality. T temp; status = OnnxParser::Parse(temp, text1.c_str()); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); std::string text2 = ProtoToString(temp); EXPECT_EQ(text1, text2); } template static void ExpectParseFailure(T& parsedData, const char* input) { auto status = OnnxParser::Parse(parsedData, input); EXPECT_FALSE(status.IsOK()); } static void CheckModel(const char* code) { ModelProto model; Parse(model, code); checker::check_model(model); } TEST(ParserTest, EscapeStringLiteral) { OnnxParser parser(R"( "123\"56\\89" )"); std::string s; auto status = parser.ParserBase::Parse(s); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); EXPECT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; EXPECT_EQ(s, std::string("123\"56\\89")); } TEST(ParserTest, TypeTest) { TypeProto type; // 1-dimensional tensor type with symbolic dimension: Parse(type, "float[N]"); EXPECT_TRUE(type.has_tensor_type()); int float_type = static_cast(TensorProto_DataType::TensorProto_DataType_FLOAT); int int32_type = static_cast(TensorProto_DataType::TensorProto_DataType_INT32); EXPECT_EQ(type.tensor_type().elem_type(), float_type); EXPECT_TRUE(type.tensor_type().has_shape()); EXPECT_EQ(type.tensor_type().shape().dim_size(), 1); EXPECT_EQ(type.tensor_type().shape().dim(0).dim_param(), "N"); // scalar type: Parse(type, "float"); EXPECT_TRUE(type.has_tensor_type()); EXPECT_EQ(type.tensor_type().elem_type(), float_type); EXPECT_TRUE(type.tensor_type().has_shape()); EXPECT_EQ(type.tensor_type().shape().dim_size(), 0); // tensor type with unknown rank: Parse(type, "float[]"); EXPECT_TRUE(type.has_tensor_type()); EXPECT_EQ(type.tensor_type().elem_type(), float_type); EXPECT_FALSE(type.tensor_type().has_shape()); // 3-dimensional tensor Parse(type, "float[N,M,K]"); EXPECT_EQ(type.tensor_type().shape().dim_size(), 3); // Unspecified dimension (neither symbolic nor constant) Parse(type, "float[N,?,K]"); EXPECT_FALSE(type.tensor_type().shape().dim(1).has_dim_param()); EXPECT_FALSE(type.tensor_type().shape().dim(1).has_dim_value()); // sequence type: Parse(type, "seq(float[])"); EXPECT_TRUE(type.has_sequence_type()); auto& elttype = type.sequence_type().elem_type(); EXPECT_TRUE(elttype.has_tensor_type()); EXPECT_EQ(elttype.tensor_type().elem_type(), float_type); EXPECT_FALSE(elttype.tensor_type().has_shape()); // optional type: Parse(type, "optional(float)"); EXPECT_TRUE(type.has_optional_type()); auto& optelttype = type.optional_type().elem_type(); EXPECT_TRUE(optelttype.has_tensor_type()); EXPECT_EQ(optelttype.tensor_type().elem_type(), float_type); EXPECT_TRUE(optelttype.tensor_type().has_shape()); // optional type: Parse(type, "sparse_tensor(float[1000])"); EXPECT_TRUE(type.has_sparse_tensor_type()); EXPECT_EQ(type.sparse_tensor_type().elem_type(), float_type); EXPECT_EQ(type.sparse_tensor_type().shape().dim_size(), 1); // map type: Parse(type, "map(int32, float[N])"); EXPECT_TRUE(type.has_map_type()); EXPECT_EQ(type.map_type().key_type(), int32_type); auto& valtype = type.map_type().value_type(); EXPECT_TRUE(valtype.has_tensor_type()); EXPECT_EQ(valtype.tensor_type().elem_type(), float_type); EXPECT_EQ(valtype.tensor_type().shape().dim_size(), 1); } TEST(ParserTest, TensorProtoTest) { TensorProto tensorProto; // Concrete tensor-type with numeric dimensions expected: ExpectParseFailure(tensorProto, "int32[] {1, 2, 3, 4, 5}"); // Symbolic dimensions are not allowed. ExpectParseFailure(tensorProto, "int32[N] {1, 2, 3, 4, 5}"); Parse(tensorProto, "int32[5] {1, 2, 3, 4, 5}"); Parse(tensorProto, "int32[5] T {1, 2, 3, 4, 5}"); EXPECT_EQ(tensorProto.name(), "T"); Parse(tensorProto, "float[5] {1, 2.0, 3.1, 4, 5.5}"); Parse(tensorProto, "float[5] {1e1, 2.0e-1, 3.1E-1, 4E+1, 5.5e-10}"); Parse(tensorProto, "string[2] { \"Hello\", \"World\" }"); // String literals with escape character Parse(tensorProto, R"( string[2] { "Use a \"quoted\" word", "Use a backslash \\ like this." } )"); } TEST(ParserTest, AttributeTest) { AttributeProto attr; Parse(attr, "x = 2"); EXPECT_EQ(attr.name(), "x"); EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_INT); EXPECT_EQ(attr.i(), 2); Parse(attr, "x = 0.625"); EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT); EXPECT_FLOAT_EQ(attr.f(), 0.625); Parse(attr, "x = [2, 4, 6]"); EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_INTS); EXPECT_EQ(attr.ints_size(), 3); Parse(attr, "x = [0.125, 0.625]"); EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS); EXPECT_EQ(attr.floats_size(), 2); Parse(attr, "x = float[3] {2.1, 4.1, 6.1}"); EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR); Parse(attr, "x = \"astring\""); EXPECT_EQ(attr.name(), "x"); EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_STRING); EXPECT_EQ(attr.s(), "astring"); Parse(attr, "x = [\"abc\", \"def\"]"); EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS); Parse(attr, "x : ints = @xyz"); EXPECT_EQ(attr.ref_attr_name(), "xyz"); EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_INTS); Parse(attr, "x : ints = []"); EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_INTS); EXPECT_EQ(attr.ints_size(), 0); Parse(attr, R"ONNX( body = somegraph (float[N] y, float[N] z) => (float[N] w) { x = foo(y, z) w = bar(x, y) } )ONNX"); EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPH); EXPECT_EQ(attr.g().node_size(), 2); Parse(attr, "type = float[3]"); EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_TYPE_PROTO); EXPECT_TRUE(attr.tp().has_tensor_type()); int float_type = static_cast(TensorProto_DataType::TensorProto_DataType_FLOAT); EXPECT_EQ(attr.tp().tensor_type().elem_type(), float_type); } TEST(ParserTest, AttrListTest) { const char* code = R"ONNX( < x = 2, w = 3 > )ONNX"; AttrList attributes; Parse(attributes, code); EXPECT_EQ(attributes.size(), 2); EXPECT_EQ(attributes.Get(0).name(), "x"); EXPECT_EQ(attributes.Get(1).name(), "w"); } TEST(ParserTest, DomainOpCallTest) { const char* code = "x = somedomain.foo(y, z)"; NodeProto n; Parse(n, code); } TEST(ParserTest, NodeTest) { const char* code = "x = foo(y, z)"; NodeProto n; Parse(n, code); EXPECT_EQ(n.input_size(), 2); EXPECT_EQ(n.input(0), "y"); EXPECT_EQ(n.input(1), "z"); EXPECT_EQ(n.output_size(), 1); EXPECT_EQ(n.output(0), "x"); EXPECT_EQ(n.op_type(), "foo"); NodeList nl; Parse(nl, R"ONNX( { sub_result = Sub(limit, start) sub_result_casted = Cast(sub_result) delta_casted = Cast(delta) div_result = Div(sub_result_casted, delta_casted) ceil_result = Ceil(div_result) ceil_result_relu = Relu(ceil_result) ceil_result_relu_int = Cast(ceil_result_relu) ceil_result_relu_bool = Cast(ceil_result_relu) variadic_output, output = Loop (ceil_result_relu_int, ceil_result_relu_bool, start) } )ONNX"); } TEST(ParserTest, QualifiedOpNameTest) { const char* code = "x = com.example.foo(y, z)"; NodeProto n; Parse(n, code); EXPECT_EQ(n.domain(), "com.example"); EXPECT_EQ(n.op_type(), "foo"); } TEST(ParserTest, NodeListTest) { const char* code = R"ONNX( { x = foo(y, z) w = bar(x, y) } )ONNX"; GraphProto graph; Parse(*graph.mutable_node(), code); EXPECT_EQ(graph.node_size(), 2); EXPECT_EQ(graph.node(0).op_type(), "foo"); EXPECT_EQ(graph.node(1).op_type(), "bar"); } TEST(ParserTest, NodeAttrTest1) { const char* code = "x = foo (y, z)"; NodeProto n; Parse(n, code); EXPECT_EQ(n.attribute_size(), 3); EXPECT_EQ(n.attribute(0).name(), "a"); EXPECT_EQ(n.attribute(1).name(), "b"); EXPECT_EQ(n.attribute(2).name(), "c"); } TEST(ParserTest, NodeAttrTest2) { const char* code = "x = foo (y, z)"; NodeProto n; Parse(n, code); EXPECT_EQ(n.attribute_size(), 3); } TEST(ParserTest, GraphTest) { const char* code = R"ONNX( agraph (float[N] y, float[N] z) => (float[N] w) { # This is a comment. x = foo(y, z, w1) # More comments. w = bar(x, y, w2) } )ONNX"; GraphProto graph; Parse(graph, code); EXPECT_EQ(graph.name(), "agraph"); EXPECT_EQ(graph.input_size(), 2); EXPECT_EQ(graph.output_size(), 1); EXPECT_EQ(graph.node_size(), 2); EXPECT_EQ(graph.initializer_size(), 2); EXPECT_EQ(graph.value_info_size(), 1); } TEST(ParserTest, GraphPartialTypeTest) { const char* code = R"ONNX( agraph (float[N] y, z) => (float[N] w) { x = foo(y, z) w = bar(x, y) } )ONNX"; GraphProto graph; Parse(graph, code); EXPECT_EQ(graph.name(), "agraph"); EXPECT_EQ(graph.input_size(), 2); EXPECT_EQ(graph.output_size(), 1); } TEST(ParserTest, FunctionTest) { const char* code = R"ONNX( < opset_import: [ "" : 10 ], domain: "ai.onnx.ml", doc_string: "A function test case." > f (y, z) => (w) { x = Add(y, z) w = Mul(x, y) } )ONNX"; FunctionProto fp; Parse(fp, code); EXPECT_EQ(fp.name(), "f"); EXPECT_EQ(fp.input_size(), 2); EXPECT_EQ(fp.output_size(), 1); EXPECT_EQ(fp.node_size(), 2); EXPECT_EQ(fp.attribute_size(), 0); EXPECT_EQ(fp.opset_import_size(), 1); } TEST(ParserTest, FunctionValueInfoTest) { const char* code = R"ONNX( < opset_import: [ "" : 10 ], domain: "ai.onnx.ml", doc_string: "A function test case." > f (float[N] y, float[N] z) => (float[N] w) { x = Add(y, z) w = Mul(x, y) } )ONNX"; FunctionProto fp; Parse(fp, code); EXPECT_EQ(fp.input_size(), 2); EXPECT_EQ(fp.output_size(), 1); ASSERT_EQ(fp.value_info_size(), 3); EXPECT_EQ(fp.value_info(0).name(), "y"); EXPECT_EQ(fp.value_info(1).name(), "z"); EXPECT_EQ(fp.value_info(2).name(), "w"); } TEST(ParserTest, FunctionValueInfoTest2) { const char* code = R"ONNX( < opset_import: [ "" : 10 ], domain: "ai.onnx.ml", doc_string: "A function test case." > f (float[N] y, float[N] z) => (float[N] w) { x = Add(y, z) w = Mul(x, y) } )ONNX"; FunctionProto fp; Parse(fp, code); EXPECT_EQ(fp.input_size(), 2); EXPECT_EQ(fp.value_info_size(), 4); ASSERT_EQ(fp.output_size(), 1); EXPECT_EQ(fp.value_info(0).name(), "y"); EXPECT_EQ(fp.value_info(1).name(), "z"); EXPECT_EQ(fp.value_info(2).name(), "w"); EXPECT_EQ(fp.value_info(3).name(), "x"); } TEST(ParserTest, FunctionValueInfoTest3) { const char* code = R"ONNX( < opset_import: [ "" : 10 ], domain: "ai.onnx.ml", doc_string: "A function test case." > f (float[N] y, z) => (float[N] w) { x = Add(y, z) t = Add(x, x) w = Mul(t, y) } )ONNX"; FunctionProto fp; Parse(fp, code); EXPECT_EQ(fp.input_size(), 2); ASSERT_EQ(fp.value_info_size(), 4); EXPECT_EQ(fp.output_size(), 1); EXPECT_EQ(fp.value_info(0).name(), "y"); EXPECT_EQ(fp.value_info(1).name(), "w"); EXPECT_EQ(fp.value_info(2).name(), "x"); EXPECT_EQ(fp.value_info(3).name(), "t"); } TEST(ParserTest, InitializerTest) { const char* code = R"ONNX( agraph (float y = {1.0}, float[N] z) => (float[N] w) { x = foo(y, z, w1) w = bar(x, y, w2) } )ONNX"; GraphProto graph; Parse(graph, code); EXPECT_EQ(graph.input_size(), 2); EXPECT_EQ(graph.output_size(), 1); EXPECT_EQ(graph.initializer_size(), 3); // y, w1, w2 EXPECT_EQ(graph.value_info_size(), 1); // x } TEST(ParserTest, IfNodeTest) { const char* code = R"ONNX( z = If (b) < then_branch = g1 () => (float[N] z_then) { z_then = foo(y) }, else_branch = g2 () => (float[N] z_else) { z_else = bar(x) } > )ONNX"; NodeProto node; Parse(node, code); EXPECT_EQ(node.input_size(), 1); EXPECT_EQ(node.output_size(), 1); EXPECT_EQ(node.attribute_size(), 2); } TEST(ParserTest, ModelTest) { const char* code = R"ONNX( < ir_version: 7, opset_import: [ "ai.onnx.ml" : 10 ], producer_name: "ParserTest", producer_version: "1.0", domain: "ai.onnx.ml", model_version: 1, doc_string: "A parser test case model.", metadata_props: [ "somekey" : "somevalue", "key2" : "value2" ] > agraph (float[N] y, float[N] z) => (float[N] w) { x = foo(y, z) w = bar(x, y) } )ONNX"; ModelProto model; Parse(model, code); EXPECT_EQ(model.graph().input_size(), 2); EXPECT_EQ(model.graph().output_size(), 1); EXPECT_EQ(model.graph().node_size(), 2); } TEST(ParserTest, ModelCheckTest) { const char* code = R"ONNX( < ir_version: 7, opset_import: [ "" : 10 ] > agraph (float[N, 128] X, float[128,10] W, float[10] B) => (float[N] C) { T = MatMul(X, W) S = Add(T, B) C = Softmax(S) } )ONNX"; CheckModel(code); } TEST(ParserTest, IfModelTest) { const char* code = R"ONNX( < ir_version: 7, opset_import: [ "" : 13 ] > iftest (bool b, float[128] X, float[128] Y) => (float[128] Z) { Z = If (b) < then_branch = g1 () => (float[128] z_then) { z_then = Identity(X) }, else_branch = g2 () => (float[128] z_else) { z_else = Identity(Y) } > } )ONNX"; CheckModel(code); } TEST(ParserTest, FunModelTest) { const char* code = R"ONNX( < ir_version: 8, opset_import: [ "" : 10, "local" : 1 ] > agraph (float[N, 128] X, float[128,10] W, float[10] B) => (float[N] C) { T = local.foo (X, W, B) C = local.square(T) } < opset_import: [ "" : 10 ], domain: "local", doc_string: "Function foo." > foo (x, w, b) => (c) { T = MatMul(x, w) S = Add(T, b) c = Softmax(S) } < opset_import: [ "" : 10 ], domain: "local", doc_string: "Function square." > square (x) => (y) { y = Mul (x, x) } )ONNX"; CheckModel(code); const char* code_function_with_attributes = R"ONNX( < ir_version: 9, opset_import: [ "" : 15, "custom_domain" : 1] > agraph (float[N] x) => (float[N] out) { out = custom_domain.foo(x) } < domain: "custom_domain", opset_import: [ "" : 15], doc_string: "function foo" > foo (X) => (C) { constant_alpha = Constant() constant_gamma = Constant() constant_alpha_x = Mul(constant_alpha, X) C = Add(constant_alpha_x, constant_gamma) } )ONNX"; CheckModel(code_function_with_attributes); } TEST(ParserTest, TypesModelTest1) { const char* code = R"ONNX( < ir_version: 8, opset_import: [ "" : 18 ] > agraph (seq(float[N]) seqX) => (float[M, N] X) { X = ConcatFromSequence < axis = 0, new_axis = 1 >(seqX) } )ONNX"; CheckModel(code); } TEST(ParserTest, TypesModelTest2) { const char* code = R"ONNX( < ir_version: 8, opset_import: [ "" : 18 ] > agraph (float[N] tensorX, seq(float[N]) seqX, map(int32, float[N]) mapX, optional(float[N]) optionalX, sparse_tensor(float[N]) sparseX) => (float[N] X) { X = Identity (tensorX) } )ONNX"; CheckModel(code); } TEST(ParserTest, ExternalDataTest) { const char* code = R"ONNX( agraph (float y = {1.0}, float[N] z) => (w) < float[3, 2] m1 = ["location": "weight_1.bin", "offset": "17"], float[2, 1] m2 = {1.0, 2.0} > { x = Add(y, z) m = Mul(m1, m1) } )ONNX"; GraphProto graph; Parse(graph, code); EXPECT_EQ(graph.input_size(), 2); EXPECT_EQ(graph.output_size(), 1); EXPECT_EQ(graph.initializer_size(), 3); // m1, m2 EXPECT_EQ(graph.value_info_size(), 0); // x EXPECT_EQ(graph.initializer().Get(1).data_location(), TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); EXPECT_EQ(graph.initializer().Get(1).external_data().Get(0).key(), "location"); EXPECT_EQ(graph.initializer().Get(1).external_data().Get(0).value(), "weight_1.bin"); EXPECT_EQ(graph.initializer().Get(1).external_data().Get(1).key(), "offset"); EXPECT_EQ(graph.initializer().Get(1).external_data().Get(1).value(), "17"); } } // namespace Test } // namespace ONNX_NAMESPACE