Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
3.99 kB
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import numpy as np
from onnx import TensorProto
from onnx.helper import np_dtype_to_tensor_dtype
from onnx.numpy_helper import float8e4m3_to_float32, float8e5m2_to_float32
from onnx.reference.custom_element_types import (
float8e4m3fn,
float8e4m3fnuz,
float8e5m2,
float8e5m2fnuz,
int4,
uint4,
)
from onnx.reference.op_run import OpRun
from onnx.reference.ops.op_quantize_linear import reshape_input
class _CommonDequantizeLinear(OpRun):
def get_x_type(self, x: np.ndarray) -> int:
tensor_dtype = None
if x.dtype == float8e4m3fn and x.dtype.descr[0][0] == "e4m3fn":
tensor_dtype = TensorProto.FLOAT8E4M3FN
elif x.dtype == float8e4m3fnuz and x.dtype.descr[0][0] == "e4m3fnuz":
tensor_dtype = TensorProto.FLOAT8E4M3FNUZ
elif x.dtype == float8e5m2 and x.dtype.descr[0][0] == "e5m2":
tensor_dtype = TensorProto.FLOAT8E5M2
elif x.dtype == float8e5m2fnuz and x.dtype.descr[0][0] == "e5m2fnuz":
tensor_dtype = TensorProto.FLOAT8E5M2FNUZ
elif x.dtype == uint4 and x.dtype.descr[0][0] == "uint4":
tensor_dtype = TensorProto.UINT4
elif x.dtype == int4 and x.dtype.descr[0][0] == "int4":
tensor_dtype = TensorProto.INT4
else:
tensor_dtype = np_dtype_to_tensor_dtype(x.dtype)
return tensor_dtype
def _run(
self,
x: np.ndarray,
x_scale: np.ndarray,
x_zero_point: Optional[np.ndarray] = None,
axis: Optional[int] = None,
block_size: Optional[int] = None,
): # type: ignore
x_type = self.get_x_type(x)
fp8_type = x_type in {
TensorProto.FLOAT8E4M3FN,
TensorProto.FLOAT8E4M3FNUZ,
TensorProto.FLOAT8E5M2,
TensorProto.FLOAT8E5M2FNUZ,
}
if x_zero_point is not None and not fp8_type:
zero_type = self.get_x_type(x_zero_point)
if x_type != zero_type:
raise ValueError(
f"Type mismatch {x_type} != {zero_type} in DequantizeLinear."
)
dx = x.astype(np.float32) - reshape_input(
x_zero_point, x.shape, axis, block_size
)
else:
if fp8_type and x_zero_point is not None:
u_x_zero_point = x_zero_point.astype(np.uint8)
umi = u_x_zero_point.min()
uma = u_x_zero_point.max()
if umi != uma or umi != np.uint8(0):
raise ValueError(
"x_zero_point is not null but should be zero for float8 types."
)
if x_type == TensorProto.FLOAT8E4M3FN:
dx = float8e4m3_to_float32(x)
elif x_type == TensorProto.FLOAT8E4M3FNUZ:
dx = float8e4m3_to_float32(x, uz=True)
elif x_type == TensorProto.FLOAT8E5M2:
dx = float8e5m2_to_float32(x)
elif x_type == TensorProto.FLOAT8E5M2FNUZ:
dx = float8e5m2_to_float32(x, fn=True, uz=True)
else:
dx = x.astype(np.float32)
y = dx * reshape_input(x_scale, x.shape, axis, block_size)
return (y.astype(x_scale.dtype),)
class DequantizeLinear_19(_CommonDequantizeLinear):
def _run(self, x, x_scale, x_zero_point=None, axis=None):
if len(x_scale.shape) > 1:
raise ValueError("Input 2 must be a vector or a number.")
return super()._run(x, x_scale, x_zero_point, axis)
class DequantizeLinear_21(_CommonDequantizeLinear):
def _run(self, *args, axis=None, block_size=None): # type: ignore
# args: x, y_scale, zero_point
return super()._run(*args, axis=axis, block_size=block_size) # type: ignore