Spaces:
Sleeping
Sleeping
File size: 6,004 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
"""Graph utilities for checking whether an ONNX proto message is legal."""
from __future__ import annotations
__all__ = [
"check_attribute",
"check_function",
"check_graph",
"check_model",
"check_node",
"check_sparse_tensor",
"check_tensor",
"check_value_info",
"DEFAULT_CONTEXT",
"LEXICAL_SCOPE_CONTEXT",
"ValidationError",
"C",
"MAXIMUM_PROTOBUF",
]
import os
import sys
from typing import Any, Callable, TypeVar
from google.protobuf.message import Message
import onnx.defs
import onnx.onnx_cpp2py_export.checker as C # noqa: N812
import onnx.shape_inference
from onnx import (
IR_VERSION,
AttributeProto,
FunctionProto,
GraphProto,
ModelProto,
NodeProto,
SparseTensorProto,
TensorProto,
ValueInfoProto,
)
# Limitation of single protobuf file is 2GB
MAXIMUM_PROTOBUF = 2000000000
# TODO: This thing where we reserialize the protobuf back into the
# string, only to deserialize it at the call site, is really goofy.
# Stop doing that.
# NB: Please don't edit this context!
DEFAULT_CONTEXT = C.CheckerContext()
DEFAULT_CONTEXT.ir_version = IR_VERSION
# TODO: Maybe ONNX-ML should also be defaulted?
DEFAULT_CONTEXT.opset_imports = {"": onnx.defs.onnx_opset_version()}
LEXICAL_SCOPE_CONTEXT = C.LexicalScopeContext()
FuncType = TypeVar("FuncType", bound=Callable[..., Any])
def _ensure_proto_type(proto: Message, proto_type: type[Message]) -> None:
if not isinstance(proto, proto_type):
raise TypeError(
f"The proto message needs to be of type '{proto_type.__name__}'"
)
def check_value_info(
value_info: ValueInfoProto, ctx: C.CheckerContext = DEFAULT_CONTEXT
) -> None:
_ensure_proto_type(value_info, ValueInfoProto)
return C.check_value_info(value_info.SerializeToString(), ctx)
def check_tensor(tensor: TensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None:
_ensure_proto_type(tensor, TensorProto)
return C.check_tensor(tensor.SerializeToString(), ctx)
def check_attribute(
attr: AttributeProto,
ctx: C.CheckerContext = DEFAULT_CONTEXT,
lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT,
) -> None:
_ensure_proto_type(attr, AttributeProto)
return C.check_attribute(attr.SerializeToString(), ctx, lexical_scope_ctx)
def check_node(
node: NodeProto,
ctx: C.CheckerContext = DEFAULT_CONTEXT,
lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT,
) -> None:
_ensure_proto_type(node, NodeProto)
return C.check_node(node.SerializeToString(), ctx, lexical_scope_ctx)
def check_function(
function: FunctionProto,
ctx: C.CheckerContext | None = None,
lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT,
) -> None:
_ensure_proto_type(function, FunctionProto)
if ctx is None:
ctx = C.CheckerContext()
ctx.ir_version = onnx.helper.find_min_ir_version_for(
function.opset_import, ignore_unknown=True
)
ctx.opset_imports = {
domain_version.domain: domain_version.version
for domain_version in function.opset_import
}
C.check_function(function.SerializeToString(), ctx, lexical_scope_ctx)
def check_graph(
graph: GraphProto,
ctx: C.CheckerContext = DEFAULT_CONTEXT,
lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT,
) -> None:
_ensure_proto_type(graph, GraphProto)
return C.check_graph(graph.SerializeToString(), ctx, lexical_scope_ctx)
def check_sparse_tensor(
sparse: SparseTensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT
) -> None:
_ensure_proto_type(sparse, SparseTensorProto)
C.check_sparse_tensor(sparse.SerializeToString(), ctx)
def check_model(
model: ModelProto | str | bytes | os.PathLike,
full_check: bool = False,
skip_opset_compatibility_check: bool = False,
check_custom_domain: bool = False,
) -> None:
"""Check the consistency of a model.
An exception will be raised if the model's ir_version is not set
properly or is higher than checker's ir_version, or if the model
has duplicate keys in metadata_props.
If IR version >= 3, the model must specify opset_import.
If IR version < 3, the model cannot have any opset_import specified.
Args:
model: Model to check. If model is a path, the function checks model
path first. If the model bytes size is larger than 2GB, function
should be called using model path.
full_check: If True, the function also runs shape inference check.
skip_opset_compatibility_check: If True, the function skips the check for
opset compatibility.
check_custom_domain: If True, the function will check all domains. Otherwise
only check built-in domains.
"""
# If model is a path instead of ModelProto
if isinstance(model, (str, os.PathLike)):
C.check_model_path(
os.fspath(model),
full_check,
skip_opset_compatibility_check,
check_custom_domain,
)
else:
protobuf_string = (
model if isinstance(model, bytes) else model.SerializeToString()
)
# If the protobuf is larger than 2GB,
# remind users should use the model path to check
if sys.getsizeof(protobuf_string) > MAXIMUM_PROTOBUF:
raise ValueError(
"This protobuf of onnx model is too large (>2GB). Call check_model with model path instead."
)
C.check_model(
protobuf_string,
full_check,
skip_opset_compatibility_check,
check_custom_domain,
)
ValidationError = C.ValidationError
|