File size: 1,599 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0


from onnx.helper import np_dtype_to_tensor_dtype
from onnx.onnx_pb import TensorProto
from onnx.reference.op_run import OpRun
from onnx.reference.ops.op_cast import (
    bfloat16,
    cast_to,
    float8e4m3fn,
    float8e4m3fnuz,
    float8e5m2,
    float8e5m2fnuz,
    int4,
    uint4,
)


def _cast_like(x, y, saturate):
    if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
        # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
        to = TensorProto.BFLOAT16
    elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn":
        to = TensorProto.FLOAT8E4M3FN
    elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz":
        to = TensorProto.FLOAT8E4M3FNUZ
    elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2":
        to = TensorProto.FLOAT8E5M2
    elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz":
        to = TensorProto.FLOAT8E5M2FNUZ
    elif y.dtype == uint4 and y.dtype.descr[0][0] == "uint4":
        to = TensorProto.UINT4
    elif y.dtype == int4 and y.dtype.descr[0][0] == "int4":
        to = TensorProto.INT4
    else:
        to = np_dtype_to_tensor_dtype(y.dtype)  # type: ignore
    return (cast_to(x, to, saturate),)


class CastLike_15(OpRun):
    def _run(self, x, y):  # type: ignore
        return _cast_like(x, y, True)


class CastLike_19(OpRun):
    def _run(self, x, y, saturate=None):  # type: ignore
        return _cast_like(x, y, saturate)