Spaces:
Running
Running
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
from typing import Optional | |
import numpy as np | |
from onnx import TensorProto | |
from onnx.helper import np_dtype_to_tensor_dtype | |
from onnx.numpy_helper import float8e4m3_to_float32, float8e5m2_to_float32 | |
from onnx.reference.custom_element_types import ( | |
float8e4m3fn, | |
float8e4m3fnuz, | |
float8e5m2, | |
float8e5m2fnuz, | |
int4, | |
uint4, | |
) | |
from onnx.reference.op_run import OpRun | |
from onnx.reference.ops.op_quantize_linear import reshape_input | |
class _CommonDequantizeLinear(OpRun): | |
def get_x_type(self, x: np.ndarray) -> int: | |
tensor_dtype = None | |
if x.dtype == float8e4m3fn and x.dtype.descr[0][0] == "e4m3fn": | |
tensor_dtype = TensorProto.FLOAT8E4M3FN | |
elif x.dtype == float8e4m3fnuz and x.dtype.descr[0][0] == "e4m3fnuz": | |
tensor_dtype = TensorProto.FLOAT8E4M3FNUZ | |
elif x.dtype == float8e5m2 and x.dtype.descr[0][0] == "e5m2": | |
tensor_dtype = TensorProto.FLOAT8E5M2 | |
elif x.dtype == float8e5m2fnuz and x.dtype.descr[0][0] == "e5m2fnuz": | |
tensor_dtype = TensorProto.FLOAT8E5M2FNUZ | |
elif x.dtype == uint4 and x.dtype.descr[0][0] == "uint4": | |
tensor_dtype = TensorProto.UINT4 | |
elif x.dtype == int4 and x.dtype.descr[0][0] == "int4": | |
tensor_dtype = TensorProto.INT4 | |
else: | |
tensor_dtype = np_dtype_to_tensor_dtype(x.dtype) | |
return tensor_dtype | |
def _run( | |
self, | |
x: np.ndarray, | |
x_scale: np.ndarray, | |
x_zero_point: Optional[np.ndarray] = None, | |
axis: Optional[int] = None, | |
block_size: Optional[int] = None, | |
): # type: ignore | |
x_type = self.get_x_type(x) | |
fp8_type = x_type in { | |
TensorProto.FLOAT8E4M3FN, | |
TensorProto.FLOAT8E4M3FNUZ, | |
TensorProto.FLOAT8E5M2, | |
TensorProto.FLOAT8E5M2FNUZ, | |
} | |
if x_zero_point is not None and not fp8_type: | |
zero_type = self.get_x_type(x_zero_point) | |
if x_type != zero_type: | |
raise ValueError( | |
f"Type mismatch {x_type} != {zero_type} in DequantizeLinear." | |
) | |
dx = x.astype(np.float32) - reshape_input( | |
x_zero_point, x.shape, axis, block_size | |
) | |
else: | |
if fp8_type and x_zero_point is not None: | |
u_x_zero_point = x_zero_point.astype(np.uint8) | |
umi = u_x_zero_point.min() | |
uma = u_x_zero_point.max() | |
if umi != uma or umi != np.uint8(0): | |
raise ValueError( | |
"x_zero_point is not null but should be zero for float8 types." | |
) | |
if x_type == TensorProto.FLOAT8E4M3FN: | |
dx = float8e4m3_to_float32(x) | |
elif x_type == TensorProto.FLOAT8E4M3FNUZ: | |
dx = float8e4m3_to_float32(x, uz=True) | |
elif x_type == TensorProto.FLOAT8E5M2: | |
dx = float8e5m2_to_float32(x) | |
elif x_type == TensorProto.FLOAT8E5M2FNUZ: | |
dx = float8e5m2_to_float32(x, fn=True, uz=True) | |
else: | |
dx = x.astype(np.float32) | |
y = dx * reshape_input(x_scale, x.shape, axis, block_size) | |
return (y.astype(x_scale.dtype),) | |
class DequantizeLinear_19(_CommonDequantizeLinear): | |
def _run(self, x, x_scale, x_zero_point=None, axis=None): | |
if len(x_scale.shape) > 1: | |
raise ValueError("Input 2 must be a vector or a number.") | |
return super()._run(x, x_scale, x_zero_point, axis) | |
class DequantizeLinear_21(_CommonDequantizeLinear): | |
def _run(self, *args, axis=None, block_size=None): # type: ignore | |
# args: x, y_scale, zero_point | |
return super()._run(*args, axis=axis, block_size=block_size) # type: ignore | |