Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates | |
import torch | |
from .core import _map_mt_args_kwargs, _wrap_result | |
__all__ = [] # type: ignore[var-annotated] | |
UNARY_NAMES = [ | |
"abs", | |
"absolute", | |
"acos", | |
"arccos", | |
"acosh", | |
"arccosh", | |
"angle", | |
"asin", | |
"arcsin", | |
"asinh", | |
"arcsinh", | |
"atan", | |
"arctan", | |
"atanh", | |
"arctanh", | |
"bitwise_not", | |
"ceil", | |
"clamp", | |
"clip", | |
"conj_physical", | |
"cos", | |
"cosh", | |
"deg2rad", | |
"digamma", | |
"erf", | |
"erfc", | |
"erfinv", | |
"exp", | |
"exp2", | |
"expm1", | |
"fix", | |
"floor", | |
"frac", | |
"lgamma", | |
"log", | |
"log10", | |
"log1p", | |
"log2", | |
"logit", | |
"i0", | |
"isnan", | |
"nan_to_num", | |
"neg", | |
"negative", | |
"positive", | |
"pow", | |
"rad2deg", | |
"reciprocal", | |
"round", | |
"rsqrt", | |
"sigmoid", | |
"sign", | |
"sgn", | |
"signbit", | |
"sin", | |
"sinc", | |
"sinh", | |
"sqrt", | |
"square", | |
"tan", | |
"tanh", | |
"trunc", | |
] | |
INPLACE_UNARY_NAMES = [ | |
n + "_" | |
for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"})) | |
] | |
# Explicitly tracking functions we know are currently not supported | |
# This might be due to missing code gen or because of complex semantics | |
UNARY_NAMES_UNSUPPORTED = [ | |
"atan2", | |
"arctan2", | |
"bitwise_left_shift", | |
"bitwise_right_shift", | |
"copysign", | |
"float_power", | |
"fmod", | |
"frexp", | |
"gradient", | |
"imag", | |
"ldexp", | |
"lerp", | |
"logical_not", | |
"hypot", | |
"igamma", | |
"igammac", | |
"mvlgamma", | |
"nextafter", | |
"polygamma", | |
"real", | |
"remainder", | |
"true_divide", | |
"xlogy", | |
] | |
def _unary_helper(fn, args, kwargs, inplace): | |
if len(kwargs) != 0: | |
raise ValueError("MaskedTensor unary ops require that len(kwargs) == 0. " | |
"If you need support for this, please open an issue on Github.") | |
for a in args[1:]: | |
if torch.is_tensor(a): | |
raise TypeError("MaskedTensor unary ops do not support additional Tensor arguments") | |
mask_args, mask_kwargs = _map_mt_args_kwargs( | |
args, kwargs, lambda x: x._masked_mask | |
) | |
data_args, data_kwargs = _map_mt_args_kwargs( | |
args, kwargs, lambda x: x._masked_data | |
) | |
if args[0].layout == torch.sparse_coo: | |
data_args[0] = data_args[0].coalesce() | |
s = data_args[0].size() | |
i = data_args[0].indices() | |
data_args[0] = data_args[0].coalesce().values() | |
v = fn(*data_args) | |
result_data = torch.sparse_coo_tensor(i, v, size=s) | |
elif args[0].layout == torch.sparse_csr: | |
crow = data_args[0].crow_indices() | |
col = data_args[0].col_indices() | |
data_args[0] = data_args[0].values() | |
v = fn(*data_args) | |
result_data = torch.sparse_csr_tensor(crow, col, v) | |
else: | |
result_data = fn(*data_args) | |
if inplace: | |
args[0]._set_data_mask(result_data, mask_args[0]) | |
return args[0] | |
else: | |
return _wrap_result(result_data, mask_args[0]) | |
def _torch_unary(fn_name): | |
fn = getattr(torch.ops.aten, fn_name) | |
def unary_fn(*args, **kwargs): | |
return _unary_helper(fn, args, kwargs, inplace=False) | |
return unary_fn | |
def _torch_inplace_unary(fn_name): | |
fn = getattr(torch.ops.aten, fn_name) | |
def unary_fn(*args, **kwargs): | |
return _unary_helper(fn, args, kwargs, inplace=True) | |
return unary_fn | |
NATIVE_UNARY_MAP = { | |
getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES | |
} | |
NATIVE_INPLACE_UNARY_MAP = { | |
getattr(torch.ops.aten, name): _torch_inplace_unary(name) | |
for name in INPLACE_UNARY_NAMES | |
} | |
NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys()) | |
NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys()) | |
def _is_native_unary(fn): | |
return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS | |
def _apply_native_unary(fn, *args, **kwargs): | |
if fn in NATIVE_UNARY_FNS: | |
return NATIVE_UNARY_MAP[fn](*args, **kwargs) | |
if fn in NATIVE_INPLACE_UNARY_FNS: | |
return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs) | |
return NotImplemented | |