Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
import numpy as np | |
from onnx.reference.custom_element_types import ( | |
bfloat16, | |
float8e4m3fn, | |
float8e4m3fnuz, | |
float8e5m2, | |
float8e5m2fnuz, | |
int4, | |
uint4, | |
) | |
from onnx.reference.op_run import OpRun, RefAttrName | |
def _check_dtype(val): # type: ignore | |
a = val.dtype | |
if not isinstance(a, np.dtype) and a not in { | |
bfloat16, | |
float8e4m3fn, | |
float8e4m3fnuz, | |
float8e5m2, | |
float8e5m2fnuz, | |
uint4, | |
int4, | |
np.int8, | |
np.uint8, | |
np.float16, | |
np.float32, | |
np.float64, | |
np.int32, | |
np.int64, | |
np.int16, | |
np.uint16, | |
np.uint32, | |
np.bool_, | |
np.str_, | |
np.uint64, | |
bool, | |
str, | |
}: | |
raise TypeError( | |
f"Type ({a}, {type(a)}) is not a numpy type (operator 'Constant')" | |
) | |
class ConstantCommon(OpRun): | |
def _check(self, cst): # type: ignore | |
if isinstance(cst, tuple): | |
raise TypeError(f"Unexpected type {type(cst)} for a constant.") | |
return cst | |
class Constant_1(ConstantCommon): | |
def __init__(self, onnx_node, run_params): # type: ignore | |
ConstantCommon.__init__(self, onnx_node, run_params) | |
self.cst = self.value # type: ignore | |
_check_dtype(self.cst) | |
def _run(self, **overridden_attributes): # type: ignore | |
if overridden_attributes and ( | |
len(overridden_attributes) > 1 | |
or "value" not in overridden_attributes | |
or id(overridden_attributes["value"]) != id(self.value) | |
): | |
raise RuntimeError( | |
"Function attributes are not implemented for opset <= 11. Use opset > 12." | |
) | |
return (self._check(self.cst),) | |
class Constant_9(Constant_1): | |
def __init__(self, onnx_node, run_params): # type: ignore | |
Constant_1.__init__(self, onnx_node, run_params) | |
class Constant_11(ConstantCommon): | |
def __init__(self, onnx_node, run_params): # type: ignore | |
ConstantCommon.__init__(self, onnx_node, run_params) | |
if getattr(self, "sparse_value", None) is None: | |
self.cst = self.value # type: ignore | |
else: | |
self.cst = self.sparse_value # type: ignore | |
_check_dtype(self.cst) | |
def _run(self, **overridden_attributes): # type: ignore | |
if overridden_attributes and ( | |
len(overridden_attributes) > 1 | |
or "value" not in overridden_attributes | |
or id(overridden_attributes["value"]) != id(self.value) | |
): | |
raise RuntimeError( | |
"Function attributes are not implemented for opset <= 11. Use opset > 12." | |
) | |
return (self._check(self.cst),) | |
class Constant_12(ConstantCommon): | |
def __init__(self, onnx_node, run_params): # type: ignore | |
ConstantCommon.__init__(self, onnx_node, run_params) | |
if hasattr(self, "sparse_value") and self.sparse_value is not None: # type: ignore | |
self.cst_name = "sparse_value" | |
self.cst = self.sparse_value # type: ignore | |
self.cst_convert = lambda v: v | |
elif hasattr(self, "value") and self.value is not None: # type: ignore | |
self.cst_name = "value" # type: ignore | |
self.cst = self.value if isinstance(self.value, RefAttrName) else self.value # type: ignore | |
self.cst_convert = lambda v: v | |
else: | |
for attr, np_dtype in { | |
"value_float": np.float32, | |
"value_floats": np.float32, | |
"value_int": np.int64, | |
"value_ints": np.int64, | |
"value_string": np.str_, | |
"value_strings": np.str_, | |
}.items(): | |
if hasattr(self, attr) and getattr(self, attr) is not None: # type: ignore | |
self.cst_name = attr | |
v = getattr(self, attr) | |
self.cst = ( | |
v # type: ignore | |
if isinstance(v, RefAttrName) # type: ignore | |
else np.array(v, dtype=np_dtype) # type: ignore | |
) | |
self.cst_convert = lambda v, np_dtype=np_dtype: np.array( # type: ignore | |
v, dtype=np_dtype | |
) | |
break | |
if not hasattr(self, "cst_name"): | |
raise AttributeError( | |
f"No constant is defined for operator 'Constant', outputs are {onnx_node.output}." | |
) | |
def _run(self, **overridden_attributes): # type: ignore | |
if self.has_linked_attribute: | |
if overridden_attributes is None: | |
raise RuntimeError( | |
f"Attributes are empty, cannot retrieve value for {self.cst!r}." | |
) | |
if self.cst_name not in overridden_attributes: | |
raise RuntimeError( | |
f"Cannot find attribute {self.cst_name!r} in {list(overridden_attributes)!r}." | |
) | |
value = overridden_attributes[self.cst_name] | |
if isinstance(value, np.ndarray): | |
return (value,) | |
return (self.cst_convert(value),) | |
return (self._check(self.cst),) | |