# Copyright (c) ONNX Project Contributors # # SPDX-License-Identifier: Apache-2.0 __all__ = [ "C", "ONNX_DOMAIN", "ONNX_ML_DOMAIN", "AI_ONNX_PREVIEW_TRAINING_DOMAIN", "has", "register_schema", "deregister_schema", "get_schema", "get_all_schemas", "get_all_schemas_with_history", "onnx_opset_version", "get_function_ops", "OpSchema", "SchemaError", ] from typing import List import onnx.onnx_cpp2py_export.defs as C # noqa: N812 from onnx import AttributeProto, FunctionProto ONNX_DOMAIN = "" ONNX_ML_DOMAIN = "ai.onnx.ml" AI_ONNX_PREVIEW_TRAINING_DOMAIN = "ai.onnx.preview.training" has = C.has_schema get_schema = C.get_schema get_all_schemas = C.get_all_schemas get_all_schemas_with_history = C.get_all_schemas_with_history deregister_schema = C.deregister_schema def onnx_opset_version() -> int: """Return current opset for domain `ai.onnx`.""" return C.schema_version_map()[ONNX_DOMAIN][1] def onnx_ml_opset_version() -> int: """Return current opset for domain `ai.onnx.ml`.""" return C.schema_version_map()[ONNX_ML_DOMAIN][1] @property # type: ignore def _function_proto(self): # type: ignore func_proto = FunctionProto() func_proto.ParseFromString(self._function_body) return func_proto OpSchema = C.OpSchema # type: ignore OpSchema.function_body = _function_proto # type: ignore @property # type: ignore def _attribute_default_value(self): # type: ignore attr = AttributeProto() attr.ParseFromString(self._default_value) return attr OpSchema.Attribute.default_value = _attribute_default_value # type: ignore def _op_schema_repr(self) -> str: return f"""\ OpSchema( name={self.name!r}, domain={self.domain!r}, since_version={self.since_version!r}, doc={self.doc!r}, type_constraints={self.type_constraints!r}, inputs={self.inputs!r}, outputs={self.outputs!r}, attributes={self.attributes!r} )""" OpSchema.__repr__ = _op_schema_repr # type: ignore def _op_schema_formal_parameter_repr(self) -> str: return ( f"OpSchema.FormalParameter(name={self.name!r}, type_str={self.type_str!r}, " f"description={self.description!r}, param_option={self.option!r}, " f"is_homogeneous={self.is_homogeneous!r}, min_arity={self.min_arity!r}, " f"differentiation_category={self.differentiation_category!r})" ) OpSchema.FormalParameter.__repr__ = _op_schema_formal_parameter_repr # type: ignore def _op_schema_type_constraint_param_repr(self) -> str: return ( f"OpSchema.TypeConstraintParam(type_param_str={self.type_param_str!r}, " f"allowed_type_strs={self.allowed_type_strs!r}, description={self.description!r})" ) OpSchema.TypeConstraintParam.__repr__ = _op_schema_type_constraint_param_repr # type: ignore def _op_schema_attribute_repr(self) -> str: return ( f"OpSchema.Attribute(name={self.name!r}, type={self.type!r}, description={self.description!r}, " f"default_value={self.default_value!r}, required={self.required!r})" ) OpSchema.Attribute.__repr__ = _op_schema_attribute_repr # type: ignore def get_function_ops() -> List[OpSchema]: """Return operators defined as functions.""" schemas = C.get_all_schemas() return [schema for schema in schemas if schema.has_function or schema.has_context_dependent_function] # type: ignore SchemaError = C.SchemaError def register_schema(schema: OpSchema) -> None: """Register a user provided OpSchema. The function extends available operator set versions for the provided domain if necessary. Args: schema: The OpSchema to register. """ version_map = C.schema_version_map() domain = schema.domain version = schema.since_version min_version, max_version = version_map.get(domain, (version, version)) if domain not in version_map or not (min_version <= version <= max_version): min_version = min(min_version, version) max_version = max(max_version, version) C.set_domain_to_version(schema.domain, min_version, max_version) C.register_schema(schema)