Spaces:
Sleeping
Sleeping
File size: 5,477 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from onnx.reference.custom_element_types import (
bfloat16,
float8e4m3fn,
float8e4m3fnuz,
float8e5m2,
float8e5m2fnuz,
int4,
uint4,
)
from onnx.reference.op_run import OpRun, RefAttrName
def _check_dtype(val): # type: ignore
a = val.dtype
if not isinstance(a, np.dtype) and a not in {
bfloat16,
float8e4m3fn,
float8e4m3fnuz,
float8e5m2,
float8e5m2fnuz,
uint4,
int4,
np.int8,
np.uint8,
np.float16,
np.float32,
np.float64,
np.int32,
np.int64,
np.int16,
np.uint16,
np.uint32,
np.bool_,
np.str_,
np.uint64,
bool,
str,
}:
raise TypeError(
f"Type ({a}, {type(a)}) is not a numpy type (operator 'Constant')"
)
class ConstantCommon(OpRun):
def _check(self, cst): # type: ignore
if isinstance(cst, tuple):
raise TypeError(f"Unexpected type {type(cst)} for a constant.")
return cst
class Constant_1(ConstantCommon):
def __init__(self, onnx_node, run_params): # type: ignore
ConstantCommon.__init__(self, onnx_node, run_params)
self.cst = self.value # type: ignore
_check_dtype(self.cst)
def _run(self, **overridden_attributes): # type: ignore
if overridden_attributes and (
len(overridden_attributes) > 1
or "value" not in overridden_attributes
or id(overridden_attributes["value"]) != id(self.value)
):
raise RuntimeError(
"Function attributes are not implemented for opset <= 11. Use opset > 12."
)
return (self._check(self.cst),)
class Constant_9(Constant_1):
def __init__(self, onnx_node, run_params): # type: ignore
Constant_1.__init__(self, onnx_node, run_params)
class Constant_11(ConstantCommon):
def __init__(self, onnx_node, run_params): # type: ignore
ConstantCommon.__init__(self, onnx_node, run_params)
if getattr(self, "sparse_value", None) is None:
self.cst = self.value # type: ignore
else:
self.cst = self.sparse_value # type: ignore
_check_dtype(self.cst)
def _run(self, **overridden_attributes): # type: ignore
if overridden_attributes and (
len(overridden_attributes) > 1
or "value" not in overridden_attributes
or id(overridden_attributes["value"]) != id(self.value)
):
raise RuntimeError(
"Function attributes are not implemented for opset <= 11. Use opset > 12."
)
return (self._check(self.cst),)
class Constant_12(ConstantCommon):
def __init__(self, onnx_node, run_params): # type: ignore
ConstantCommon.__init__(self, onnx_node, run_params)
if hasattr(self, "sparse_value") and self.sparse_value is not None: # type: ignore
self.cst_name = "sparse_value"
self.cst = self.sparse_value # type: ignore
self.cst_convert = lambda v: v
elif hasattr(self, "value") and self.value is not None: # type: ignore
self.cst_name = "value" # type: ignore
self.cst = self.value if isinstance(self.value, RefAttrName) else self.value # type: ignore
self.cst_convert = lambda v: v
else:
for attr, np_dtype in {
"value_float": np.float32,
"value_floats": np.float32,
"value_int": np.int64,
"value_ints": np.int64,
"value_string": np.str_,
"value_strings": np.str_,
}.items():
if hasattr(self, attr) and getattr(self, attr) is not None: # type: ignore
self.cst_name = attr
v = getattr(self, attr)
self.cst = (
v # type: ignore
if isinstance(v, RefAttrName) # type: ignore
else np.array(v, dtype=np_dtype) # type: ignore
)
self.cst_convert = lambda v, np_dtype=np_dtype: np.array( # type: ignore
v, dtype=np_dtype
)
break
if not hasattr(self, "cst_name"):
raise AttributeError(
f"No constant is defined for operator 'Constant', outputs are {onnx_node.output}."
)
def _run(self, **overridden_attributes): # type: ignore
if self.has_linked_attribute:
if overridden_attributes is None:
raise RuntimeError(
f"Attributes are empty, cannot retrieve value for {self.cst!r}."
)
if self.cst_name not in overridden_attributes:
raise RuntimeError(
f"Cannot find attribute {self.cst_name!r} in {list(overridden_attributes)!r}."
)
value = overridden_attributes[self.cst_name]
if isinstance(value, np.ndarray):
return (value,)
return (self.cst_convert(value),)
return (self._check(self.cst),)
|