Spaces:
Running
Running
File size: 3,992 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import numpy as np
from onnx import TensorProto
from onnx.helper import np_dtype_to_tensor_dtype
from onnx.numpy_helper import float8e4m3_to_float32, float8e5m2_to_float32
from onnx.reference.custom_element_types import (
float8e4m3fn,
float8e4m3fnuz,
float8e5m2,
float8e5m2fnuz,
int4,
uint4,
)
from onnx.reference.op_run import OpRun
from onnx.reference.ops.op_quantize_linear import reshape_input
class _CommonDequantizeLinear(OpRun):
def get_x_type(self, x: np.ndarray) -> int:
tensor_dtype = None
if x.dtype == float8e4m3fn and x.dtype.descr[0][0] == "e4m3fn":
tensor_dtype = TensorProto.FLOAT8E4M3FN
elif x.dtype == float8e4m3fnuz and x.dtype.descr[0][0] == "e4m3fnuz":
tensor_dtype = TensorProto.FLOAT8E4M3FNUZ
elif x.dtype == float8e5m2 and x.dtype.descr[0][0] == "e5m2":
tensor_dtype = TensorProto.FLOAT8E5M2
elif x.dtype == float8e5m2fnuz and x.dtype.descr[0][0] == "e5m2fnuz":
tensor_dtype = TensorProto.FLOAT8E5M2FNUZ
elif x.dtype == uint4 and x.dtype.descr[0][0] == "uint4":
tensor_dtype = TensorProto.UINT4
elif x.dtype == int4 and x.dtype.descr[0][0] == "int4":
tensor_dtype = TensorProto.INT4
else:
tensor_dtype = np_dtype_to_tensor_dtype(x.dtype)
return tensor_dtype
def _run(
self,
x: np.ndarray,
x_scale: np.ndarray,
x_zero_point: Optional[np.ndarray] = None,
axis: Optional[int] = None,
block_size: Optional[int] = None,
): # type: ignore
x_type = self.get_x_type(x)
fp8_type = x_type in {
TensorProto.FLOAT8E4M3FN,
TensorProto.FLOAT8E4M3FNUZ,
TensorProto.FLOAT8E5M2,
TensorProto.FLOAT8E5M2FNUZ,
}
if x_zero_point is not None and not fp8_type:
zero_type = self.get_x_type(x_zero_point)
if x_type != zero_type:
raise ValueError(
f"Type mismatch {x_type} != {zero_type} in DequantizeLinear."
)
dx = x.astype(np.float32) - reshape_input(
x_zero_point, x.shape, axis, block_size
)
else:
if fp8_type and x_zero_point is not None:
u_x_zero_point = x_zero_point.astype(np.uint8)
umi = u_x_zero_point.min()
uma = u_x_zero_point.max()
if umi != uma or umi != np.uint8(0):
raise ValueError(
"x_zero_point is not null but should be zero for float8 types."
)
if x_type == TensorProto.FLOAT8E4M3FN:
dx = float8e4m3_to_float32(x)
elif x_type == TensorProto.FLOAT8E4M3FNUZ:
dx = float8e4m3_to_float32(x, uz=True)
elif x_type == TensorProto.FLOAT8E5M2:
dx = float8e5m2_to_float32(x)
elif x_type == TensorProto.FLOAT8E5M2FNUZ:
dx = float8e5m2_to_float32(x, fn=True, uz=True)
else:
dx = x.astype(np.float32)
y = dx * reshape_input(x_scale, x.shape, axis, block_size)
return (y.astype(x_scale.dtype),)
class DequantizeLinear_19(_CommonDequantizeLinear):
def _run(self, x, x_scale, x_zero_point=None, axis=None):
if len(x_scale.shape) > 1:
raise ValueError("Input 2 must be a vector or a number.")
return super()._run(x, x_scale, x_zero_point, axis)
class DequantizeLinear_21(_CommonDequantizeLinear):
def _run(self, *args, axis=None, block_size=None): # type: ignore
# args: x, y_scale, zero_point
return super()._run(*args, axis=axis, block_size=block_size) # type: ignore
|