Spaces:
Running
Running
File size: 4,316 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from .core import _map_mt_args_kwargs, _wrap_result
__all__ = [] # type: ignore[var-annotated]
UNARY_NAMES = [
"abs",
"absolute",
"acos",
"arccos",
"acosh",
"arccosh",
"angle",
"asin",
"arcsin",
"asinh",
"arcsinh",
"atan",
"arctan",
"atanh",
"arctanh",
"bitwise_not",
"ceil",
"clamp",
"clip",
"conj_physical",
"cos",
"cosh",
"deg2rad",
"digamma",
"erf",
"erfc",
"erfinv",
"exp",
"exp2",
"expm1",
"fix",
"floor",
"frac",
"lgamma",
"log",
"log10",
"log1p",
"log2",
"logit",
"i0",
"isnan",
"nan_to_num",
"neg",
"negative",
"positive",
"pow",
"rad2deg",
"reciprocal",
"round",
"rsqrt",
"sigmoid",
"sign",
"sgn",
"signbit",
"sin",
"sinc",
"sinh",
"sqrt",
"square",
"tan",
"tanh",
"trunc",
]
INPLACE_UNARY_NAMES = [
n + "_"
for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"}))
]
# Explicitly tracking functions we know are currently not supported
# This might be due to missing code gen or because of complex semantics
UNARY_NAMES_UNSUPPORTED = [
"atan2",
"arctan2",
"bitwise_left_shift",
"bitwise_right_shift",
"copysign",
"float_power",
"fmod",
"frexp",
"gradient",
"imag",
"ldexp",
"lerp",
"logical_not",
"hypot",
"igamma",
"igammac",
"mvlgamma",
"nextafter",
"polygamma",
"real",
"remainder",
"true_divide",
"xlogy",
]
def _unary_helper(fn, args, kwargs, inplace):
if len(kwargs) != 0:
raise ValueError("MaskedTensor unary ops require that len(kwargs) == 0. "
"If you need support for this, please open an issue on Github.")
for a in args[1:]:
if torch.is_tensor(a):
raise TypeError("MaskedTensor unary ops do not support additional Tensor arguments")
mask_args, mask_kwargs = _map_mt_args_kwargs(
args, kwargs, lambda x: x._masked_mask
)
data_args, data_kwargs = _map_mt_args_kwargs(
args, kwargs, lambda x: x._masked_data
)
if args[0].layout == torch.sparse_coo:
data_args[0] = data_args[0].coalesce()
s = data_args[0].size()
i = data_args[0].indices()
data_args[0] = data_args[0].coalesce().values()
v = fn(*data_args)
result_data = torch.sparse_coo_tensor(i, v, size=s)
elif args[0].layout == torch.sparse_csr:
crow = data_args[0].crow_indices()
col = data_args[0].col_indices()
data_args[0] = data_args[0].values()
v = fn(*data_args)
result_data = torch.sparse_csr_tensor(crow, col, v)
else:
result_data = fn(*data_args)
if inplace:
args[0]._set_data_mask(result_data, mask_args[0])
return args[0]
else:
return _wrap_result(result_data, mask_args[0])
def _torch_unary(fn_name):
fn = getattr(torch.ops.aten, fn_name)
def unary_fn(*args, **kwargs):
return _unary_helper(fn, args, kwargs, inplace=False)
return unary_fn
def _torch_inplace_unary(fn_name):
fn = getattr(torch.ops.aten, fn_name)
def unary_fn(*args, **kwargs):
return _unary_helper(fn, args, kwargs, inplace=True)
return unary_fn
NATIVE_UNARY_MAP = {
getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES
}
NATIVE_INPLACE_UNARY_MAP = {
getattr(torch.ops.aten, name): _torch_inplace_unary(name)
for name in INPLACE_UNARY_NAMES
}
NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys())
NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys())
def _is_native_unary(fn):
return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS
def _apply_native_unary(fn, *args, **kwargs):
if fn in NATIVE_UNARY_FNS:
return NATIVE_UNARY_MAP[fn](*args, **kwargs)
if fn in NATIVE_INPLACE_UNARY_FNS:
return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs)
return NotImplemented
|