# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 import contextlib import unittest from typing import List, Sequence import parameterized import onnx from onnx import defs class TestSchema(unittest.TestCase): def test_get_schema(self) -> None: defs.get_schema("Relu") def test_typecheck(self) -> None: defs.get_schema("Conv") def test_attr_default_value(self) -> None: v = defs.get_schema("BatchNormalization").attributes["epsilon"].default_value self.assertEqual(type(v), onnx.AttributeProto) self.assertEqual(v.type, onnx.AttributeProto.FLOAT) def test_function_body(self) -> None: self.assertEqual( type(defs.get_schema("Selu").function_body), onnx.FunctionProto ) class TestOpSchema(unittest.TestCase): def test_init(self): # Test that the constructor creates an OpSchema object schema = defs.OpSchema("test_op", "test_domain", 1) self.assertIsInstance(schema, defs.OpSchema) def test_init_with_inputs(self) -> None: op_schema = defs.OpSchema( "test_op", "test_domain", 1, inputs=[defs.OpSchema.FormalParameter("input1", "T")], type_constraints=[("T", ["tensor(int64)"], "")], ) self.assertEqual(op_schema.name, "test_op") self.assertEqual(op_schema.domain, "test_domain") self.assertEqual(op_schema.since_version, 1) self.assertEqual(len(op_schema.inputs), 1) self.assertEqual(op_schema.inputs[0].name, "input1") self.assertEqual(op_schema.inputs[0].type_str, "T") self.assertEqual(len(op_schema.type_constraints), 1) self.assertEqual(op_schema.type_constraints[0].type_param_str, "T") self.assertEqual( op_schema.type_constraints[0].allowed_type_strs, ["tensor(int64)"] ) def test_init_creates_multi_input_output_schema(self) -> None: op_schema = defs.OpSchema( "test_op", "test_domain", 1, inputs=[ defs.OpSchema.FormalParameter("input1", "T"), defs.OpSchema.FormalParameter("input2", "T"), ], outputs=[ defs.OpSchema.FormalParameter("output1", "T"), defs.OpSchema.FormalParameter("output2", "T"), ], type_constraints=[("T", ["tensor(int64)"], "")], attributes=[ defs.OpSchema.Attribute( "attr1", defs.OpSchema.AttrType.INTS, "attr1 description" ) ], ) self.assertEqual(len(op_schema.inputs), 2) self.assertEqual(op_schema.inputs[0].name, "input1") self.assertEqual(op_schema.inputs[0].type_str, "T") self.assertEqual(op_schema.inputs[1].name, "input2") self.assertEqual(op_schema.inputs[1].type_str, "T") self.assertEqual(len(op_schema.outputs), 2) self.assertEqual(op_schema.outputs[0].name, "output1") self.assertEqual(op_schema.outputs[0].type_str, "T") self.assertEqual(op_schema.outputs[1].name, "output2") self.assertEqual(op_schema.outputs[1].type_str, "T") self.assertEqual(len(op_schema.type_constraints), 1) self.assertEqual(op_schema.type_constraints[0].type_param_str, "T") self.assertEqual( op_schema.type_constraints[0].allowed_type_strs, ["tensor(int64)"] ) self.assertEqual(len(op_schema.attributes), 1) self.assertEqual(op_schema.attributes["attr1"].name, "attr1") self.assertEqual( op_schema.attributes["attr1"].type, defs.OpSchema.AttrType.INTS ) self.assertEqual(op_schema.attributes["attr1"].description, "attr1 description") def test_init_without_optional_arguments(self) -> None: op_schema = defs.OpSchema("test_op", "test_domain", 1) self.assertEqual(op_schema.name, "test_op") self.assertEqual(op_schema.domain, "test_domain") self.assertEqual(op_schema.since_version, 1) self.assertEqual(len(op_schema.inputs), 0) self.assertEqual(len(op_schema.outputs), 0) self.assertEqual(len(op_schema.type_constraints), 0) def test_name(self): # Test that the name parameter is required and is a string with self.assertRaises(TypeError): defs.OpSchema(domain="test_domain", since_version=1) # type: ignore with self.assertRaises(TypeError): defs.OpSchema(123, "test_domain", 1) # type: ignore schema = defs.OpSchema("test_op", "test_domain", 1) self.assertEqual(schema.name, "test_op") def test_domain(self): # Test that the domain parameter is required and is a string with self.assertRaises(TypeError): defs.OpSchema(name="test_op", since_version=1) # type: ignore with self.assertRaises(TypeError): defs.OpSchema("test_op", 123, 1) # type: ignore schema = defs.OpSchema("test_op", "test_domain", 1) self.assertEqual(schema.domain, "test_domain") def test_since_version(self): # Test that the since_version parameter is required and is an integer with self.assertRaises(TypeError): defs.OpSchema("test_op", "test_domain") # type: ignore schema = defs.OpSchema("test_op", "test_domain", 1) self.assertEqual(schema.since_version, 1) def test_doc(self): schema = defs.OpSchema("test_op", "test_domain", 1, doc="test_doc") self.assertEqual(schema.doc, "test_doc") def test_inputs(self): # Test that the inputs parameter is optional and is a sequence of FormalParameter tuples inputs = [ defs.OpSchema.FormalParameter( name="input1", type_str="T", description="The first input." ) ] schema = defs.OpSchema( "test_op", "test_domain", 1, inputs=inputs, type_constraints=[("T", ["tensor(int64)"], "")], ) self.assertEqual(len(schema.inputs), 1) self.assertEqual(schema.inputs[0].name, "input1") self.assertEqual(schema.inputs[0].type_str, "T") self.assertEqual(schema.inputs[0].description, "The first input.") def test_outputs(self): # Test that the outputs parameter is optional and is a sequence of FormalParameter tuples outputs = [ defs.OpSchema.FormalParameter( name="output1", type_str="T", description="The first output." ) ] schema = defs.OpSchema( "test_op", "test_domain", 1, outputs=outputs, type_constraints=[("T", ["tensor(int64)"], "")], ) self.assertEqual(len(schema.outputs), 1) self.assertEqual(schema.outputs[0].name, "output1") self.assertEqual(schema.outputs[0].type_str, "T") self.assertEqual(schema.outputs[0].description, "The first output.") class TestFormalParameter(unittest.TestCase): def test_init(self): name = "input1" type_str = "tensor(float)" description = "The first input." param_option = defs.OpSchema.FormalParameterOption.Single is_homogeneous = True min_arity = 1 differentiation_category = defs.OpSchema.DifferentiationCategory.Unknown formal_parameter = defs.OpSchema.FormalParameter( name, type_str, description, param_option=param_option, is_homogeneous=is_homogeneous, min_arity=min_arity, differentiation_category=differentiation_category, ) self.assertEqual(formal_parameter.name, name) self.assertEqual(formal_parameter.type_str, type_str) self.assertEqual(formal_parameter.description, description) self.assertEqual(formal_parameter.option, param_option) self.assertEqual(formal_parameter.is_homogeneous, is_homogeneous) self.assertEqual(formal_parameter.min_arity, min_arity) self.assertEqual( formal_parameter.differentiation_category, differentiation_category ) class TestTypeConstraintParam(unittest.TestCase): @parameterized.parameterized.expand( [ ("single_type", "T", ["tensor(float)"], "Test description"), ( "double_types", "T", ["tensor(float)", "tensor(int64)"], "Test description", ), ("tuple", "T", ("tensor(float)", "tensor(int64)"), "Test description"), ] ) def test_init( self, _: str, type_param_str: str, allowed_types: Sequence[str], description: str, ) -> None: type_constraint = defs.OpSchema.TypeConstraintParam( type_param_str, allowed_types, description ) self.assertEqual(type_constraint.description, description) self.assertEqual(type_constraint.allowed_type_strs, list(allowed_types)) self.assertEqual(type_constraint.type_param_str, type_param_str) class TestAttribute(unittest.TestCase): def test_init(self): name = "test_attr" type_ = defs.OpSchema.AttrType.STRINGS description = "Test attribute" attribute = defs.OpSchema.Attribute(name, type_, description) self.assertEqual(attribute.name, name) self.assertEqual(attribute.type, type_) self.assertEqual(attribute.description, description) def test_init_with_default_value(self): default_value = ( defs.get_schema("BatchNormalization").attributes["epsilon"].default_value ) self.assertIsInstance(default_value, onnx.AttributeProto) attribute = defs.OpSchema.Attribute("attr1", default_value, "attr1 description") self.assertEqual(default_value, attribute.default_value) self.assertEqual("attr1", attribute.name) self.assertEqual("attr1 description", attribute.description) @parameterized.parameterized_class( [ # register to exist domain { "op_type": "CustomOp", "op_version": 5, "op_domain": "", "trap_op_version": [1, 2, 6, 7], }, # register to new domain { "op_type": "CustomOp", "op_version": 5, "op_domain": "test", "trap_op_version": [1, 2, 6, 7], }, ] ) class TestOpSchemaRegister(unittest.TestCase): op_type: str op_version: int op_domain: str # register some fake schema to check behavior trap_op_version: List[int] def setUp(self) -> None: # Ensure the schema is unregistered self.assertFalse(onnx.defs.has(self.op_type, self.op_domain)) def tearDown(self) -> None: # Clean up the registered schema for version in [*self.trap_op_version, self.op_version]: with contextlib.suppress(onnx.defs.SchemaError): onnx.defs.deregister_schema(self.op_type, version, self.op_domain) def test_register_multi_schema(self): for version in [*self.trap_op_version, self.op_version]: op_schema = defs.OpSchema( self.op_type, self.op_domain, version, ) onnx.defs.register_schema(op_schema) self.assertTrue(onnx.defs.has(self.op_type, version, self.op_domain)) for version in [*self.trap_op_version, self.op_version]: # Also make sure the `op_schema` is accessible after register registered_op = onnx.defs.get_schema( op_schema.name, version, op_schema.domain ) op_schema = defs.OpSchema( self.op_type, self.op_domain, version, ) self.assertEqual(str(registered_op), str(op_schema)) def test_using_the_specified_version_in_onnx_check(self): input = f""" < ir_version: 7, opset_import: [ "{self.op_domain}" : {self.op_version} ] > agraph (float[N, 128] X, int32 Y) => (float[N] Z) {{ Z = {self.op_domain}.{self.op_type}(X, Y) }} """ model = onnx.parser.parse_model(input) op_schema = defs.OpSchema( self.op_type, self.op_domain, self.op_version, inputs=[ defs.OpSchema.FormalParameter("input1", "T"), defs.OpSchema.FormalParameter("input2", "int32"), ], outputs=[ defs.OpSchema.FormalParameter("output1", "T"), ], type_constraints=[("T", ["tensor(float)"], "")], attributes=[ defs.OpSchema.Attribute( "attr1", defs.OpSchema.AttrType.INTS, "attr1 description" ) ], ) with self.assertRaises(onnx.checker.ValidationError): onnx.checker.check_model(model, check_custom_domain=True) onnx.defs.register_schema(op_schema) # The fake schema will raise check exception if selected in checker for version in self.trap_op_version: onnx.defs.register_schema( defs.OpSchema( self.op_type, self.op_domain, version, outputs=[ defs.OpSchema.FormalParameter("output1", "int32"), ], ) ) onnx.checker.check_model(model, check_custom_domain=True) def test_register_schema_raises_error_when_registering_a_schema_twice(self): op_schema = defs.OpSchema( self.op_type, self.op_domain, self.op_version, ) onnx.defs.register_schema(op_schema) with self.assertRaises(onnx.defs.SchemaError): onnx.defs.register_schema(op_schema) def test_deregister_the_specified_schema(self): for version in [*self.trap_op_version, self.op_version]: op_schema = defs.OpSchema( self.op_type, self.op_domain, version, ) onnx.defs.register_schema(op_schema) self.assertTrue(onnx.defs.has(op_schema.name, version, op_schema.domain)) onnx.defs.deregister_schema(op_schema.name, self.op_version, op_schema.domain) for version in self.trap_op_version: self.assertTrue(onnx.defs.has(op_schema.name, version, op_schema.domain)) # Maybe has lesser op version in trap list if onnx.defs.has(op_schema.name, self.op_version, op_schema.domain): schema = onnx.defs.get_schema( op_schema.name, self.op_version, op_schema.domain ) self.assertLess(schema.since_version, self.op_version) def test_deregister_schema_raises_error_when_opschema_does_not_exist(self): with self.assertRaises(onnx.defs.SchemaError): onnx.defs.deregister_schema(self.op_type, self.op_version, self.op_domain) def test_legacy_schema_accessible_after_deregister(self): op_schema = defs.OpSchema( self.op_type, self.op_domain, self.op_version, ) onnx.defs.register_schema(op_schema) schema_a = onnx.defs.get_schema( op_schema.name, op_schema.since_version, op_schema.domain ) schema_b = onnx.defs.get_schema(op_schema.name, op_schema.domain) def filter_schema(schemas): return [op for op in schemas if op.name == op_schema.name] schema_c = filter_schema(onnx.defs.get_all_schemas()) schema_d = filter_schema(onnx.defs.get_all_schemas_with_history()) self.assertEqual(len(schema_c), 1) self.assertEqual(len(schema_d), 1) # Avoid memory residue and access storage as much as possible self.assertEqual(str(schema_a), str(op_schema)) self.assertEqual(str(schema_b), str(op_schema)) self.assertEqual(str(schema_c[0]), str(op_schema)) self.assertEqual(str(schema_d[0]), str(op_schema)) if __name__ == "__main__": unittest.main()