# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 import numpy as np from onnx import subbyte from onnx.helper import ( float32_to_bfloat16, float32_to_float8e4m3, float32_to_float8e5m2, tensor_dtype_to_np_dtype, ) from onnx.numpy_helper import ( bfloat16_to_float32, float8e4m3_to_float32, float8e5m2_to_float32, ) from onnx.onnx_pb import TensorProto from onnx.reference.custom_element_types import ( bfloat16, float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4, uint4, ) from onnx.reference.op_run import OpRun def cast_to(x, to, saturate): # noqa: PLR0911 if x.dtype == bfloat16 and x.dtype.descr[0][0] == "bfloat16": if to == TensorProto.BFLOAT16: return x xr = x.ravel() xf = np.empty(xr.shape[0], dtype=np.float32) for i in range(xr.shape[0]): el = bfloat16_to_float32(xr[i]) xf[i] = el dtype = tensor_dtype_to_np_dtype(to) return xf.astype(dtype).reshape(x.shape) f8 = { (float8e4m3fn, "e4m3fn", TensorProto.FLOAT8E4M3FN): float8e4m3_to_float32, ( float8e4m3fnuz, "e4m3fnuz", TensorProto.FLOAT8E4M3FNUZ, ): lambda *args: float8e4m3_to_float32(*args, uz=True), (float8e5m2, "e5m2", TensorProto.FLOAT8E5M2): float8e5m2_to_float32, ( float8e5m2fnuz, "e5m2fnuz", TensorProto.FLOAT8E5M2FNUZ, ): lambda *args: float8e5m2_to_float32(*args, fn=True, uz=True), } for (dt, st, proto_type), cvt in f8.items(): if x.dtype == dt and x.dtype.descr[0][0] == st: if to == proto_type: return x xr = x.ravel() xf = np.empty(xr.shape[0], dtype=np.float32) for i in range(xr.shape[0]): el = cvt(xr[i]) xf[i] = el dtype = tensor_dtype_to_np_dtype(to) return xf.astype(dtype).reshape(x.shape) if to == TensorProto.BFLOAT16: xf = x.astype(np.float32).ravel() y = np.empty(xf.shape, dtype=bfloat16).ravel() for i in range(y.shape[0]): el = float32_to_bfloat16(xf[i], truncate=True) # type: ignore[assignment] y[i] = el return y.reshape(x.shape) i4 = [ (uint4, "uint4", TensorProto.UINT4, False), (int4, "int4", TensorProto.INT4, True), ] for np_type, np_desc, tensor_type, signed in i4: if x.dtype == np_type and x.dtype.descr[0][0] == np_desc: if to == tensor_type: return x to_type = tensor_dtype_to_np_dtype(to) return x.astype(to_type) if to == tensor_type: xf = x.astype(np.float32).ravel() y = np.empty(xf.shape, dtype=np_type).ravel() for i in range(y.shape[0]): el = subbyte.float32_to_4bit_unpacked(xf[i], signed=signed) y[i] = el return y.reshape(x.shape) f8back = { TensorProto.FLOAT8E4M3FN: ( float8e4m3fn, lambda *args: float32_to_float8e4m3(*args, saturate=saturate), ), TensorProto.FLOAT8E4M3FNUZ: ( float8e4m3fnuz, lambda *args: float32_to_float8e4m3(*args, uz=True, saturate=saturate), ), TensorProto.FLOAT8E5M2: ( float8e5m2, lambda *args: float32_to_float8e5m2(*args, saturate=saturate), ), TensorProto.FLOAT8E5M2FNUZ: ( float8e5m2fnuz, lambda *args: float32_to_float8e5m2( *args, fn=True, uz=True, saturate=saturate ), ), } for dt, (npdt, cvt) in f8back.items(): if to == dt: xf = x.astype(np.float32).ravel() y = np.empty(xf.shape, dtype=npdt).ravel() for i in range(y.shape[0]): el = cvt(xf[i]) # type: ignore[assignment] y[i] = el return y.reshape(x.shape) if to == TensorProto.STRING: return x.astype(np.str_) dtype = tensor_dtype_to_np_dtype(to) return x.astype(dtype) class Cast_1(OpRun): def _run(self, x, to=None): # type: ignore return (cast_to(x, to, saturate=True),) class Cast_19(OpRun): def _run(self, x, to=None, saturate=None): # type: ignore return (cast_to(x, to, saturate),)