Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| # Copyright (c) ONNX Project Contributors | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import numpy as np | |
| from onnx import subbyte | |
| from onnx.helper import ( | |
| float32_to_bfloat16, | |
| float32_to_float8e4m3, | |
| float32_to_float8e5m2, | |
| tensor_dtype_to_np_dtype, | |
| ) | |
| from onnx.numpy_helper import ( | |
| bfloat16_to_float32, | |
| float8e4m3_to_float32, | |
| float8e5m2_to_float32, | |
| ) | |
| from onnx.onnx_pb import TensorProto | |
| from onnx.reference.custom_element_types import ( | |
| bfloat16, | |
| float8e4m3fn, | |
| float8e4m3fnuz, | |
| float8e5m2, | |
| float8e5m2fnuz, | |
| int4, | |
| uint4, | |
| ) | |
| from onnx.reference.op_run import OpRun | |
| def cast_to(x, to, saturate): # noqa: PLR0911 | |
| if x.dtype == bfloat16 and x.dtype.descr[0][0] == "bfloat16": | |
| if to == TensorProto.BFLOAT16: | |
| return x | |
| xr = x.ravel() | |
| xf = np.empty(xr.shape[0], dtype=np.float32) | |
| for i in range(xr.shape[0]): | |
| el = bfloat16_to_float32(xr[i]) | |
| xf[i] = el | |
| dtype = tensor_dtype_to_np_dtype(to) | |
| return xf.astype(dtype).reshape(x.shape) | |
| f8 = { | |
| (float8e4m3fn, "e4m3fn", TensorProto.FLOAT8E4M3FN): float8e4m3_to_float32, | |
| ( | |
| float8e4m3fnuz, | |
| "e4m3fnuz", | |
| TensorProto.FLOAT8E4M3FNUZ, | |
| ): lambda *args: float8e4m3_to_float32(*args, uz=True), | |
| (float8e5m2, "e5m2", TensorProto.FLOAT8E5M2): float8e5m2_to_float32, | |
| ( | |
| float8e5m2fnuz, | |
| "e5m2fnuz", | |
| TensorProto.FLOAT8E5M2FNUZ, | |
| ): lambda *args: float8e5m2_to_float32(*args, fn=True, uz=True), | |
| } | |
| for (dt, st, proto_type), cvt in f8.items(): | |
| if x.dtype == dt and x.dtype.descr[0][0] == st: | |
| if to == proto_type: | |
| return x | |
| xr = x.ravel() | |
| xf = np.empty(xr.shape[0], dtype=np.float32) | |
| for i in range(xr.shape[0]): | |
| el = cvt(xr[i]) | |
| xf[i] = el | |
| dtype = tensor_dtype_to_np_dtype(to) | |
| return xf.astype(dtype).reshape(x.shape) | |
| if to == TensorProto.BFLOAT16: | |
| xf = x.astype(np.float32).ravel() | |
| y = np.empty(xf.shape, dtype=bfloat16).ravel() | |
| for i in range(y.shape[0]): | |
| el = float32_to_bfloat16(xf[i], truncate=True) # type: ignore[assignment] | |
| y[i] = el | |
| return y.reshape(x.shape) | |
| i4 = [ | |
| (uint4, "uint4", TensorProto.UINT4, False), | |
| (int4, "int4", TensorProto.INT4, True), | |
| ] | |
| for np_type, np_desc, tensor_type, signed in i4: | |
| if x.dtype == np_type and x.dtype.descr[0][0] == np_desc: | |
| if to == tensor_type: | |
| return x | |
| to_type = tensor_dtype_to_np_dtype(to) | |
| return x.astype(to_type) | |
| if to == tensor_type: | |
| xf = x.astype(np.float32).ravel() | |
| y = np.empty(xf.shape, dtype=np_type).ravel() | |
| for i in range(y.shape[0]): | |
| el = subbyte.float32_to_4bit_unpacked(xf[i], signed=signed) | |
| y[i] = el | |
| return y.reshape(x.shape) | |
| f8back = { | |
| TensorProto.FLOAT8E4M3FN: ( | |
| float8e4m3fn, | |
| lambda *args: float32_to_float8e4m3(*args, saturate=saturate), | |
| ), | |
| TensorProto.FLOAT8E4M3FNUZ: ( | |
| float8e4m3fnuz, | |
| lambda *args: float32_to_float8e4m3(*args, uz=True, saturate=saturate), | |
| ), | |
| TensorProto.FLOAT8E5M2: ( | |
| float8e5m2, | |
| lambda *args: float32_to_float8e5m2(*args, saturate=saturate), | |
| ), | |
| TensorProto.FLOAT8E5M2FNUZ: ( | |
| float8e5m2fnuz, | |
| lambda *args: float32_to_float8e5m2( | |
| *args, fn=True, uz=True, saturate=saturate | |
| ), | |
| ), | |
| } | |
| for dt, (npdt, cvt) in f8back.items(): | |
| if to == dt: | |
| xf = x.astype(np.float32).ravel() | |
| y = np.empty(xf.shape, dtype=npdt).ravel() | |
| for i in range(y.shape[0]): | |
| el = cvt(xf[i]) # type: ignore[assignment] | |
| y[i] = el | |
| return y.reshape(x.shape) | |
| if to == TensorProto.STRING: | |
| return x.astype(np.str_) | |
| dtype = tensor_dtype_to_np_dtype(to) | |
| return x.astype(dtype) | |
| class Cast_1(OpRun): | |
| def _run(self, x, to=None): # type: ignore | |
| return (cast_to(x, to, saturate=True),) | |
| class Cast_19(OpRun): | |
| def _run(self, x, to=None, saturate=None): # type: ignore | |
| return (cast_to(x, to, saturate),) | |
