Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
import contextlib | |
import unittest | |
from typing import List, Sequence | |
import parameterized | |
import onnx | |
from onnx import defs | |
class TestSchema(unittest.TestCase): | |
def test_get_schema(self) -> None: | |
defs.get_schema("Relu") | |
def test_typecheck(self) -> None: | |
defs.get_schema("Conv") | |
def test_attr_default_value(self) -> None: | |
v = defs.get_schema("BatchNormalization").attributes["epsilon"].default_value | |
self.assertEqual(type(v), onnx.AttributeProto) | |
self.assertEqual(v.type, onnx.AttributeProto.FLOAT) | |
def test_function_body(self) -> None: | |
self.assertEqual( | |
type(defs.get_schema("Selu").function_body), onnx.FunctionProto | |
) | |
class TestOpSchema(unittest.TestCase): | |
def test_init(self): | |
# Test that the constructor creates an OpSchema object | |
schema = defs.OpSchema("test_op", "test_domain", 1) | |
self.assertIsInstance(schema, defs.OpSchema) | |
def test_init_with_inputs(self) -> None: | |
op_schema = defs.OpSchema( | |
"test_op", | |
"test_domain", | |
1, | |
inputs=[defs.OpSchema.FormalParameter("input1", "T")], | |
type_constraints=[("T", ["tensor(int64)"], "")], | |
) | |
self.assertEqual(op_schema.name, "test_op") | |
self.assertEqual(op_schema.domain, "test_domain") | |
self.assertEqual(op_schema.since_version, 1) | |
self.assertEqual(len(op_schema.inputs), 1) | |
self.assertEqual(op_schema.inputs[0].name, "input1") | |
self.assertEqual(op_schema.inputs[0].type_str, "T") | |
self.assertEqual(len(op_schema.type_constraints), 1) | |
self.assertEqual(op_schema.type_constraints[0].type_param_str, "T") | |
self.assertEqual( | |
op_schema.type_constraints[0].allowed_type_strs, ["tensor(int64)"] | |
) | |
def test_init_creates_multi_input_output_schema(self) -> None: | |
op_schema = defs.OpSchema( | |
"test_op", | |
"test_domain", | |
1, | |
inputs=[ | |
defs.OpSchema.FormalParameter("input1", "T"), | |
defs.OpSchema.FormalParameter("input2", "T"), | |
], | |
outputs=[ | |
defs.OpSchema.FormalParameter("output1", "T"), | |
defs.OpSchema.FormalParameter("output2", "T"), | |
], | |
type_constraints=[("T", ["tensor(int64)"], "")], | |
attributes=[ | |
defs.OpSchema.Attribute( | |
"attr1", defs.OpSchema.AttrType.INTS, "attr1 description" | |
) | |
], | |
) | |
self.assertEqual(len(op_schema.inputs), 2) | |
self.assertEqual(op_schema.inputs[0].name, "input1") | |
self.assertEqual(op_schema.inputs[0].type_str, "T") | |
self.assertEqual(op_schema.inputs[1].name, "input2") | |
self.assertEqual(op_schema.inputs[1].type_str, "T") | |
self.assertEqual(len(op_schema.outputs), 2) | |
self.assertEqual(op_schema.outputs[0].name, "output1") | |
self.assertEqual(op_schema.outputs[0].type_str, "T") | |
self.assertEqual(op_schema.outputs[1].name, "output2") | |
self.assertEqual(op_schema.outputs[1].type_str, "T") | |
self.assertEqual(len(op_schema.type_constraints), 1) | |
self.assertEqual(op_schema.type_constraints[0].type_param_str, "T") | |
self.assertEqual( | |
op_schema.type_constraints[0].allowed_type_strs, ["tensor(int64)"] | |
) | |
self.assertEqual(len(op_schema.attributes), 1) | |
self.assertEqual(op_schema.attributes["attr1"].name, "attr1") | |
self.assertEqual( | |
op_schema.attributes["attr1"].type, defs.OpSchema.AttrType.INTS | |
) | |
self.assertEqual(op_schema.attributes["attr1"].description, "attr1 description") | |
def test_init_without_optional_arguments(self) -> None: | |
op_schema = defs.OpSchema("test_op", "test_domain", 1) | |
self.assertEqual(op_schema.name, "test_op") | |
self.assertEqual(op_schema.domain, "test_domain") | |
self.assertEqual(op_schema.since_version, 1) | |
self.assertEqual(len(op_schema.inputs), 0) | |
self.assertEqual(len(op_schema.outputs), 0) | |
self.assertEqual(len(op_schema.type_constraints), 0) | |
def test_name(self): | |
# Test that the name parameter is required and is a string | |
with self.assertRaises(TypeError): | |
defs.OpSchema(domain="test_domain", since_version=1) # type: ignore | |
with self.assertRaises(TypeError): | |
defs.OpSchema(123, "test_domain", 1) # type: ignore | |
schema = defs.OpSchema("test_op", "test_domain", 1) | |
self.assertEqual(schema.name, "test_op") | |
def test_domain(self): | |
# Test that the domain parameter is required and is a string | |
with self.assertRaises(TypeError): | |
defs.OpSchema(name="test_op", since_version=1) # type: ignore | |
with self.assertRaises(TypeError): | |
defs.OpSchema("test_op", 123, 1) # type: ignore | |
schema = defs.OpSchema("test_op", "test_domain", 1) | |
self.assertEqual(schema.domain, "test_domain") | |
def test_since_version(self): | |
# Test that the since_version parameter is required and is an integer | |
with self.assertRaises(TypeError): | |
defs.OpSchema("test_op", "test_domain") # type: ignore | |
schema = defs.OpSchema("test_op", "test_domain", 1) | |
self.assertEqual(schema.since_version, 1) | |
def test_doc(self): | |
schema = defs.OpSchema("test_op", "test_domain", 1, doc="test_doc") | |
self.assertEqual(schema.doc, "test_doc") | |
def test_inputs(self): | |
# Test that the inputs parameter is optional and is a sequence of FormalParameter tuples | |
inputs = [ | |
defs.OpSchema.FormalParameter( | |
name="input1", type_str="T", description="The first input." | |
) | |
] | |
schema = defs.OpSchema( | |
"test_op", | |
"test_domain", | |
1, | |
inputs=inputs, | |
type_constraints=[("T", ["tensor(int64)"], "")], | |
) | |
self.assertEqual(len(schema.inputs), 1) | |
self.assertEqual(schema.inputs[0].name, "input1") | |
self.assertEqual(schema.inputs[0].type_str, "T") | |
self.assertEqual(schema.inputs[0].description, "The first input.") | |
def test_outputs(self): | |
# Test that the outputs parameter is optional and is a sequence of FormalParameter tuples | |
outputs = [ | |
defs.OpSchema.FormalParameter( | |
name="output1", type_str="T", description="The first output." | |
) | |
] | |
schema = defs.OpSchema( | |
"test_op", | |
"test_domain", | |
1, | |
outputs=outputs, | |
type_constraints=[("T", ["tensor(int64)"], "")], | |
) | |
self.assertEqual(len(schema.outputs), 1) | |
self.assertEqual(schema.outputs[0].name, "output1") | |
self.assertEqual(schema.outputs[0].type_str, "T") | |
self.assertEqual(schema.outputs[0].description, "The first output.") | |
class TestFormalParameter(unittest.TestCase): | |
def test_init(self): | |
name = "input1" | |
type_str = "tensor(float)" | |
description = "The first input." | |
param_option = defs.OpSchema.FormalParameterOption.Single | |
is_homogeneous = True | |
min_arity = 1 | |
differentiation_category = defs.OpSchema.DifferentiationCategory.Unknown | |
formal_parameter = defs.OpSchema.FormalParameter( | |
name, | |
type_str, | |
description, | |
param_option=param_option, | |
is_homogeneous=is_homogeneous, | |
min_arity=min_arity, | |
differentiation_category=differentiation_category, | |
) | |
self.assertEqual(formal_parameter.name, name) | |
self.assertEqual(formal_parameter.type_str, type_str) | |
self.assertEqual(formal_parameter.description, description) | |
self.assertEqual(formal_parameter.option, param_option) | |
self.assertEqual(formal_parameter.is_homogeneous, is_homogeneous) | |
self.assertEqual(formal_parameter.min_arity, min_arity) | |
self.assertEqual( | |
formal_parameter.differentiation_category, differentiation_category | |
) | |
class TestTypeConstraintParam(unittest.TestCase): | |
def test_init( | |
self, | |
_: str, | |
type_param_str: str, | |
allowed_types: Sequence[str], | |
description: str, | |
) -> None: | |
type_constraint = defs.OpSchema.TypeConstraintParam( | |
type_param_str, allowed_types, description | |
) | |
self.assertEqual(type_constraint.description, description) | |
self.assertEqual(type_constraint.allowed_type_strs, list(allowed_types)) | |
self.assertEqual(type_constraint.type_param_str, type_param_str) | |
class TestAttribute(unittest.TestCase): | |
def test_init(self): | |
name = "test_attr" | |
type_ = defs.OpSchema.AttrType.STRINGS | |
description = "Test attribute" | |
attribute = defs.OpSchema.Attribute(name, type_, description) | |
self.assertEqual(attribute.name, name) | |
self.assertEqual(attribute.type, type_) | |
self.assertEqual(attribute.description, description) | |
def test_init_with_default_value(self): | |
default_value = ( | |
defs.get_schema("BatchNormalization").attributes["epsilon"].default_value | |
) | |
self.assertIsInstance(default_value, onnx.AttributeProto) | |
attribute = defs.OpSchema.Attribute("attr1", default_value, "attr1 description") | |
self.assertEqual(default_value, attribute.default_value) | |
self.assertEqual("attr1", attribute.name) | |
self.assertEqual("attr1 description", attribute.description) | |
class TestOpSchemaRegister(unittest.TestCase): | |
op_type: str | |
op_version: int | |
op_domain: str | |
# register some fake schema to check behavior | |
trap_op_version: List[int] | |
def setUp(self) -> None: | |
# Ensure the schema is unregistered | |
self.assertFalse(onnx.defs.has(self.op_type, self.op_domain)) | |
def tearDown(self) -> None: | |
# Clean up the registered schema | |
for version in [*self.trap_op_version, self.op_version]: | |
with contextlib.suppress(onnx.defs.SchemaError): | |
onnx.defs.deregister_schema(self.op_type, version, self.op_domain) | |
def test_register_multi_schema(self): | |
for version in [*self.trap_op_version, self.op_version]: | |
op_schema = defs.OpSchema( | |
self.op_type, | |
self.op_domain, | |
version, | |
) | |
onnx.defs.register_schema(op_schema) | |
self.assertTrue(onnx.defs.has(self.op_type, version, self.op_domain)) | |
for version in [*self.trap_op_version, self.op_version]: | |
# Also make sure the `op_schema` is accessible after register | |
registered_op = onnx.defs.get_schema( | |
op_schema.name, version, op_schema.domain | |
) | |
op_schema = defs.OpSchema( | |
self.op_type, | |
self.op_domain, | |
version, | |
) | |
self.assertEqual(str(registered_op), str(op_schema)) | |
def test_using_the_specified_version_in_onnx_check(self): | |
input = f""" | |
< | |
ir_version: 7, | |
opset_import: [ | |
"{self.op_domain}" : {self.op_version} | |
] | |
> | |
agraph (float[N, 128] X, int32 Y) => (float[N] Z) | |
{{ | |
Z = {self.op_domain}.{self.op_type}<attr1=[1,2]>(X, Y) | |
}} | |
""" | |
model = onnx.parser.parse_model(input) | |
op_schema = defs.OpSchema( | |
self.op_type, | |
self.op_domain, | |
self.op_version, | |
inputs=[ | |
defs.OpSchema.FormalParameter("input1", "T"), | |
defs.OpSchema.FormalParameter("input2", "int32"), | |
], | |
outputs=[ | |
defs.OpSchema.FormalParameter("output1", "T"), | |
], | |
type_constraints=[("T", ["tensor(float)"], "")], | |
attributes=[ | |
defs.OpSchema.Attribute( | |
"attr1", defs.OpSchema.AttrType.INTS, "attr1 description" | |
) | |
], | |
) | |
with self.assertRaises(onnx.checker.ValidationError): | |
onnx.checker.check_model(model, check_custom_domain=True) | |
onnx.defs.register_schema(op_schema) | |
# The fake schema will raise check exception if selected in checker | |
for version in self.trap_op_version: | |
onnx.defs.register_schema( | |
defs.OpSchema( | |
self.op_type, | |
self.op_domain, | |
version, | |
outputs=[ | |
defs.OpSchema.FormalParameter("output1", "int32"), | |
], | |
) | |
) | |
onnx.checker.check_model(model, check_custom_domain=True) | |
def test_register_schema_raises_error_when_registering_a_schema_twice(self): | |
op_schema = defs.OpSchema( | |
self.op_type, | |
self.op_domain, | |
self.op_version, | |
) | |
onnx.defs.register_schema(op_schema) | |
with self.assertRaises(onnx.defs.SchemaError): | |
onnx.defs.register_schema(op_schema) | |
def test_deregister_the_specified_schema(self): | |
for version in [*self.trap_op_version, self.op_version]: | |
op_schema = defs.OpSchema( | |
self.op_type, | |
self.op_domain, | |
version, | |
) | |
onnx.defs.register_schema(op_schema) | |
self.assertTrue(onnx.defs.has(op_schema.name, version, op_schema.domain)) | |
onnx.defs.deregister_schema(op_schema.name, self.op_version, op_schema.domain) | |
for version in self.trap_op_version: | |
self.assertTrue(onnx.defs.has(op_schema.name, version, op_schema.domain)) | |
# Maybe has lesser op version in trap list | |
if onnx.defs.has(op_schema.name, self.op_version, op_schema.domain): | |
schema = onnx.defs.get_schema( | |
op_schema.name, self.op_version, op_schema.domain | |
) | |
self.assertLess(schema.since_version, self.op_version) | |
def test_deregister_schema_raises_error_when_opschema_does_not_exist(self): | |
with self.assertRaises(onnx.defs.SchemaError): | |
onnx.defs.deregister_schema(self.op_type, self.op_version, self.op_domain) | |
def test_legacy_schema_accessible_after_deregister(self): | |
op_schema = defs.OpSchema( | |
self.op_type, | |
self.op_domain, | |
self.op_version, | |
) | |
onnx.defs.register_schema(op_schema) | |
schema_a = onnx.defs.get_schema( | |
op_schema.name, op_schema.since_version, op_schema.domain | |
) | |
schema_b = onnx.defs.get_schema(op_schema.name, op_schema.domain) | |
def filter_schema(schemas): | |
return [op for op in schemas if op.name == op_schema.name] | |
schema_c = filter_schema(onnx.defs.get_all_schemas()) | |
schema_d = filter_schema(onnx.defs.get_all_schemas_with_history()) | |
self.assertEqual(len(schema_c), 1) | |
self.assertEqual(len(schema_d), 1) | |
# Avoid memory residue and access storage as much as possible | |
self.assertEqual(str(schema_a), str(op_schema)) | |
self.assertEqual(str(schema_b), str(op_schema)) | |
self.assertEqual(str(schema_c[0]), str(op_schema)) | |
self.assertEqual(str(schema_d[0]), str(op_schema)) | |
if __name__ == "__main__": | |
unittest.main() | |