Spaces:
Running
Running
File size: 1,599 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from onnx.helper import np_dtype_to_tensor_dtype
from onnx.onnx_pb import TensorProto
from onnx.reference.op_run import OpRun
from onnx.reference.ops.op_cast import (
bfloat16,
cast_to,
float8e4m3fn,
float8e4m3fnuz,
float8e5m2,
float8e5m2fnuz,
int4,
uint4,
)
def _cast_like(x, y, saturate):
if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
to = TensorProto.BFLOAT16
elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn":
to = TensorProto.FLOAT8E4M3FN
elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz":
to = TensorProto.FLOAT8E4M3FNUZ
elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2":
to = TensorProto.FLOAT8E5M2
elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz":
to = TensorProto.FLOAT8E5M2FNUZ
elif y.dtype == uint4 and y.dtype.descr[0][0] == "uint4":
to = TensorProto.UINT4
elif y.dtype == int4 and y.dtype.descr[0][0] == "int4":
to = TensorProto.INT4
else:
to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
return (cast_to(x, to, saturate),)
class CastLike_15(OpRun):
def _run(self, x, y): # type: ignore
return _cast_like(x, y, True)
class CastLike_19(OpRun):
def _run(self, x, y, saturate=None): # type: ignore
return _cast_like(x, y, saturate)
|