Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates | |
import torch | |
from .core import _map_mt_args_kwargs, _masks_match, _tensors_match, _wrap_result, is_masked_tensor | |
__all__ = [] # type: ignore[var-annotated] | |
BINARY_NAMES = [ | |
"add", | |
"atan2", | |
"arctan2", | |
"bitwise_and", | |
"bitwise_or", | |
"bitwise_xor", | |
"bitwise_left_shift", | |
"bitwise_right_shift", | |
"div", | |
"divide", | |
"floor_divide", | |
"fmod", | |
"logaddexp", | |
"logaddexp2", | |
"mul", | |
"multiply", | |
"nextafter", | |
"remainder", | |
"sub", | |
"subtract", | |
"true_divide", | |
"eq", | |
"ne", | |
"le", | |
"ge", | |
"greater", | |
"greater_equal", | |
"gt", | |
"less_equal", | |
"lt", | |
"less", | |
"maximum", | |
"minimum", | |
"fmax", | |
"fmin", | |
"not_equal", | |
] | |
INPLACE_BINARY_NAMES = [ | |
n + "_" | |
for n in ( | |
list( | |
set(BINARY_NAMES) | |
- { | |
"logaddexp", | |
"logaddexp2", | |
"equal", | |
"fmin", | |
"minimum", | |
"maximum", | |
"fmax", | |
} | |
) | |
) | |
] | |
def _get_at_least_one_mask(a, b): | |
if not is_masked_tensor(a) and not is_masked_tensor(b): | |
raise TypeError("At least one of `a` and `b` must be a MaskedTensor") | |
if not _masks_match(a, b): | |
raise ValueError("a and b must have matching masks") | |
if is_masked_tensor(a): | |
return a.get_mask() | |
return b.get_mask() | |
def _binary_helper(fn, args, kwargs, inplace): | |
if len(kwargs) != 0: | |
raise ValueError("len(kwargs) must equal 0") | |
for a in args[2:]: | |
if torch.is_tensor(a): | |
raise TypeError("MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs") | |
if not _masks_match(*args[:2]): | |
raise ValueError( | |
"Input masks must match. If you need support for this, please open an issue on Github." | |
) | |
data_args, data_kwargs = _map_mt_args_kwargs( | |
args, kwargs, lambda x: x.get_data() | |
) | |
mask_args, mask_kwargs = _map_mt_args_kwargs( | |
args, kwargs, lambda x: x.get_mask() | |
) | |
args0_layout = data_args[0].layout | |
same_layout = ( | |
(torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])) and | |
(args0_layout == data_args[1].layout) | |
) | |
if args0_layout == torch.sparse_coo: | |
if same_layout: | |
if not _tensors_match(data_args[0].indices(), data_args[1].indices()): | |
raise ValueError( | |
"sparse_coo indices must match. If you need support for this, please open an issue on Github." | |
) | |
if data_args[0].size() != data_args[1].size(): | |
raise ValueError("input1 and input2 must have the same size for binary functions.") | |
data_args[1] = data_args[1].values() | |
i = data_args[0].indices() | |
size = data_args[0].size() | |
data_args[0] = data_args[0].values() | |
v = fn(*data_args) | |
result_data = torch.sparse_coo_tensor(i, v, size) | |
elif args0_layout == torch.sparse_csr: | |
if same_layout: | |
if not ( | |
_tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices()) | |
and _tensors_match( | |
data_args[0].col_indices(), data_args[1].col_indices() | |
) | |
): | |
raise ValueError( | |
"sparse_csr indices must match. If you need support for this, please open an issue on Github." | |
) | |
data_args[1] = data_args[1].values() | |
crow = data_args[0].crow_indices() | |
col = data_args[0].col_indices() | |
data_args[0] = data_args[0].values() | |
v = fn(*data_args) | |
result_data = torch.sparse_csr_tensor(crow, col, v) | |
else: | |
result_data = fn(*data_args) | |
if inplace: | |
args[0]._set_data_mask(result_data, mask_args[0]) | |
return args[0] | |
else: | |
result_mask = _get_at_least_one_mask(*args[:2]) | |
# sparse tensors don't have strides so we can only expand if the layout is strided | |
if args0_layout == torch.strided: | |
result_mask = result_mask.expand_as(result_data) | |
return _wrap_result(result_data, result_mask) | |
def _torch_binary(fn_name): | |
fn = getattr(torch.ops.aten, fn_name) | |
def binary_fn(*args, **kwargs): | |
return _binary_helper(fn, args, kwargs, inplace=False) | |
return binary_fn | |
def _torch_inplace_binary(fn_name): | |
fn = getattr(torch.ops.aten, fn_name) | |
def binary_fn(*args, **kwargs): | |
return _binary_helper(fn, args, kwargs, inplace=True) | |
return binary_fn | |
NATIVE_BINARY_MAP = { | |
getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES | |
} | |
NATIVE_INPLACE_BINARY_MAP = { | |
getattr(torch.ops.aten, name): _torch_inplace_binary(name) | |
for name in INPLACE_BINARY_NAMES | |
} | |
NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys()) | |
NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys()) | |
def _is_native_binary(fn): | |
return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS | |
def _apply_native_binary(fn, *args, **kwargs): | |
if fn in NATIVE_BINARY_FNS: | |
return NATIVE_BINARY_MAP[fn](*args, **kwargs) | |
if fn in NATIVE_INPLACE_BINARY_FNS: | |
return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs) | |
return NotImplemented | |