Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
5.48 kB
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from onnx.reference.custom_element_types import (
bfloat16,
float8e4m3fn,
float8e4m3fnuz,
float8e5m2,
float8e5m2fnuz,
int4,
uint4,
)
from onnx.reference.op_run import OpRun, RefAttrName
def _check_dtype(val): # type: ignore
a = val.dtype
if not isinstance(a, np.dtype) and a not in {
bfloat16,
float8e4m3fn,
float8e4m3fnuz,
float8e5m2,
float8e5m2fnuz,
uint4,
int4,
np.int8,
np.uint8,
np.float16,
np.float32,
np.float64,
np.int32,
np.int64,
np.int16,
np.uint16,
np.uint32,
np.bool_,
np.str_,
np.uint64,
bool,
str,
}:
raise TypeError(
f"Type ({a}, {type(a)}) is not a numpy type (operator 'Constant')"
)
class ConstantCommon(OpRun):
def _check(self, cst): # type: ignore
if isinstance(cst, tuple):
raise TypeError(f"Unexpected type {type(cst)} for a constant.")
return cst
class Constant_1(ConstantCommon):
def __init__(self, onnx_node, run_params): # type: ignore
ConstantCommon.__init__(self, onnx_node, run_params)
self.cst = self.value # type: ignore
_check_dtype(self.cst)
def _run(self, **overridden_attributes): # type: ignore
if overridden_attributes and (
len(overridden_attributes) > 1
or "value" not in overridden_attributes
or id(overridden_attributes["value"]) != id(self.value)
):
raise RuntimeError(
"Function attributes are not implemented for opset <= 11. Use opset > 12."
)
return (self._check(self.cst),)
class Constant_9(Constant_1):
def __init__(self, onnx_node, run_params): # type: ignore
Constant_1.__init__(self, onnx_node, run_params)
class Constant_11(ConstantCommon):
def __init__(self, onnx_node, run_params): # type: ignore
ConstantCommon.__init__(self, onnx_node, run_params)
if getattr(self, "sparse_value", None) is None:
self.cst = self.value # type: ignore
else:
self.cst = self.sparse_value # type: ignore
_check_dtype(self.cst)
def _run(self, **overridden_attributes): # type: ignore
if overridden_attributes and (
len(overridden_attributes) > 1
or "value" not in overridden_attributes
or id(overridden_attributes["value"]) != id(self.value)
):
raise RuntimeError(
"Function attributes are not implemented for opset <= 11. Use opset > 12."
)
return (self._check(self.cst),)
class Constant_12(ConstantCommon):
def __init__(self, onnx_node, run_params): # type: ignore
ConstantCommon.__init__(self, onnx_node, run_params)
if hasattr(self, "sparse_value") and self.sparse_value is not None: # type: ignore
self.cst_name = "sparse_value"
self.cst = self.sparse_value # type: ignore
self.cst_convert = lambda v: v
elif hasattr(self, "value") and self.value is not None: # type: ignore
self.cst_name = "value" # type: ignore
self.cst = self.value if isinstance(self.value, RefAttrName) else self.value # type: ignore
self.cst_convert = lambda v: v
else:
for attr, np_dtype in {
"value_float": np.float32,
"value_floats": np.float32,
"value_int": np.int64,
"value_ints": np.int64,
"value_string": np.str_,
"value_strings": np.str_,
}.items():
if hasattr(self, attr) and getattr(self, attr) is not None: # type: ignore
self.cst_name = attr
v = getattr(self, attr)
self.cst = (
v # type: ignore
if isinstance(v, RefAttrName) # type: ignore
else np.array(v, dtype=np_dtype) # type: ignore
)
self.cst_convert = lambda v, np_dtype=np_dtype: np.array( # type: ignore
v, dtype=np_dtype
)
break
if not hasattr(self, "cst_name"):
raise AttributeError(
f"No constant is defined for operator 'Constant', outputs are {onnx_node.output}."
)
def _run(self, **overridden_attributes): # type: ignore
if self.has_linked_attribute:
if overridden_attributes is None:
raise RuntimeError(
f"Attributes are empty, cannot retrieve value for {self.cst!r}."
)
if self.cst_name not in overridden_attributes:
raise RuntimeError(
f"Cannot find attribute {self.cst_name!r} in {list(overridden_attributes)!r}."
)
value = overridden_attributes[self.cst_name]
if isinstance(value, np.ndarray):
return (value,)
return (self.cst_convert(value),)
return (self._check(self.cst),)