# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 from onnx.helper import np_dtype_to_tensor_dtype from onnx.onnx_pb import TensorProto from onnx.reference.op_run import OpRun from onnx.reference.ops.op_cast import ( bfloat16, cast_to, float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4, uint4, ) def _cast_like(x, y, saturate): if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16": # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16 to = TensorProto.BFLOAT16 elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn": to = TensorProto.FLOAT8E4M3FN elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz": to = TensorProto.FLOAT8E4M3FNUZ elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2": to = TensorProto.FLOAT8E5M2 elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz": to = TensorProto.FLOAT8E5M2FNUZ elif y.dtype == uint4 and y.dtype.descr[0][0] == "uint4": to = TensorProto.UINT4 elif y.dtype == int4 and y.dtype.descr[0][0] == "int4": to = TensorProto.INT4 else: to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore return (cast_to(x, to, saturate),) class CastLike_15(OpRun): def _run(self, x, y): # type: ignore return _cast_like(x, y, True) class CastLike_19(OpRun): def _run(self, x, y, saturate=None): # type: ignore return _cast_like(x, y, saturate)