Spaces:
Running
Running
File size: 5,729 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# Copyright (c) Meta Platforms, Inc. and affiliates
import warnings
import torch
from .core import is_masked_tensor
from .creation import as_masked_tensor, masked_tensor
__all__ = [] # type: ignore[var-annotated]
def _masked_all_all(data, mask=None):
if mask is None:
return data.all()
return data.masked_fill(~mask, True).all()
def _masked_all_dim(data, dim, keepdim=False, mask=None):
if mask is None:
return torch.all(data, dim=dim, keepdim=keepdim)
return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim)
def _masked_all(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 1:
return _masked_all_all(args[0], mask=kwargs["mask"])
return _masked_all_dim(*args, **kwargs)
def _multidim_any(mask, dim, keepdim):
if isinstance(dim, int):
return _multidim_any(mask, [dim], keepdim)
for d in sorted(dim, reverse=True):
mask = torch.any(mask, dim=d, keepdim=keepdim)
return mask
def _get_masked_fn(fn):
if fn == "all":
return _masked_all
return getattr(torch.masked, fn)
def _torch_reduce_all(fn):
def reduce_all(self):
masked_fn = _get_masked_fn(fn)
data = self.get_data()
mask = self.get_mask().values() if self.is_sparse else self.get_mask()
# When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the
# element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts.
# Therefore, this implementation calculates it using the strides.
if fn == "all":
result_data = masked_fn(data, mask=mask)
elif fn in {"argmin", "argmax"} and self.is_sparse_coo():
sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int)
indices = (
data.to_sparse_coo().indices()
if not self.is_sparse_coo()
else data.indices()
)
idx = indices.unbind(1)[sparse_idx]
stride = data.size().numel() / torch.tensor(
data.size(), device=data.device
).cumprod(0)
result_data = torch.sum(idx * stride)
# we simply pass in the values for sparse COO/CSR tensors
elif self.is_sparse:
result_data = masked_fn(masked_tensor(data.values(), mask))
else:
result_data = masked_fn(self, mask=mask)
return as_masked_tensor(result_data, torch.any(mask))
return reduce_all
def _torch_reduce_dim(fn):
def reduce_dim(self, dim, keepdim=False, dtype=None):
if self.is_sparse:
msg = (
f"The sparse version of {fn} is not implemented in reductions.\n"
"If you would like this operator to be supported, please file an issue for a feature request at "
"https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
"In the case that the semantics for the operator are not trivial, it would be appreciated "
"to also include a proposal for the semantics."
)
warnings.warn(msg)
return NotImplemented
if not is_masked_tensor(self):
raise TypeError("Input to reduce_dim must be a MaskedTensor")
masked_fn = _get_masked_fn(fn)
data = self.get_data()
mask = self.get_mask()
if fn == "all":
result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask)
else:
result_data = masked_fn(
self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask()
)
return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim))
return reduce_dim
def _torch_reduce(fn):
def reduce_fn(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0:
return _torch_reduce_all(fn)(args[0])
return _torch_reduce_dim(fn)(*args, **kwargs)
return reduce_fn
def _reduce_dim_args(input, dim, keepdim=False, dtype=None):
return input, dim, keepdim, dtype
def _torch_grad_reduce(fn):
def grad_reduce(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0:
return _torch_reduce_all(fn)(args[0])
# TODO: autograd.Function doesn't support kwarg
input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs)
return _torch_reduce_dim(fn)(input, dim, keepdim, dtype)
return grad_reduce
REDUCE_NAMES = [
"sum",
"mean",
"amin",
"amax",
"argmin",
"argmax",
"prod",
"all",
"norm",
"var",
"std",
]
NATIVE_REDUCE_MAP = {
getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES
}
TORCH_REDUCE_MAP = {
getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
}
TENSOR_REDUCE_MAP = {
getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
}
NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys())
TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys())
TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys())
def _is_reduction(fn):
return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP
def _apply_reduction(fn, *args, **kwargs):
if fn in NATIVE_REDUCE_MAP:
return NATIVE_REDUCE_MAP[fn](*args, **kwargs)
if fn in TORCH_REDUCE_MAP:
return TORCH_REDUCE_MAP[fn](*args, **kwargs)
if fn in TENSOR_REDUCE_MAP:
return TENSOR_REDUCE_MAP[fn](*args, **kwargs)
return NotImplemented
|