Spaces:
Sleeping
Sleeping
import warnings | |
# A workaround to support both TorchScript and MyPy: | |
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union | |
import torch | |
from torch import Tensor | |
from torch.masked import as_masked_tensor, is_masked_tensor, MaskedTensor | |
from . import _docs | |
from torch._prims_common import corresponding_real_dtype | |
from torch import sym_float | |
if TYPE_CHECKING: | |
from torch.types import _dtype as DType | |
DimOrDims = Optional[Union[int, Tuple[int], List[int]]] | |
else: | |
# The JIT doesn't understand Union, nor torch.dtype here | |
DType = int | |
DimOrDims = Optional[Tuple[int]] | |
__all__: List[str] = [] | |
# All masked reduction/normalization operations have the same | |
# signatures. Here we introduce docstring templates that are applied | |
# to docstrings of reduction/normalization functions via | |
# _apply_docstring_templates decorator. | |
def _apply_docstring_templates(func): | |
"""Decorator that applies docstring templates to function docstring | |
and returns the function instance. | |
""" | |
doc_string = getattr(_docs, f"{func.__name__}_docstring", None) | |
if doc_string is None: | |
warnings.warn( | |
f"No documentation string available for {func.__name__}." | |
" PyTorch team should run `python tools/update_masked_docs.py`" | |
" to generate the missing docstrings." | |
) | |
else: | |
func.__doc__ = doc_string | |
# Expose function as public symbol | |
__all__.append(func.__name__) | |
return func | |
def _generate_docstring(func): | |
"""A utility function called from tools/update_masked_docs.py | |
script to update the module torch.masked._docs.py | |
""" | |
docstring_templates = dict( | |
reduction_signature="""\ | |
{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""", | |
reduction_descr="""\ | |
Returns {operation name} of all the elements in the :attr:`input` | |
tensor along the given dimension(s) :attr:`dim` while the :attr:`input` | |
elements are masked out according to the boolean tensor | |
:attr:`mask`.""", | |
reduction_args="""\ | |
If :attr:`keepdim` is ``True``, the output tensor is of the same size | |
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of | |
size 1. Otherwise, :attr:`dim` is squeezed (see | |
:func:`torch.squeeze`), resulting in the output tensor having 1 (or | |
``len(dim)``) fewer dimension(s). | |
The boolean tensor :attr:`mask` defines the "validity" of | |
:attr:`input` tensor elements: if :attr:`mask` element is True | |
then the corresponding element in :attr:`input` tensor will be | |
included in {operation name} computation, otherwise the element is | |
ignored. | |
When all elements of :attr:`input` along the given dimension | |
:attr:`dim` are ignored (fully masked-out), the corresponding element | |
of the output tensor will have undefined value: it may or may not | |
correspond to the identity value of {operation name} operation; the | |
choice may correspond to the value that leads to the most efficient | |
storage of :attr:`output` tensor. | |
The mask of the output tensor can be computed as | |
``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, | |
dtype=torch.bool)``. | |
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor | |
don't need to match, but they must be :ref:`broadcastable | |
<broadcasting-semantics>` and the dimensionality of the :attr:`mask` | |
tensor must not be greater than of the :attr:`input` tensor. | |
Args: | |
input (Tensor): the input tensor | |
{args_declarations} | |
Keyword args: | |
{kwargs_declarations}""", | |
reduction_example="""\ | |
Example:: | |
>>> input = {example_input} | |
>>> input | |
{indent_example_input} | |
>>> mask = {example_mask} | |
>>> mask | |
{indent_example_mask} | |
>>> {full_function_name}(input, {example_args}, mask=mask) | |
{indent_example_output} | |
""", | |
reduction_identity="""\ | |
The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""", | |
reduction_identity_dtype="""\ | |
The identity value of {operation name} operation, which is used to start the | |
reduction, depends on input dtype. For instance, for float32, uint8, | |
and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""", | |
normalization_signature="""\ | |
{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""", | |
normalization_descr="""\ | |
Returns {operation name} of all the slices in the :attr:`input` tensor | |
along :attr:`dim` while the :attr:`input` elements are masked out | |
according to the boolean tensor :attr:`mask`. | |
{definition}""", | |
normalization_args="""\ | |
The boolean tensor :attr:`mask` defines the "validity" of | |
:attr:`input` tensor elements: if :attr:`mask` element is True then | |
the corresponding element in :attr:`input` tensor will be included in | |
{operation name} computation, otherwise the element is ignored. | |
The values of masked-out elements of the output tensor have undefined | |
value: it may or may not be set to zero or nan; the choice may correspond to | |
the value that leads to the most efficient storage of :attr:`output` | |
tensor. | |
The mask of the {operation name} output tensor can be computed as | |
``torch.broadcast_to(mask, input.shape)``. | |
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor | |
don't need to match, but they must be :ref:`broadcastable | |
<broadcasting-semantics>` and the dimensionality of the :attr:`mask` | |
tensor must not be greater than of the :attr:`input` tensor. | |
Args: | |
input (Tensor): the input tensor | |
{args_declarations} | |
Keyword args: | |
{kwargs_declarations}""", | |
normalization_example="""\ | |
Example:: | |
>>> input = {example_input} | |
>>> input | |
{indent_example_input} | |
>>> mask = {example_mask} | |
>>> mask | |
{indent_example_mask} | |
>>> {full_function_name}(input, {example_args}, mask=mask) | |
{indent_example_output} | |
""", | |
) | |
args_and_kwargs = dict( | |
# argument name sufficies separated by double underscore will | |
# be removed in the final documentation string. | |
sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), | |
prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), | |
cumsum=(("dim__as_int",), ("dtype=None", "mask=None")), | |
cumprod=(("dim__as_int",), ("dtype=None", "mask=None")), | |
amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), | |
amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), | |
argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), | |
argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), | |
mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), | |
median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), | |
norm=( | |
( | |
"ord", | |
"dim", | |
), | |
("keepdim=False", "dtype=None", "mask=None"), | |
), | |
var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), | |
std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), | |
logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), | |
softmax=(("dim__as_int",), ("dtype=None", "mask=None")), | |
log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")), | |
softmin=(("dim__as_int",), ("dtype=None", "mask=None")), | |
normalize=( | |
( | |
"ord__required", | |
"dim__as_int", | |
), | |
("eps=1e-12", "dtype=None", "mask=None"), | |
), | |
) | |
argument_declarations = dict( | |
dim="""\ | |
dim (int or tuple of ints, optional): the dimension or dimensions to reduce. | |
Default: None that is equivalent to ``tuple(range(input.ndim))``.""", | |
dim__as_int="""\ | |
dim (int): the dimension along which {operation name} is computed.""", | |
ord="""\ | |
ord (int, float, optional): the order of vector norm. Default: 2. | |
See :func:`torch.linalg.vector_norm` for a list of supported norms.""", | |
ord__required="""\ | |
ord (int, float): the order of vector norm. Default: 2. | |
See :func:`torch.linalg.vector_norm` for a list of supported norms.""", | |
unbiased="""\ | |
unbiased (bool): when True, use Bessel’s correction, otherwise, compute | |
the uncorrected sample variance.""", | |
eps="""\ | |
eps (float, optional): small value to avoid division by zero. Default: {default}.""", | |
keepdim="""\ | |
keepdim (bool, optional): whether the output tensor has | |
:attr:`dim` retained or not. Default: {default}.""", | |
dtype="""\ | |
dtype (:class:`torch.dtype`, optional): the desired data type | |
of returned tensor. If specified, the input tensor is | |
casted to :attr:`dtype` before the operation is | |
performed. Default: {default}.""", | |
mask="""\ | |
mask (:class:`torch.Tensor`, optional): the boolean tensor | |
containing the binary mask of validity of input tensor | |
elements. | |
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""", | |
) | |
definitions = dict( | |
softmax="""\ | |
Let ``x`` be a sequence of unmasked elements of one-dimensional slice | |
of the :attr:`input` tensor. Softmax of i-th element in ``x`` is | |
defined as ``exp(x[i])/sum(exp(x))``.""", | |
log_softmax="""\ | |
Let ``x`` be a sequence of unmasked elements of one-dimensional slice | |
of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is | |
defined as ``log(exp(x[i])/sum(exp(x)))``.""", | |
softmin="""\ | |
Let ``x`` be a sequence of unmasked elements of one-dimensional slice | |
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is | |
defined as ``exp(-x[i])/sum(exp(-x))``.""", | |
normalize="""\ | |
Let ``x`` be a sequence of unmasked elements of one-dimensional slice | |
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is | |
defined as ``x[i]/max(norm(x, p), eps)``.""", | |
cumsum="""\ | |
Let ``x`` be a sequence of unmasked elements of one-dimensional slice | |
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is | |
defined as ``sum(x[:i])``.""", | |
cumprod="""\ | |
Let ``x`` be a sequence of unmasked elements of one-dimensional slice | |
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is | |
defined as ``prod(x[:i])``.""", | |
) | |
reduction_names = dict( | |
sum="sum", | |
prod="product", | |
amax="maximum", | |
amin="minimum", | |
argmax="argmax", | |
argmin="argmin", | |
mean="mean", | |
median="median", | |
norm="norm", | |
var="variance", | |
std="standard_deviation", | |
logsumexp="logsumexp", | |
) | |
normalization_names = dict( | |
softmax="softmax", | |
log_softmax="log_softmax", | |
softmin="softmin", | |
normalize="normalize", | |
cumsum="cumulative_sum", | |
cumprod="cumulative_prod", | |
) | |
operation_names = {} | |
operation_names.update(reduction_names) | |
operation_names.update(normalization_names) | |
# Default example data: | |
example_dim = 1 | |
example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]]) | |
example_mask = torch.tensor([[True, False, True], [False, False, False]]) | |
example_args: Tuple[Any, ...] | |
if func.__name__ in {"norm", "normalize"}: | |
example_args = (2.0, example_dim) | |
example_input = example_input.to(dtype=torch.float32) | |
elif func.__name__ in {"var", "std"}: | |
example_args = (example_dim, False) | |
elif func.__name__ == "median": | |
example_args = (example_dim,) | |
example_input = example_input.to(dtype=torch.float32) | |
else: | |
example_args = (example_dim,) | |
operation_args: Tuple[str, ...] | |
operation_kwargs: Tuple[str, ...] | |
operation_args, operation_kwargs = args_and_kwargs[func.__name__] | |
arg_declarations = [ | |
"\n ".join( | |
argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines() | |
) | |
for a in operation_args | |
] | |
kwarg_declarations = [ | |
"\n ".join( | |
argument_declarations.get( | |
a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.' | |
) | |
.format(default=a.split("=", 1)[1]) | |
.splitlines() | |
) | |
for a in operation_kwargs | |
] | |
if func.__name__ in reduction_names: | |
op_kind = "reduction" | |
doc_sections = ["signature", "descr", "identity", "args", "example"] | |
elif func.__name__ in normalization_names: | |
op_kind = "normalization" | |
doc_sections = ["signature", "descr", "args", "example"] | |
example_input = example_input.to(dtype=torch.float32) | |
else: | |
assert 0 # add function name to operation names dictionaries | |
example_output = func(example_input, *example_args, mask=example_mask) | |
template_data = { | |
"function_name": func.__name__, | |
"full_function_name": func.__module__ + "." + func.__name__, | |
"operation name": operation_names[func.__name__], | |
"operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args), | |
"operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs), | |
# one-line representation of a tensor: | |
"example_input": " ".join(str(example_input).split()), | |
"example_args": ", ".join(map(str, example_args)), | |
"example_mask": " ".join(str(example_mask).split()), | |
# multi-line representation of a tensor with indent | |
"indent_example_input": ("\n ").join(str(example_input).splitlines()), | |
"indent_example_mask": ("\n ").join(str(example_mask).splitlines()), | |
"indent_example_output": ("\n ").join(str(example_output).splitlines()), | |
} | |
if func.__name__ in reduction_names: | |
template_data.update( | |
identity_uint8=_reduction_identity( | |
func.__name__, torch.tensor(0, dtype=torch.uint8) | |
), | |
identity_int32=_reduction_identity( | |
func.__name__, torch.tensor(0, dtype=torch.int32) | |
), | |
identity_float32=_reduction_identity( | |
func.__name__, torch.tensor(0, dtype=torch.float32) | |
), | |
) | |
if func.__name__ == "norm": | |
template_data.update( | |
identity_ord_ninf=_reduction_identity( | |
func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf") | |
) | |
) | |
elif func.__name__ in normalization_names: | |
template_data.update(definition=definitions[func.__name__]) | |
else: | |
assert 0 # add function name to operation names dictionaries | |
template_data.update( | |
args_declarations=("\n ".join(arg_declarations)).format_map(template_data) | |
) | |
template_data.update( | |
kwargs_declarations=("\n ".join(kwarg_declarations)).format_map( | |
template_data | |
) | |
) | |
# Apply function name info to docstring templates: | |
templates = { | |
k: v.format_map(template_data) | |
for k, v in docstring_templates.items() | |
if k.startswith(op_kind) | |
} | |
templates.update( | |
(k, v.format_map(template_data) if isinstance(v, str) else v) | |
for k, v in template_data.items() | |
) | |
# Apply docstring templates to function doctring: | |
if func.__doc__ is None: | |
doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections]) | |
else: | |
doc_template = func.__doc__ | |
return doc_template.format_map(templates) | |
def _reduction_identity(op_name: str, input: Tensor, *args): | |
"""Return identity value as scalar tensor of a reduction operation on | |
given input, or None, if the identity value cannot be uniquely | |
defined for the given input. | |
The identity value of the operation is defined as the initial | |
value to reduction operation that has a property ``op(op_identity, | |
value) == value`` for any value in the domain of the operation. | |
Or put it another way, including or excluding the identity value in | |
a list of operands will not change the reduction result. | |
See https://github.com/pytorch/rfcs/pull/27 for more information. | |
""" | |
dtype: DType = input.dtype | |
device = input.device | |
op_name = op_name.rsplit(".", 1)[-1] # lstrip module name when present | |
if op_name in {"sum", "cumsum"}: | |
return torch.tensor(0, dtype=dtype, device=device) | |
elif op_name in {"prod", "cumprod"}: | |
return torch.tensor(1, dtype=dtype, device=device) | |
elif op_name in {"amax", "argmax", "logsumexp"}: | |
if torch.is_floating_point(input): | |
return torch.tensor(-torch.inf, dtype=dtype, device=device) | |
elif torch.is_signed(input) or dtype == torch.uint8: | |
return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) | |
elif op_name in {"amin", "argmin"}: | |
if torch.is_floating_point(input): | |
return torch.tensor(torch.inf, dtype=dtype, device=device) | |
elif torch.is_signed(input) or dtype == torch.uint8: | |
return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device) | |
elif op_name == "mean": | |
# Strictly speaking, the identity value of the mean operation | |
# is the mean of the input. Since the mean value depends on | |
# the dim argument and it may be a non-scalar tensor, we | |
# consider the identity value of the mean operation ambiguous. | |
# Moreover, the mean value of empty input is undefined. | |
return None | |
elif op_name == "norm": | |
ord = args[0] if args else 2 | |
if ord == float("-inf"): | |
assert torch.is_floating_point(input), input.dtype | |
return torch.tensor(torch.inf, dtype=dtype, device=device) | |
return torch.tensor(0, dtype=dtype, device=device) | |
elif op_name == "median": | |
# We use NaN for now because the implementation is currently using torch.nanmedian | |
# and NaN is the identity for that function since it gets ignored | |
dtype = input.dtype if torch.is_floating_point(input) else torch.float | |
return torch.tensor(torch.nan, dtype=dtype, device=device) | |
elif op_name in {"var", "std"}: | |
return None | |
raise NotImplementedError(f"identity of {op_name} on {dtype} input") | |
def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]: | |
"""Return dim argument as a tuple of sorted dim values.""" | |
dims: List[int] = [] | |
if dim == (): | |
# Currently, `dim=()` in reductions operations means "reduce | |
# over all dimensions" while in future, it will read "no | |
# reduce". See https://github.com/pytorch/pytorch/issues/29137 | |
# When gh-29137 is resolved, this if-block must be deleted. | |
dim = None | |
if dim is None: | |
return tuple(range(ndim)) | |
ndim = max(ndim, 1) | |
dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim | |
for d in dim_: | |
if d in dims: | |
raise RuntimeError(f"dim={d} appears multiple times in the list of dims") | |
if d >= ndim or d < -ndim: | |
raise IndexError( | |
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {d})" | |
) | |
dims.append(d % ndim) | |
return tuple(sorted(dims)) | |
def _sparse_coo_flatten_indices(indices: Tensor, shape: tuple): | |
# Flatted N-D indices to 1-D indices | |
flat_indices = indices.new_zeros(indices.size(1)) | |
for d, sz in enumerate(shape): | |
flat_indices.mul_(sz) | |
flat_indices.add_(indices[d]) | |
return flat_indices | |
def _any(input: Tensor, dim: tuple, keepdim: bool): | |
# Support torch.any with tuple dim argument. | |
# Workaround of https://github.com/pytorch/pytorch/issues/56586 | |
r = input | |
for d in reversed(dim): | |
r = r.any(dim=d, keepdim=keepdim) | |
return r | |
def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: | |
"""Sparse variant of torch.where. Supports sparse COO and hybrid sparse COO tensors. | |
_sparse_coo_where implements the following invariant: | |
_sparse_coo_where(mask, input, fill_value).to_dense(fill_value) == | |
torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value)) | |
where `a == b` means `assertEqual(a, b)`, mask is boolean sparse | |
tensor, and `to_dense(fill_value)` is like `to_dense()` except | |
that the unspecified elements are mapped to `fill_value` rather | |
than to `0`. | |
Returns a sparse COO tensor with the following features: | |
- all specified elements correspond to masked-in elements that | |
have the values of the input tensor. If there exists a masked-in | |
element (as specified by mask) that is not specified in the | |
input, in the result tensor, the corresponding element has value | |
0. In the dense part of the sparse tensor, the masked-out | |
elements are replaced with fill_value. | |
- all unspecified elements correspond to masked-out elements. | |
""" | |
assert input.layout == torch.sparse_coo | |
assert mask.layout == input.layout | |
assert mask.shape == input.shape | |
assert mask.dense_dim() == input.dense_dim() # TODO: eliminate this restriction | |
input = input.coalesce() | |
# For set operations on sparse tensor indices, we'll convert | |
# multi-dimensional indices to 1-D indices for efficiency. | |
input_flat_indices = _sparse_coo_flatten_indices( | |
input.indices(), input.shape[: input.sparse_dim()] | |
) | |
mask_flat_indices = _sparse_coo_flatten_indices( | |
mask.indices(), mask.shape[: mask.sparse_dim()] | |
) | |
# the set of mask flat indices that define masked-in elements: | |
if mask.dense_dim() > 0: | |
mask_values = _any( | |
mask.values(), tuple(range(1, input.sparse_dim() + 1)), False | |
) | |
else: | |
mask_values = mask.values() | |
maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]] | |
def intersection(i1, i2): | |
union, counts = torch.cat([i1, i2]).unique(return_counts=True) | |
return union, torch.where(counts.gt(1)) | |
def minus(i1, i2): | |
union, counts = torch.cat([i1, i2]).unique(return_counts=True) | |
return intersection(union[torch.where(counts.eq(1))], i1) | |
def _apply(a): | |
obj, w = a | |
return obj[w] | |
# the set of input flat indices of specified and masked-in elements: | |
maskin_input_flat_indices = _apply( | |
intersection(maskin_flat_indices, input_flat_indices) | |
) | |
_, w = intersection(input_flat_indices, maskin_input_flat_indices) | |
# the indices and values of masked-in elements | |
where_input_indices = input.indices()[(slice(None),) + w] | |
where_input_values = input.values()[w] | |
if mask.dense_dim() > 0: | |
# apply mask to the dense part of the input values: | |
_, w1 = intersection(mask_flat_indices, maskin_input_flat_indices) | |
where_mask_values = mask.values()[w1] | |
where_input_values = torch.where( | |
where_mask_values, where_input_values, fill_value | |
) | |
# the set of flat indices of unspecified input and masked-in elements: | |
maskin_zero_flat_indices = _apply( | |
minus(maskin_flat_indices, maskin_input_flat_indices) | |
) | |
# the indices of masked-in zero elements | |
_, w = intersection(mask_flat_indices, maskin_zero_flat_indices) | |
where_zero_indices = mask.indices()[(slice(None),) + w] | |
# construct result | |
n = where_zero_indices.size(1) | |
if n == 0: | |
# the input is coalesced, hence input_flat_indices are ordered | |
# and the result is guaranteed to be coalesced: | |
result = torch.sparse_coo_tensor( | |
where_input_indices, where_input_values, input.shape | |
) | |
return result._coalesced_(True) | |
where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1) | |
where_values = torch.cat( | |
[ | |
where_input_values, | |
where_input_values.new_zeros((n,) + where_input_values.shape[1:]), | |
] | |
) | |
result = torch.sparse_coo_tensor(where_indices, where_values, input.shape) | |
# appending zero elements leads to uncoalesced sparse tensor | |
return result.coalesce() | |
def _sparse_coo_scatter_reduction_helper( | |
op, | |
mask_input: Tensor, | |
dims: Tuple[int, ...], | |
keepdim: bool, | |
dtype: Optional[DType] = None, | |
) -> Tensor: | |
reduce = op.__name__ | |
valid_reductions = ["sum", "prod", "amax", "amin"] | |
if reduce not in valid_reductions: | |
raise ValueError( | |
f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead" | |
) | |
output_dtype = dtype | |
values, indices = mask_input._values(), mask_input._indices() | |
input_dims = mask_input.dim() | |
num_sparse_dims = mask_input.sparse_dim() | |
reduced_sparse_dims = [] | |
retained_sparse_dims = [] | |
reduced_dense_dims = [] | |
# promote dtype if specified | |
if values.dtype != output_dtype: | |
values = values.to(output_dtype) | |
if keepdim: | |
output_shape = tuple( | |
1 if i in dims else si for (i, si) in enumerate(mask_input.shape) | |
) | |
else: | |
output_shape = tuple( | |
si for (i, si) in enumerate(mask_input.shape) if i not in dims | |
) | |
for d in dims: | |
if d >= input_dims: | |
continue | |
if d < num_sparse_dims: | |
reduced_sparse_dims.append(d) | |
else: | |
reduced_dense_dims.append(d + 1 - num_sparse_dims) | |
# Reduce dense dimensions | |
if len(reduced_dense_dims) > 0: | |
if reduce == "sum": | |
new_values = values | |
new_values = op(new_values, dim=reduced_dense_dims, keepdim=bool(keepdim)) | |
else: | |
# FIXME: Implement reductions for dense dimensions for ops with non-zero reduction identities | |
return NotImplemented | |
else: | |
new_values = values.clone() | |
# Reduce sparse dimensions | |
if len(reduced_sparse_dims) == num_sparse_dims: | |
if reduce in {"amax", "amin"} and new_values.size(0) == 0: | |
# IndexError: amax(): Expected reduction dim 0 to have non-zero size. | |
# sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not | |
# See https://github.com/pytorch/pytorch/issues/61901 | |
new_values = _reduction_identity(reduce, new_values) | |
else: | |
new_values = op(new_values, dim=0) | |
if keepdim: | |
for _ in range(num_sparse_dims): | |
new_values = new_values.unsqueeze(0) | |
return new_values.to(dtype=output_dtype).to_sparse() | |
else: | |
new_indices = indices.clone() | |
if keepdim: | |
# zero out reduced sparse dimensions if keepdim = True | |
# ensures that the call to torch.unique folds duplicated indices together while preserving the dimension | |
new_indices[reduced_sparse_dims, :] = 0 | |
else: | |
# remove reduced sparse dimensions if keepdim = False | |
if len(reduced_sparse_dims) > 0: | |
retained_sparse_dims = [ | |
i | |
for i in range(num_sparse_dims) | |
if i not in set(reduced_sparse_dims) | |
] | |
new_indices = new_indices.index_select( | |
0, torch.tensor(retained_sparse_dims).to(mask_input.device) | |
) | |
# Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices | |
if new_indices.numel() > 0: | |
# lexsort indices and get index tensor for scatter reduction | |
new_indices, inverse_indices = torch.unique( | |
new_indices, return_inverse=True, dim=1 | |
) | |
out_shape = list(new_values.shape) | |
out_shape[0] = new_indices.shape[1] | |
for _ in range(new_values.ndim - 1): | |
inverse_indices = inverse_indices.unsqueeze(-1) | |
scatter_indices = inverse_indices.expand(new_values.shape) | |
# FIXME: temporary workaround for issue with bfloat16/float16 remove when acctype is implemented for scatter_reduce | |
if output_dtype in {torch.bfloat16, torch.float16}: | |
new_values = new_values.to(torch.float) | |
out = new_values.new_empty(out_shape) | |
new_values = out.scatter_reduce_( | |
0, scatter_indices, new_values, reduce=reduce, include_self=False | |
) | |
new_values = new_values.to(dtype=output_dtype) | |
else: | |
out = new_values.new_empty(out_shape) | |
new_values = out.scatter_reduce_( | |
0, scatter_indices, new_values, reduce=reduce, include_self=False | |
) | |
return torch.sparse_coo_tensor( | |
new_indices, | |
new_values, | |
output_shape, | |
dtype=output_dtype, | |
device=mask_input.device, | |
) | |
def _sparse_csr_segment_reduction_helper( | |
op, | |
mask_input: Tensor, | |
dims: Tuple[int, ...], | |
keepdim: bool, | |
dtype: Optional[DType] = None, | |
) -> Tensor: | |
# Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True | |
# FIXME: when dense dimensions are implemented for CSR tensors | |
assert ( | |
keepdim | |
), "reduction operations on CSR tensors with keepdim=False is unsupported" | |
reduce = op.__name__ | |
valid_reductions = ["sum", "prod", "mean", "amax", "amin"] | |
if reduce not in valid_reductions: | |
raise ValueError( | |
f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead" | |
) | |
device = mask_input.device | |
output_dtype = dtype | |
values, crow_indices, col_indices = ( | |
mask_input.values(), | |
mask_input.crow_indices(), | |
mask_input.col_indices(), | |
) | |
# promote dtype if specified | |
if values.dtype != output_dtype: | |
values = values.to(output_dtype) | |
if len(dims) == 0: | |
return mask_input | |
if len(dims) == 1: | |
if dims[0] == 0: | |
new_col_indices, scatter_indices = torch.unique( | |
col_indices, return_inverse=True | |
) | |
new_nnz = new_col_indices.shape[0] | |
new_crow_indices = torch.tensor([0, new_nnz]) | |
new_values = values.new_empty(new_col_indices.shape) | |
new_values.scatter_reduce_( | |
0, scatter_indices, values, reduce, include_self=False | |
) | |
new_shape = [1, mask_input.size(1)] | |
else: | |
assert ( | |
dims[0] == 1 | |
), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1." | |
# all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1 | |
# except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0 | |
new_crow_indices = torch.cat( | |
( | |
crow_indices.new_zeros(1), | |
torch.cumsum(torch.diff(crow_indices) != 0, 0), | |
), | |
0, | |
) | |
new_nnz = new_crow_indices[-1] | |
new_col_indices = col_indices.new_zeros(new_nnz) | |
new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined] | |
new_shape = [mask_input.size(0), 1] | |
else: | |
assert len(dims) == 2 | |
nnz = min(1, values.numel()) | |
if nnz == 1: | |
op_kwargs = {"keepdim": True, "dtype": output_dtype} | |
# amax and amin do not support dtype kwarg | |
if reduce in ["amax", "amin"]: | |
del op_kwargs["dtype"] | |
new_values = op(values, 0, **op_kwargs) | |
else: | |
new_values = torch.empty(0, dtype=output_dtype) | |
new_col_indices = col_indices.new_zeros(nnz) | |
new_crow_indices = torch.tensor([0, nnz]) | |
new_shape = [1, nnz] | |
return torch.sparse_csr_tensor( | |
new_crow_indices, | |
new_col_indices, | |
new_values, | |
new_shape, | |
dtype=output_dtype, | |
device=device, | |
) | |
def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: | |
"""Sparse variant of torch.where. Supports sparse CSR tensors.""" | |
# TODO: implement sparse CSR specific where operator for efficiency | |
return _sparse_coo_where( | |
mask.to_sparse_coo(), input.to_sparse_coo(), fill_value | |
).to_sparse_csr() | |
def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: | |
"""torch.where with sparse inputs support. | |
_where implements the following invariant: | |
_where(mask, input, fill_value).to_dense(fill_value) == | |
torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value)) | |
where `a == b` means `assertEqual(a, b)`, mask is boolean sparse | |
tensor, and `to_dense(fill_value)` is like `to_dense()` except | |
that the unspecified elements are mapped to `fill_value` rather | |
than to `0`. | |
Returns a sparse tensor with the following features: | |
- all specified elements correspond to masked-in elements that | |
have the values of the input tensor. If there exists a masked-in | |
element (as specified by mask) that is not specified in the | |
input, in the result tensor, the corresponding element has value | |
0. In the dense part of the sparse tensor, the masked-out | |
elements are replaced with fill_value. | |
- all unspecified elements correspond to masked-out elements. | |
""" | |
if mask.layout == torch.strided: | |
return torch.where(mask, input, fill_value) | |
elif mask.layout == torch.sparse_coo: | |
return _sparse_coo_where(mask, input, fill_value) | |
elif mask.layout == torch.sparse_csr: | |
return _sparse_csr_where(mask, input, fill_value) | |
else: | |
raise ValueError( | |
f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}" | |
) | |
def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor: | |
"""Return canonical input mask. | |
A canonical input mask is defined as a boolean mask tensor that | |
shape and layout matches with the shape and the layout of the | |
input. | |
The canonical input mask is computed from the :attr:`mask` tensor | |
content to meet the following criteria: | |
1. The shape of the canonical input mask is the same as the shape | |
of :attr:`input` tensor. If the mask tensor has a smaller shape | |
than the shape of the :attr:`input`, broadcasting rules will be | |
applied. Downcasting of mask is not supported. | |
2. The layout of the canonical input mask is the same as the | |
layout of the :attr:`input` tensor. If the mask has different | |
layout, it will be converted to the expected layout. In the | |
case of sparse COO layout, the canonical input mask will be | |
coalesced. | |
3. The dtype of the canonical input mask is torch.bool. If the | |
mask dtype is not bool then it will be converted to bool dtype | |
using `.to(dtype=bool)` method call. | |
4. The elements of the canonical input mask have boolean values | |
copied from the content of the :attr:`mask` tensor (after | |
possible broadcasting and dtype conversion transforms). In | |
general, the sparsity pattern of the sparse canonical input | |
mask need not to be the same as the sparsity pattern of the | |
sparse :attr:`input` tensor. | |
""" | |
if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: | |
raise ValueError( | |
f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}" | |
) | |
mask = kwargs.get("mask") | |
# default mask | |
if mask is None: | |
raise ValueError("_input_mask requires explicit mask") | |
# mask shape must match with input shape | |
if mask.shape != input.shape: | |
if mask.ndim > input.ndim: | |
raise IndexError( | |
"_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)" | |
) | |
if mask.layout == torch.strided: | |
mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool) | |
elif mask.layout == torch.sparse_coo: | |
mask = torch._sparse_broadcast_to(mask, input.shape) | |
else: | |
assert mask.layout == torch.sparse_csr | |
# Broadcasting of CSR tensors is not implemented. Working | |
# around by using COO layout. | |
mask = torch._sparse_broadcast_to( | |
mask.to_sparse(), input.shape | |
).to_sparse_csr() | |
# mask layout must match with input layout | |
if mask.layout != input.layout: | |
if input.layout == torch.strided: | |
mask = mask.to_dense() | |
elif input.layout == torch.sparse_coo: | |
if mask.layout == torch.strided: | |
mask = mask.to_sparse(input.sparse_dim()) | |
else: | |
mask = mask.to_sparse() | |
else: | |
assert input.layout == torch.sparse_csr | |
mask = mask.to_sparse_csr() | |
# sparse mask must be coalesced | |
if mask.layout == torch.sparse_coo: | |
mask = mask.coalesce() | |
# mask is a boolean tensor | |
mask = mask.to(dtype=torch.bool) | |
return mask | |
def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor: | |
"""Return output mask of masked operation applied to given arguments.""" | |
if callable(op): | |
is_reduction = op.__name__ in { | |
"sum", | |
"prod", | |
"amax", | |
"amin", | |
"argmax", | |
"argmin", | |
"mean", | |
"median", | |
"norm", | |
"var", | |
"std", | |
"logsumexp", | |
} | |
is_normalization = op.__name__ in { | |
"softmax", | |
"log_softmax", | |
"softmin", | |
"normalize", | |
"cumsum", | |
"cumprod", | |
} | |
if is_reduction: | |
if op.__name__ == "norm": | |
if args: | |
args = args[1:] # lstrip ord argument | |
dim = args[0] if args else kwargs.get("dim") | |
outmask = _input_mask(input, *args, **kwargs) | |
keepdim = kwargs.get("keepdim", False) | |
dim_ = _canonical_dim(dim, input.ndim) | |
return _any(outmask, dim_, bool(keepdim)) | |
elif is_normalization: | |
return _input_mask(input, *args, **kwargs) | |
else: | |
raise ValueError( | |
f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})" | |
) | |
else: | |
raise ValueError( | |
f"_output_mask expected masked operation (got {type(op).__name__} object)" | |
) | |
def _combine_input_and_mask( | |
op, input: Union[MaskedTensor, Tensor], mask, *args | |
) -> Tensor: | |
def helper(input, mask): | |
if mask is None: | |
return input | |
canonical_mask = _input_mask(input, mask=mask) | |
if callable(op): | |
fill_value = _reduction_identity(op.__name__, input, *args) | |
return _where(canonical_mask, input, fill_value) | |
else: | |
raise ValueError( | |
f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)" | |
) | |
class Combine(torch.autograd.Function): | |
def forward(ctx, input, mask): | |
"""Return input with masked-out elements eliminated for the given operations.""" | |
ctx.save_for_backward(mask) | |
if mask is not None: | |
ctx.mark_non_differentiable(mask) | |
return helper(input, mask) | |
def backward(ctx, grad_output): | |
(mask,) = ctx.saved_tensors | |
grad_data = ( | |
grad_output.get_data() if is_masked_tensor(grad_output) else grad_output | |
) | |
result = as_masked_tensor(grad_data, mask) | |
return result, None | |
return ( | |
Combine.apply(input.get_data(), input.get_mask()) # type: ignore[union-attr] | |
if is_masked_tensor(input) | |
else helper(input, mask) | |
) | |
def sum( | |
input: Union[Tensor, MaskedTensor], | |
dim: DimOrDims = None, | |
*, | |
keepdim: Optional[bool] = False, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
# __doc__ is generated by _apply_docstring_templates decorator | |
if dtype is None: | |
# promote integer types to int64 when output dtype is not specified | |
if input.layout == torch.sparse_csr: | |
if input.dtype in { | |
torch.uint8, | |
torch.bool, | |
torch.int8, | |
torch.int16, | |
torch.int32, | |
}: | |
# csr.to(dtype=torch.int64) is not implemented, so | |
# using coo.to on input to ensure the promoted dtype | |
input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr() | |
else: | |
dtype = input.dtype | |
else: | |
dtype = input.dtype | |
if input.dtype in { | |
torch.uint8, | |
torch.bool, | |
torch.int8, | |
torch.int16, | |
torch.int32, | |
}: | |
dtype = torch.int64 | |
dim_ = _canonical_dim(dim, input.ndim) | |
mask_input = _combine_input_and_mask(sum, input, mask) | |
if mask_input.layout == torch.strided: | |
return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype) | |
elif mask_input.layout == torch.sparse_coo: | |
return _sparse_coo_scatter_reduction_helper( | |
torch.sum, mask_input, dim_, bool(keepdim), dtype | |
) | |
elif mask_input.layout == torch.sparse_csr: | |
return torch._sparse_csr_sum( | |
mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype | |
) | |
else: | |
raise ValueError( | |
f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" | |
) | |
def prod( | |
input: Union[Tensor, MaskedTensor], | |
dim: DimOrDims = None, | |
*, | |
keepdim: Optional[bool] = False, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
# __doc__ is generated by _apply_docstring_templates decorator | |
if dtype is None: | |
# promote integer types to int64 when output dtype is not specified | |
if input.layout == torch.sparse_csr: | |
if input.dtype in { | |
torch.uint8, | |
torch.bool, | |
torch.int8, | |
torch.int16, | |
torch.int32, | |
}: | |
# csr.to(dtype=torch.int64) is not implemented, so | |
# using coo.to on input to ensure the promoted dtype | |
input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr() | |
else: | |
dtype = input.dtype | |
else: | |
dtype = input.dtype | |
if input.dtype in { | |
torch.uint8, | |
torch.bool, | |
torch.int8, | |
torch.int16, | |
torch.int32, | |
}: | |
dtype = torch.int64 | |
dim_ = _canonical_dim(dim, input.ndim) | |
mask_input = _combine_input_and_mask(prod, input, mask) | |
if mask_input.layout == torch.strided: | |
# Workaround https://github.com/pytorch/pytorch/issues/56586 | |
result = mask_input | |
result = result.to(dtype=dtype) | |
for d in reversed(dim_): | |
result = result.prod(dim=d, keepdim=bool(keepdim)) | |
return result | |
elif mask_input.layout == torch.sparse_coo: | |
if mask is None: | |
# See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors | |
raise ValueError( | |
"masked prod expects explicit mask for sparse_coo tensor input" | |
) | |
return _sparse_coo_scatter_reduction_helper( | |
torch.prod, mask_input, dim_, bool(keepdim), dtype | |
) | |
elif mask_input.layout == torch.sparse_csr: | |
if mask is None: | |
# mask is None corresponds to all-True mask. The | |
# unspecified elements in the CSR tensor correspond to | |
# zero values. Hence, the prod reduction result is | |
# automatically zero unless all elements are specified. | |
# A semi-optimal way to take this into account is to use: | |
# | |
# masked_prod(csr, ..., mask=None) == torch._sparse_csr_prod(csr, ...) * all(csr.nonzero(), ...) | |
# | |
# but that requires implementing `all` and `nonzero` | |
# support for sparse csr tensors. | |
raise ValueError( | |
"masked prod expects explicit mask for sparse_csr tensor input" | |
) | |
return torch._sparse_csr_prod( | |
mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype | |
) | |
else: | |
raise ValueError( | |
f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" | |
) | |
def cumsum( | |
input: Tensor, | |
dim: int, | |
*, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
if dtype is None: | |
dtype = input.dtype | |
dim_ = _canonical_dim(dim, input.ndim)[0] | |
mask_input = _combine_input_and_mask(sum, input, mask) | |
if mask_input.layout == torch.strided: | |
return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype) | |
else: | |
raise ValueError( | |
f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)" | |
) | |
def cumprod( | |
input: Tensor, | |
dim: int, | |
*, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
if dtype is None: | |
dtype = input.dtype | |
dim_ = _canonical_dim(dim, input.ndim)[0] | |
mask_input = _combine_input_and_mask(prod, input, mask) | |
if mask_input.layout == torch.strided: | |
return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype) | |
else: | |
raise ValueError( | |
f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)" | |
) | |
def amax( | |
input: Union[Tensor, MaskedTensor], | |
dim: DimOrDims = None, | |
*, | |
keepdim: Optional[bool] = False, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
"""\ | |
{reduction_signature} | |
{reduction_descr} | |
{reduction_identity_dtype} | |
{reduction_args} | |
{reduction_example}""" | |
if dtype is None: | |
dtype = input.dtype | |
mask_input = _combine_input_and_mask(amax, input, mask) | |
dim_ = _canonical_dim(dim, mask_input.ndim) | |
if mask_input.layout == torch.strided: | |
return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype) | |
elif mask_input.layout == torch.sparse_coo: | |
if mask is None: | |
# See comment in the sparse_csr branch of prod, a similar issue arises here | |
# where unspecified elements along a dimension may need to be reduced with the result | |
raise ValueError( | |
"masked amax expects explicit mask for sparse_coo tensor input" | |
) | |
return _sparse_coo_scatter_reduction_helper( | |
torch.amax, mask_input, dim_, bool(keepdim), dtype | |
) | |
elif mask_input.layout == torch.sparse_csr: | |
if mask is None: | |
raise ValueError( | |
"masked amax expects explicit mask for sparse_csr tensor input" | |
) | |
return _sparse_csr_segment_reduction_helper( | |
torch.amax, mask_input, dim_, bool(keepdim), dtype | |
) | |
else: | |
raise ValueError( | |
f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" | |
) | |
def amin( | |
input: Union[Tensor, MaskedTensor], | |
dim: DimOrDims = None, | |
*, | |
keepdim: Optional[bool] = False, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
"""\ | |
{reduction_signature} | |
{reduction_descr} | |
{reduction_identity_dtype} | |
{reduction_args} | |
{reduction_example}""" | |
if dtype is None: | |
dtype = input.dtype | |
mask_input = _combine_input_and_mask(amin, input, mask) | |
dim_ = _canonical_dim(dim, mask_input.ndim) | |
if mask_input.layout == torch.strided: | |
return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype) | |
elif mask_input.layout == torch.sparse_coo: | |
if mask is None: | |
# See comment in the sparse_csr branch of prod, a similar issue arises here | |
# where unspecified elements along a dimension may need to be reduced with the result | |
raise ValueError( | |
"masked amax expects explicit mask for sparse_coo tensor input" | |
) | |
return _sparse_coo_scatter_reduction_helper( | |
torch.amin, mask_input, dim_, bool(keepdim), dtype | |
) | |
elif mask_input.layout == torch.sparse_csr: | |
if mask is None: | |
raise ValueError( | |
"masked amin expects explicit mask for sparse_csr tensor input" | |
) | |
return _sparse_csr_segment_reduction_helper( | |
torch.amin, mask_input, dim_, bool(keepdim), dtype | |
) | |
else: | |
raise ValueError( | |
f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" | |
) | |
def argmax( | |
input: Union[Tensor, MaskedTensor], | |
dim: Optional[int] = None, | |
*, | |
keepdim: Optional[bool] = False, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
"""\ | |
{reduction_signature} | |
{reduction_descr} | |
{reduction_identity_dtype} | |
{reduction_args} | |
{reduction_example}""" | |
if dtype is None: | |
dtype = input.dtype | |
mask_input = _combine_input_and_mask(argmax, input, mask) | |
if mask_input.layout == torch.strided: | |
return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype) | |
else: | |
raise ValueError( | |
f"masked argmax expects strided tensor (got {mask_input.layout} tensor)" | |
) | |
def argmin( | |
input: Union[Tensor, MaskedTensor], | |
dim: Optional[int] = None, | |
*, | |
keepdim: Optional[bool] = False, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
"""\ | |
{reduction_signature} | |
{reduction_descr} | |
{reduction_identity_dtype} | |
{reduction_args} | |
{reduction_example}""" | |
if dtype is None: | |
dtype = input.dtype | |
mask_input = _combine_input_and_mask(argmin, input, mask) | |
if mask_input.layout == torch.strided: | |
return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype) | |
else: | |
raise ValueError( | |
f"masked argmin expects strided tensor (got {mask_input.layout} tensor)" | |
) | |
def mean( | |
input: Union[Tensor, MaskedTensor], | |
dim: DimOrDims = None, | |
*, | |
keepdim: Optional[bool] = False, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
"""\ | |
{reduction_signature} | |
{reduction_descr} | |
By definition, the identity value of a mean operation is the mean | |
value of the tensor. If all elements of the input tensor along given | |
dimension(s) :attr:`dim` are masked-out, the identity value of the | |
mean is undefined. Due to this ambiguity, the elements of output | |
tensor with strided layout, that correspond to fully masked-out | |
elements, have ``nan`` values. | |
{reduction_args} | |
{reduction_example}""" | |
if dtype is None: | |
dtype = input.dtype | |
if input.layout == torch.strided: | |
if mask is None: | |
# TODO: compute count analytically | |
count = sum( | |
torch.ones(input.shape, dtype=torch.int64, device=input.device), | |
dim, | |
keepdim=keepdim, | |
) | |
total = sum(input, dim, keepdim=keepdim, dtype=dtype) | |
else: | |
inmask = _input_mask(input, mask=mask) | |
count = sum( | |
inmask.new_ones(input.shape, dtype=torch.int64), | |
dim, | |
keepdim=keepdim, | |
mask=inmask, | |
) | |
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask) | |
return total / count | |
elif input.layout == torch.sparse_csr: | |
mask_input = _combine_input_and_mask(mean, input, mask) | |
dim_ = _canonical_dim(dim, mask_input.ndim) | |
if mask is None: | |
raise ValueError( | |
"masked mean expects explicit mask for sparse_csr tensor input" | |
) | |
return _sparse_csr_segment_reduction_helper( | |
torch.mean, mask_input, dim_, bool(keepdim), dtype | |
) | |
else: | |
raise ValueError( | |
f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)" | |
) | |
def median( | |
input: Union[Tensor, MaskedTensor], | |
dim: int = -1, | |
*, | |
keepdim: bool = False, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
"""\ | |
{reduction_signature} | |
{reduction_descr} | |
By definition, the identity value of a median operation is the median | |
value of the tensor. If all elements of the input tensor along given | |
dimension(s) :attr:`dim` are masked-out, the identity value of the | |
median is undefined. Due to this ambiguity, the elements of output | |
tensor with strided layout, that correspond to fully masked-out | |
elements, have ``nan`` values. | |
{reduction_args} | |
{reduction_example}""" | |
if dtype is None: | |
dtype = input.dtype | |
dim_ = _canonical_dim(dim, input.ndim)[0] | |
is_float = torch.is_floating_point(input) | |
if not is_float: | |
input = input.to(dtype=torch.float) | |
mask_input = _combine_input_and_mask(median, input, mask) | |
if mask_input.layout == torch.strided: | |
output = torch.nanmedian(mask_input, dim_, keepdim).values | |
if is_float: | |
return output | |
elif not is_float and not torch.isnan(output).any(): | |
return output.to(dtype=dtype) | |
else: | |
raise ValueError( | |
"masked median expects no fully masked out rows if dtype is not floating point" | |
) | |
else: | |
raise ValueError( | |
f"masked median expects strided tensor (got {mask_input.layout} tensor)" | |
) | |
def logsumexp( | |
input: Tensor, | |
dim: DimOrDims = None, | |
*, | |
keepdim: bool = False, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
if dtype is None: | |
dtype = input.dtype | |
dim_ = _canonical_dim(dim, input.ndim) | |
mask_input = _combine_input_and_mask(logsumexp, input, mask) | |
if mask_input.layout == torch.strided: | |
return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype) | |
else: | |
raise ValueError( | |
f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)" | |
) | |
# Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations | |
def logaddexp( | |
input: Union[Tensor, MaskedTensor], | |
other: Union[Tensor, MaskedTensor], | |
*, | |
dtype: Optional[DType] = None, | |
input_mask: Optional[Tensor] = None, | |
other_mask: Optional[Tensor] = None, | |
) -> Tensor: | |
"""logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor | |
Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other` | |
tensor. The :attr:`input` elements are masked out according to the boolean tensor | |
:attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor | |
:attr:`other_mask`. | |
The shapes of a mask tensor and the tensor to be masked | |
don't need to match, but they must be :ref:`broadcastable | |
<broadcasting-semantics>` and the dimensionality of the mask | |
tensor must not be greater than of the tensor to be masked. | |
Args: | |
input (Tensor): the input tensor | |
other (Tensor): the second input tensor | |
Keyword args: | |
dtype (:class:`torch.dtype`, optional): the desired data type | |
of returned tensor. If specified, the output tensor is | |
casted to :attr:`dtype` after the operation is | |
performed. Default: None. | |
input_mask (:class:`torch.Tensor`, optional): the boolean tensor | |
containing the binary mask of validity of :attr:`input` tensor elements. | |
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. | |
other_mask (:class:`torch.Tensor`, optional): the boolean tensor | |
containing the binary mask of validity of :attr:`other` tensor elements. | |
Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``. | |
Example:: | |
>>> input = torch.tensor([-100.0, -200, -300]) | |
>>> input | |
tensor([-100., -200., -300.]) | |
>>> other = torch.tensor([-1.0, -2, -3]) | |
>>> other | |
tensor([-1., -2., -3.]) | |
>>> mask = torch.tensor([True, False, True]) | |
>>> mask | |
tensor([ True, False, True]) | |
>>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask) | |
tensor([-1., -inf, -3.]) | |
""" | |
if dtype is None: | |
dtype = input.dtype | |
if input.layout == torch.strided and other.layout == torch.strided: | |
mask_input = _combine_input_and_mask(logsumexp, input, input_mask) | |
mask_other = _combine_input_and_mask(logsumexp, other, other_mask) | |
return torch.logaddexp(mask_input, mask_other).to(dtype=dtype) | |
else: | |
raise ValueError( | |
f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)" | |
) | |
def norm( | |
input: Union[Tensor, MaskedTensor], | |
ord: Optional[float] = 2.0, | |
dim: DimOrDims = None, | |
*, | |
keepdim: Optional[bool] = False, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
"""\ | |
{reduction_signature} | |
{reduction_descr} | |
The identity value of norm operation, which is used to start the | |
reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is | |
``{identity_ord_ninf}``. | |
{reduction_args} | |
{reduction_example}""" | |
if dtype is None: | |
dtype = input.dtype | |
mask_input = _combine_input_and_mask(norm, input, mask, ord) | |
if mask_input.layout == torch.strided: | |
dim_ = _canonical_dim(dim, input.ndim) | |
return torch.linalg.vector_norm( | |
mask_input, ord, dim_, bool(keepdim), dtype=dtype | |
) | |
else: | |
raise ValueError( | |
f"masked norm expects strided tensor (got {mask_input.layout} tensor)" | |
) | |
def _std_var( | |
input: Union[Tensor, MaskedTensor], | |
dim: DimOrDims, | |
unbiased: Optional[bool], | |
*, | |
correction_opt: Optional[Union[int, float]], | |
keepdim: Optional[bool], | |
dtype: Optional[DType], | |
mask: Optional[Tensor], | |
take_sqrt: Optional[bool], | |
) -> Tensor: | |
assert (unbiased is None or correction_opt is None), "Only one of unbiased and correction may be given" | |
correction = 1.0 | |
if unbiased is not None: | |
correction = 1.0 if unbiased else 0.0 | |
if correction_opt is not None: | |
correction = sym_float(correction_opt) | |
if dtype is None: | |
dtype = input.dtype | |
if not (dtype.is_floating_point or dtype.is_complex): | |
dtype = torch.float32 | |
compute_dtype = dtype | |
if not (compute_dtype.is_floating_point or compute_dtype.is_complex): | |
compute_dtype = torch.float32 | |
if input.layout == torch.strided: | |
if mask is None: | |
# TODO: compute count analytically | |
count = sum( | |
torch.ones(input.shape, dtype=torch.int64, device=input.device), | |
dim, | |
keepdim=True, | |
) | |
sample_total = sum(input, dim, keepdim=True, dtype=dtype) | |
else: | |
inmask = _input_mask(input, mask=mask) | |
count = sum( | |
inmask.new_ones(input.shape, dtype=torch.int64), | |
dim, | |
keepdim=True, | |
mask=inmask, | |
) | |
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask) | |
# TODO: replace torch.subtract/divide/square/maximum with | |
# masked subtract/divide/square/maximum when these will be | |
# available. | |
sample_mean = torch.divide(sample_total, count) | |
x = torch.subtract(input, sample_mean) | |
if mask is None: | |
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype) | |
else: | |
total = sum( | |
x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined] | |
) | |
if not keepdim: | |
count = count.reshape(total.shape) | |
if correction != 0: | |
real_dtype = (corresponding_real_dtype(compute_dtype) | |
if compute_dtype.is_complex else compute_dtype) | |
count = count.to(real_dtype) | |
count = torch.subtract(count, correction) | |
count = torch.maximum(count, count.new_zeros([])) | |
output = torch.divide(total, count).to(dtype=dtype) | |
if take_sqrt: | |
output = torch.sqrt(output) | |
return output | |
else: | |
raise ValueError( | |
f"masked std/var expects strided tensor (got {input.layout} tensor)" | |
) | |
def var( | |
input: Union[Tensor, MaskedTensor], | |
dim: DimOrDims = None, | |
unbiased: Optional[bool] = None, | |
*, | |
correction: Optional[Union[int, float]] = None, | |
keepdim: Optional[bool] = False, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
"""\ | |
{reduction_signature} | |
{reduction_descr} | |
The identity value of sample variance operation is undefined. The | |
elements of output tensor with strided layout, that correspond to | |
fully masked-out elements, have ``nan`` values. | |
{reduction_args} | |
{reduction_example}""" | |
return _std_var( | |
input=input, | |
dim=dim, | |
unbiased=unbiased, | |
correction_opt=correction, | |
keepdim=keepdim, | |
dtype=dtype, | |
mask=mask, | |
take_sqrt=False, | |
) | |
def std( | |
input: Union[Tensor, MaskedTensor], | |
dim: DimOrDims = None, | |
unbiased: Optional[bool] = None, | |
*, | |
correction: Optional[int] = None, | |
keepdim: Optional[bool] = False, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
"""\ | |
{reduction_signature} | |
{reduction_descr} | |
The identity value of sample standard deviation operation is undefined. The | |
elements of output tensor with strided layout, that correspond to | |
fully masked-out elements, have ``nan`` values. | |
{reduction_args} | |
{reduction_example}""" | |
return _std_var( | |
input=input, | |
dim=dim, | |
unbiased=unbiased, | |
correction_opt=correction, | |
keepdim=keepdim, | |
dtype=dtype, | |
mask=mask, | |
take_sqrt=True, | |
) | |
def softmax( | |
input: Union[Tensor, MaskedTensor], | |
dim: int, | |
*, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
if dtype is None: | |
dtype = input.dtype | |
dim_ = _canonical_dim(dim, input.ndim)[0] | |
mask_input = _combine_input_and_mask(amax, input, mask) | |
if mask_input.layout == torch.strided: | |
return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype) | |
else: | |
raise ValueError( | |
f"masked softmax expects strided tensor (got {mask_input.layout} tensor)" | |
) | |
def log_softmax( | |
input: Union[Tensor, MaskedTensor], | |
dim: int, | |
*, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
if dtype is None: | |
dtype = input.dtype | |
dim_ = _canonical_dim(dim, input.ndim)[0] | |
mask_input = _combine_input_and_mask(amax, input, mask) | |
if mask_input.layout == torch.strided: | |
return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype) | |
else: | |
raise ValueError( | |
f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)" | |
) | |
def softmin( | |
input: Union[Tensor, MaskedTensor], | |
dim: int, | |
*, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
if dtype is None: | |
dtype = input.dtype | |
dim_ = _canonical_dim(dim, input.ndim)[0] | |
mask_input = _combine_input_and_mask(amin, input, mask) | |
if mask_input.layout == torch.strided: | |
return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype) | |
else: | |
raise ValueError( | |
f"masked softmin expects strided tensor (got {mask_input.layout} tensor)" | |
) | |
def normalize( | |
input: Union[Tensor, MaskedTensor], | |
ord: float, | |
dim: int, | |
*, | |
eps: float = 1e-12, | |
dtype: Optional[DType] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
if dtype is None: | |
dtype = input.dtype | |
dim_ = _canonical_dim(dim, input.ndim)[0] | |
# TODO: eliminate mask_input as unnecessary when using masked divide. | |
mask_input = _combine_input_and_mask(sum, input, mask) | |
if mask_input.layout == torch.strided: | |
nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask) | |
# TODO: replace torch.maximum with masked maximum when available. | |
denom = torch.maximum(nrm_, nrm_.new_full([], eps)) | |
# TODO: replace torch.divide with masked divide when available. | |
return torch.divide(mask_input, denom) | |
else: | |
raise ValueError( | |
f"masked normalize expects strided tensor (got {mask_input.layout} tensor)" | |
) | |