Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
from io import BytesIO | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import numpy as np | |
from onnx import load | |
from onnx.defs import onnx_opset_version | |
from onnx.external_data_helper import ExternalDataInfo, uses_external_data | |
from onnx.model_container import ModelContainer | |
from onnx.onnx_pb import ( | |
FunctionProto, | |
GraphProto, | |
ModelProto, | |
NodeProto, | |
TensorProto, | |
TypeProto, | |
) | |
from onnx.reference.op_run import ( | |
OpFunctionContextDependant, | |
OpRun, | |
OpRunExpand, | |
RuntimeContextError, | |
to_array_extended, | |
) | |
from onnx.reference.ops_optimized import optimized_operators | |
class ReferenceEvaluator: | |
r"""Computes the outputs of an ONNX proto (`ModelProto`, `FunctionProto`, `GraphProto`, `NodeProto`). | |
This is a pure python implementation of ONNX specifications. | |
Mismatches may remain between the official specifications and the implementation here. | |
In the case of such a mismatch, the official spec overrides this implementation. | |
Args: | |
proto: :class:`onnx.ModelProto`, :class:`onnx.GraphProto`, | |
:class:`onnx.FunctionProto`, :class:`onnx.NodeProto`, | |
filename or bytes | |
verbose: display intermediate results on the standard output | |
during the execution | |
opsets: if *proto* is an instance of *GraphProto*, opsets must | |
be defined by a dictionary of | |
functions: known onnx functions | |
new_ops: this runtime can be used to test the implementations of | |
new operators, *new_ops* is a list of classes derived from | |
:class:`OpRun <onnx.reference.op_run.OpRun>`, every class | |
must define the static attribute `domain`, there may be | |
multiple implementations for the same operator, the first | |
one in the list is used. | |
optimized: some operators have two implementations, a naive one | |
corresponding to definition of the mathematical definition | |
of the operator, another one more efficient. This is the | |
case for operator Conv. The naive version is ten times | |
slower than the optimized one using a decomposition into | |
*Conv = im2col + Gemm*. If True, all optimized kernels are | |
added in `new_ops` and are used instead of the inner | |
implementation if list *new_ops* does not already contain | |
one. | |
The class maps every node to its associated implementation. | |
When a subgraph of a function is met, | |
it uses this class to execute the subgraph or the function. | |
Next example shows how to run `ReferenceEvaluator` with an onnx model | |
stored in file `model.onnx`. | |
:: | |
import numpy as np | |
from onnx.reference import ReferenceEvaluator | |
X = np.array(...) | |
sess = ReferenceEvaluator("model.onnx") | |
results = sess.run(None, {"X": X}) | |
print(results[0]) # display the first result | |
Parameter *verbose* may be used to show intermediate results. | |
:: | |
import numpy as np | |
from onnx.reference import ReferenceEvaluator | |
X = np.array(...) | |
sess = ReferenceEvaluator("model.onnx", verbose=1) | |
results = sess.run(None, {"X": X}) | |
print(results[0]) # display the first result | |
The class can use any implementation available in folder | |
`ops <https://github.com/onnx/onnx/tree/main/onnx/reference/ops>`_. | |
Adding an implementation requires two changes. The first one is | |
the implementation itself. Any existing node can be used as a template. | |
The second is one line in file `_op_list.py | |
<https://github.com/onnx/onnx/tree/main/onnx/reference/ops/_op_list.py>`_ | |
to import the file and let the reference evaluator know it exists. | |
This class can also be used to test an implementation of | |
a custom operator. Let's assume this new operator | |
is `InvAlpha` from domain `custom`. The implementation | |
must take place in a class inheriting from | |
:class:`OpRun <onnx.reference.op_run.OpRun>`. | |
It must also define attribute `op_domain`. | |
Here is an example which computes :math:`\\frac{1}{X + \\alpha}`. | |
.. exec_code:: | |
from onnx.reference.op_run import OpRun | |
class InvAlpha(OpRun): | |
op_domain = "custom" | |
def _run(self, x, alpha=None): # type: ignore | |
# None must be the default value, it is automatically | |
# replaced by class OpRun with either the default value | |
# specified in the NodeProto or an attribute value defined | |
# in a `FunctionProto`. | |
return (1 / (x + alpha),) | |
`alpha` is an attribute. It can be defined by the onnx node or | |
be defined by the function using this node. It is safe to assume | |
that attributes are known at the same time as the input. | |
Class `ReferenceEvaluator` must know about this new implementation | |
and this can be done by specified argument *new_ops*. | |
:: | |
sess = ReferenceEvaluator(onnx_model, new_ops=[InvAlpha]) | |
got = sess.run(None, {"X": x})[0] | |
A specific node can be simply evaluated. | |
.. exec_code:: | |
import numpy as np | |
from onnx.reference.ops._op_list import Celu | |
x = np.array([[0, 1], [-1, 2]], dtype=np.float32) | |
y = Celu.eval(x, alpha=0.5) | |
print(y) | |
This can also be expressed as: | |
.. exec_code:: | |
import numpy as np | |
from onnx.reference.ops import load_op | |
Celu = load_op("", "Celu") # domain is "" | |
x = np.array([[0, 1], [-1, 2]], dtype=np.float32) | |
y = Celu.eval(x, alpha=0.5) | |
print(y) | |
It is possible to overwrite an existing operator. | |
The class name must be the same. The domain does not have | |
to be specified for the default domain. However, by default, | |
class `OpRun` will load the most recent for this operator. | |
It can be explicitly specified by adding static attribute | |
`op_schema` of type :class:`OpSchema | |
<onnx.onnx_cpp2py_export.defs.OpSchema>`. | |
:: | |
from onnx.reference.op_run.op_conv import Conv as _Conv | |
class Conv(_Conv): | |
op_schema = instance_of_OpSchema() | |
def _run(self, ...): | |
... | |
An operator may be different in a later opset. In that case, | |
a new implementation needs to be registered. `Pad_11`, `Pad_18`. | |
`Pad_11` is the implementation chose for opset in [11, 17]. | |
`Pad_18` is selected for any greater opset. Both classes must be | |
imported into file `_op_list.py` to register their existence to the | |
runtime. | |
An operator may have a reference implementation such as `CastLike` | |
and still be defined as a function. By default, the reference implementation | |
is used. This behaviour can be changed by adding a class to the list | |
of overwritten operators. It must inherit from :class:`OpRunExpand`. | |
:: | |
from onnx.reference.op_run import OpRunExpand | |
class CastLike(OpRunExpand): | |
op_domain = "" | |
ref = ReferenceEvaluator(model, new_ops=[CastLike]) | |
# ... | |
This mechanism is used in unit test to check the function | |
implementation a schema may define. | |
""" | |
def __init__( # type: ignore | |
self, | |
proto: Any, | |
opsets: Optional[Dict[str, int]] = None, | |
functions: Optional[List[Union["ReferenceEvaluator", FunctionProto]]] = None, # type: ignore | |
verbose: int = 0, | |
new_ops: Optional[List[OpRun]] = None, | |
optimized: bool = True, | |
): | |
if optimized: | |
if new_ops is None: | |
new_ops = optimized_operators.copy() | |
else: | |
set_new_ops = set(new_ops) | |
for op in optimized_operators: | |
if op not in set_new_ops: | |
new_ops.append(op) | |
self.output_types_ = None | |
self.input_types_ = None | |
if isinstance(proto, ModelContainer): | |
self.container_ = proto | |
proto = self.container_.model_proto | |
else: | |
self.container_ = None | |
if isinstance(proto, str): | |
with open(proto, "rb") as f: | |
proto = load(f) | |
elif isinstance(proto, bytes): | |
proto = load(BytesIO(proto)) | |
self.proto_ = proto | |
self.functions_: Dict[Tuple[str, str], ReferenceEvaluator] = {} | |
self.attributes_: List[str] = [] | |
if isinstance(proto, ModelProto): | |
self.onnx_graph_ = proto.graph | |
self.opsets_ = {d.domain: d.version for d in proto.opset_import} | |
if opsets is not None: | |
raise ValueError("opsets must be None if proto is ModelProto.") | |
if functions is not None: | |
raise ValueError("functions must be None if proto is ModelProto.") | |
functions = proto.functions # type: ignore[assignment] | |
elif isinstance(proto, GraphProto): | |
self.onnx_graph_ = proto | |
if not isinstance(opsets, dict): | |
raise ValueError("opsets must be a dictionary if proto is GraphProto.") | |
self.opsets_ = opsets | |
elif isinstance(proto, FunctionProto): | |
self.onnx_graph_ = None # type: ignore | |
self.opsets_ = {d.domain: d.version for d in proto.opset_import} | |
if opsets is not None: | |
raise ValueError("opsets must be None if proto is FunctionProto.") | |
self.attributes_ = list(proto.attribute) | |
elif isinstance(proto, NodeProto): | |
self.onnx_graph_ = None # type: ignore | |
self.opsets_ = { | |
proto.domain: 1 if proto.domain != "" else onnx_opset_version() | |
} | |
else: | |
raise TypeError(f"Unexpected type {type(proto)} for proto.") | |
if self.onnx_graph_: | |
self.input_names_ = [i.name for i in self.onnx_graph_.input] | |
self.input_types_ = [i.type for i in self.onnx_graph_.input] | |
self.output_names_ = [o.name for o in self.onnx_graph_.output] | |
self.output_types_ = [i.type for i in self.onnx_graph_.output] | |
self.inits_ = list(self.onnx_graph_.initializer) + list( | |
self.onnx_graph_.sparse_initializer # type: ignore | |
) | |
self.nodes_ = self.onnx_graph_.node | |
all_types = {i.name: i.type for i in self.onnx_graph_.input} | |
if hasattr(self.proto_, "value_info"): | |
for shape_type in self.proto_.value_info: | |
all_types[shape_type.name] = shape_type.type | |
self.all_types_ = all_types | |
else: | |
self.input_names_ = list(proto.input) | |
self.output_names_ = list(proto.output) | |
self.inits_ = [] | |
if isinstance(proto, NodeProto): | |
self.nodes_ = [proto] # type: ignore[assignment] | |
else: | |
self.nodes_ = proto.node | |
if functions is not None: | |
for f in functions: # type: ignore | |
if isinstance(f, FunctionProto): | |
self.functions_[f.domain, f.name] = self.__class__( | |
f, verbose=verbose, functions=list(self.functions_.values()) | |
) | |
elif isinstance(f, ReferenceEvaluator): | |
onx = f.proto_ # type: ignore | |
self.functions_[onx.domain, onx.name] = f | |
else: | |
raise TypeError(f"Unexpected type {type(f)!r} for a function.") | |
self.verbose = verbose | |
self.new_ops_: Dict[Tuple[str, str], OpRun] = {} | |
if new_ops is not None: | |
for cl in new_ops: | |
if not hasattr(cl, "op_domain"): | |
raise AttributeError( | |
f"Class {cl} must define attribute 'op_domain'." | |
) | |
if not issubclass(cl, OpRun): # type: ignore | |
raise TypeError(f"Class {cl} must inherit from OpRun (in new_ops).") | |
key = cl.op_domain, cl.__name__ # type: ignore | |
if key in self.new_ops_: | |
# Already an implementation, the first one is used. | |
continue | |
self.new_ops_[key] = cl | |
self._init() | |
def retrieve_external_data(self, initializer: TensorProto) -> np.array: | |
"""Returns a tensor saved as external.""" | |
info = ExternalDataInfo(initializer) | |
location = info.location | |
if self.container_ and self.container_.is_in_memory_external_initializer( | |
location | |
): | |
# It comes from a large container. | |
return self.container_[location] | |
# Otherwise, the data is on disk. | |
if self.container_ is not None: | |
raise RuntimeError( | |
"ReferenceEvaluator assumes a LargeContainer was loaded with its external tensor." | |
) | |
raise RuntimeError( | |
"An instance of LargeContainer should be created before using ReferenceEvaluator." | |
) | |
def _log_arg(self, a: Any) -> Any: | |
if isinstance(a, (str, int, float)): | |
return a | |
if isinstance(a, np.ndarray): | |
if self.verbose < 4: # noqa: PLR2004 | |
return f"{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]" | |
elements = a.ravel().tolist() | |
if len(elements) > 5: # noqa: PLR2004 | |
elements = elements[:5] | |
return f"{a.dtype}:{a.shape}:{','.join(map(str, elements))}..." | |
return f"{a.dtype}:{a.shape}:{elements}" | |
if hasattr(a, "append"): | |
return ", ".join(map(self._log_arg, a)) | |
return a | |
def _log(self, level: int, pattern: str, *args: List[Any]) -> None: | |
if level < self.verbose: | |
new_args = [self._log_arg(a) for a in args] | |
print(pattern % tuple(new_args)) | |
def input_names(self): # type: ignore | |
"""Returns the input names.""" | |
return self.input_names_ | |
def input_types(self): # type: ignore | |
"""Returns the input types if any specified.""" | |
return self.input_types_ | |
def output_names(self): # type: ignore | |
"""Returns the output names.""" | |
return self.output_names_ | |
def output_types(self): # type: ignore | |
"""Returns the output types.""" | |
return self.output_types_ | |
def opsets(self): # type: ignore | |
"""Returns the opsets.""" | |
return self.opsets_ | |
def has_linked_attribute(self): | |
"""Checks if the graph has a linked attribute (= an attribute whose value is defined | |
by a function attribute. | |
""" | |
return any(node.has_linked_attribute for node in self.rt_nodes_) | |
def __str__(self) -> str: | |
return f"{self.__class__.__name__}({', '.join(self.input_names)}) -> {', '.join(self.output_names)}" | |
def get_result_types(self, name: str, exc: bool = True) -> Any: | |
if self.all_types_ is None: | |
raise RuntimeError( | |
f"Unable to return type for name {name!r}. Run shape_inference first." | |
) | |
if name not in self.all_types_: | |
if exc: | |
raise RuntimeError( | |
f"Unable to return type for name {name!r}, it was not found in {sorted(self.all_types_)}." | |
) | |
return None | |
return self.all_types_[name] | |
def _init(self) -> None: | |
"""Loads the implementation for every node in the graph.""" | |
self.rt_inits_ = {} | |
self.rt_nodes_ = [] | |
for init in self.inits_: | |
self.rt_inits_[init.name] = ( | |
self.retrieve_external_data(init) | |
if uses_external_data(init) | |
else to_array_extended(init) | |
) | |
run_params = { | |
"log": lambda pattern, *args: self._log(10, pattern, *args), | |
"opsets": self.opsets, | |
"verbose": self.verbose, | |
"new_ops": self.new_ops_, | |
"existing_functions": self.functions_.copy(), | |
"evaluator_cls": self.__class__, | |
} | |
if self.input_types_: | |
all_types = {i.name: i.type for i in self.onnx_graph_.input} | |
if hasattr(self.proto_, "value_info"): | |
for shape_type in self.proto_.value_info: | |
all_types[shape_type.name] = shape_type.type | |
self.all_types_ = all_types | |
else: | |
self.all_types_ = None # type: ignore | |
for node in self.nodes_: | |
try: | |
cl = self._load_impl(node) | |
except RuntimeContextError as e: | |
# A node has a context dependent implementation. | |
# Shape inference must be run to get the input types. | |
if self.all_types_: | |
it = [self.get_result_types(i, exc=False) for i in node.input] | |
if None in it: | |
# One input does not exist. It must be done while executing the graph. | |
cl = lambda *args, parent=self: OpFunctionContextDependant( # noqa: E731 | |
*args, parent=parent | |
) | |
else: | |
cl = self._load_impl(node, it) # type: ignore | |
else: | |
raise RuntimeContextError( | |
f"No implementation was found for node type {node.op_type!r} from domain {node.domain!r}. " | |
f"If this node has a context dependent implementation, you should run function infer_shapes " | |
f"before calling ReferenceEvaluator." | |
) from e | |
try: | |
inst = cl(node, run_params) | |
except TypeError as e: | |
raise TypeError( | |
f"Unable to instantiate class {cl!r} with " | |
f"run_params={run_params} and node={node}." | |
) from e | |
self.rt_nodes_.append(inst) | |
def _load_impl( # noqa: PLR0911 | |
self, node: NodeProto, input_types: Optional[TypeProto] = None | |
) -> Any: | |
"""Loads the implementation for a specified runtime.""" | |
if node.domain not in self.opsets: | |
raise RuntimeError( | |
f"Domain {node.domain!r} (node type: {node.op_type!r}) " | |
f"is not specified. Known opsets: {self.opsets!r}." | |
) | |
version = self.opsets[node.domain] | |
key = node.domain, node.op_type | |
expand = False | |
if key in self.new_ops_: | |
# This operator has a custom implementation. | |
# This mechanism can be used to implement a custom onnx node | |
# or to overwrite an existing one. | |
cl = self.new_ops_[key] | |
if not issubclass(cl, OpRunExpand): | |
return cl | |
# It must be replaced by its implementation defined in its schema. | |
expand = True | |
if node.domain == "": | |
from onnx.reference.ops import load_op | |
try: | |
return load_op( | |
node.domain, | |
node.op_type, | |
version, | |
expand=expand, | |
evaluator_cls=self.__class__, | |
) | |
except RuntimeContextError: | |
if input_types is None: | |
raise | |
return load_op( | |
node.domain, | |
node.op_type, | |
version, | |
node=node, | |
input_types=input_types, # type: ignore[arg-type] | |
expand=expand, | |
evaluator_cls=self.__class__, | |
) | |
if expand: | |
raise NotImplementedError( | |
f"Expanding an operator with its function definition " | |
f"is only implemented for the main opset. Remove operator " | |
f"{node.domain},{node.op_type} from the list of inlined operator." | |
) | |
if node.domain == "ai.onnx.preview.training": | |
from onnx.reference.ops.aionnx_preview_training import load_op as load_op_pt | |
return load_op_pt( | |
node.domain, node.op_type, version, evaluator_cls=self.__class__ | |
) | |
if node.domain == "experimental": | |
from onnx.reference.ops.experimental import load_op as load_op_exp | |
return load_op_exp( | |
node.domain, node.op_type, version, evaluator_cls=self.__class__ | |
) | |
if node.domain == "ai.onnx.ml": | |
from onnx.reference.ops.aionnxml import load_op as load_op_ml | |
return load_op_ml( | |
node.domain, node.op_type, version, evaluator_cls=self.__class__ | |
) | |
# It has to be a function. | |
if key in self.functions_: | |
from onnx.reference.ops import load_op | |
impl = self.functions_[key] | |
return load_op( | |
node.domain, | |
node.op_type, | |
version, | |
custom=impl, | |
evaluator_cls=self.__class__, | |
) | |
raise NotImplementedError( | |
f"Node type {node.op_type!r} from domain {node.domain!r} " | |
f"is unknown, known functions: {sorted(self.functions_)}." | |
) | |
def run(self, output_names, feed_inputs: Dict[str, Any], attributes: Optional[Dict[str, Any]] = None): # type: ignore | |
"""Executes the onnx model. | |
Args: | |
output_names: requested outputs by names, None for all | |
feed_inputs: dictionary `{ input name: input value }` | |
attributes: attributes value if the instance runs a | |
FunctionProto | |
Returns: | |
list of requested outputs | |
""" | |
if output_names is None: | |
output_names = self.output_names | |
if isinstance(self.proto_, FunctionProto) and attributes is None: | |
raise TypeError() | |
# step 1: inputs and initializers | |
results = {"": None} # optional input | |
results.update(self.rt_inits_) # type: ignore[arg-type] | |
results.update(feed_inputs) | |
for k, v in self.rt_inits_.items(): | |
self._log(2, " +C %s: %s", k, v) # type: ignore[arg-type] | |
for k, v in feed_inputs.items(): | |
self._log(2, " +I %s: %s", k, v) # type: ignore[arg-type] | |
# step 2: execute nodes | |
for node in self.rt_nodes_: | |
self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output) | |
for i in node.input: | |
if i not in results: | |
raise RuntimeError( | |
f"Unable to find input {i!r} in known results {sorted(results)}, " | |
f"self.rt_inits_ has {sorted(self.rt_inits_)}, " | |
f"feed_inputs has {sorted(feed_inputs)}." | |
) | |
inputs = [results[i] for i in node.input] | |
linked_attributes = {} | |
if node.has_linked_attribute and attributes: | |
linked_attributes["linked_attributes"] = attributes | |
if node.need_context(): | |
outputs = node.run(*inputs, context=results, **linked_attributes) | |
else: | |
outputs = node.run(*inputs, **linked_attributes) | |
for name, value in zip(node.output, outputs): | |
self._log(2, " + %s: %s", name, value) # type: ignore[arg-type] | |
results[name] = value | |
# return the results | |
for name in output_names: | |
if name not in results: | |
raise RuntimeError( | |
f"Unable to find output name {name!r} in {sorted(results)}, proto is\n{self.proto_}" | |
) | |
return [results[name] for name in output_names] | |