Spaces:
Running
Running
File size: 5,562 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from .core import _map_mt_args_kwargs, _masks_match, _tensors_match, _wrap_result, is_masked_tensor
__all__ = [] # type: ignore[var-annotated]
BINARY_NAMES = [
"add",
"atan2",
"arctan2",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"bitwise_left_shift",
"bitwise_right_shift",
"div",
"divide",
"floor_divide",
"fmod",
"logaddexp",
"logaddexp2",
"mul",
"multiply",
"nextafter",
"remainder",
"sub",
"subtract",
"true_divide",
"eq",
"ne",
"le",
"ge",
"greater",
"greater_equal",
"gt",
"less_equal",
"lt",
"less",
"maximum",
"minimum",
"fmax",
"fmin",
"not_equal",
]
INPLACE_BINARY_NAMES = [
n + "_"
for n in (
list(
set(BINARY_NAMES)
- {
"logaddexp",
"logaddexp2",
"equal",
"fmin",
"minimum",
"maximum",
"fmax",
}
)
)
]
def _get_at_least_one_mask(a, b):
if not is_masked_tensor(a) and not is_masked_tensor(b):
raise TypeError("At least one of `a` and `b` must be a MaskedTensor")
if not _masks_match(a, b):
raise ValueError("a and b must have matching masks")
if is_masked_tensor(a):
return a.get_mask()
return b.get_mask()
def _binary_helper(fn, args, kwargs, inplace):
if len(kwargs) != 0:
raise ValueError("len(kwargs) must equal 0")
for a in args[2:]:
if torch.is_tensor(a):
raise TypeError("MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs")
if not _masks_match(*args[:2]):
raise ValueError(
"Input masks must match. If you need support for this, please open an issue on Github."
)
data_args, data_kwargs = _map_mt_args_kwargs(
args, kwargs, lambda x: x.get_data()
)
mask_args, mask_kwargs = _map_mt_args_kwargs(
args, kwargs, lambda x: x.get_mask()
)
args0_layout = data_args[0].layout
same_layout = (
(torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])) and
(args0_layout == data_args[1].layout)
)
if args0_layout == torch.sparse_coo:
if same_layout:
if not _tensors_match(data_args[0].indices(), data_args[1].indices()):
raise ValueError(
"sparse_coo indices must match. If you need support for this, please open an issue on Github."
)
if data_args[0].size() != data_args[1].size():
raise ValueError("input1 and input2 must have the same size for binary functions.")
data_args[1] = data_args[1].values()
i = data_args[0].indices()
size = data_args[0].size()
data_args[0] = data_args[0].values()
v = fn(*data_args)
result_data = torch.sparse_coo_tensor(i, v, size)
elif args0_layout == torch.sparse_csr:
if same_layout:
if not (
_tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices())
and _tensors_match(
data_args[0].col_indices(), data_args[1].col_indices()
)
):
raise ValueError(
"sparse_csr indices must match. If you need support for this, please open an issue on Github."
)
data_args[1] = data_args[1].values()
crow = data_args[0].crow_indices()
col = data_args[0].col_indices()
data_args[0] = data_args[0].values()
v = fn(*data_args)
result_data = torch.sparse_csr_tensor(crow, col, v)
else:
result_data = fn(*data_args)
if inplace:
args[0]._set_data_mask(result_data, mask_args[0])
return args[0]
else:
result_mask = _get_at_least_one_mask(*args[:2])
# sparse tensors don't have strides so we can only expand if the layout is strided
if args0_layout == torch.strided:
result_mask = result_mask.expand_as(result_data)
return _wrap_result(result_data, result_mask)
def _torch_binary(fn_name):
fn = getattr(torch.ops.aten, fn_name)
def binary_fn(*args, **kwargs):
return _binary_helper(fn, args, kwargs, inplace=False)
return binary_fn
def _torch_inplace_binary(fn_name):
fn = getattr(torch.ops.aten, fn_name)
def binary_fn(*args, **kwargs):
return _binary_helper(fn, args, kwargs, inplace=True)
return binary_fn
NATIVE_BINARY_MAP = {
getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES
}
NATIVE_INPLACE_BINARY_MAP = {
getattr(torch.ops.aten, name): _torch_inplace_binary(name)
for name in INPLACE_BINARY_NAMES
}
NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys())
NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys())
def _is_native_binary(fn):
return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS
def _apply_native_binary(fn, *args, **kwargs):
if fn in NATIVE_BINARY_FNS:
return NATIVE_BINARY_MAP[fn](*args, **kwargs)
if fn in NATIVE_INPLACE_BINARY_FNS:
return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs)
return NotImplemented
|