Spaces:
Running
Running
import functools | |
import math | |
import operator | |
import torch | |
from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention | |
from .nested_tensor import NestedTensor | |
from typing import * # noqa: F403 | |
import torch.nn.functional as F | |
from torch.fx.operator_schemas import normalize_function | |
__all__: List[Any] = [] | |
JAGGED_OPS_TABLE: Dict[Any, Any] = {} | |
# Simplifying assumption: we assume that the batch dim is always the left-most | |
# dim, and the ragged dim is always the second dim. | |
def _outer_to_inner_dim(ndim, dim): | |
assert dim >= 0 and dim < ndim | |
return 0 if dim < 2 else dim - 1 | |
def _wrap_jagged_dim( | |
ndim, dim, op_name, convert_to_inner_dim=True, allow_batch_dim=False | |
): | |
from torch._prims_common import canonicalize_dims | |
wrapped = canonicalize_dims(ndim, dim) | |
if wrapped == 1: | |
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=1") | |
elif wrapped == 0 and not allow_batch_dim: | |
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0") | |
return _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped | |
def _wrap_jagged_dims(ndim, dims, op_name): | |
# ex: (2, 3, 4) -> (1, 2, 3) | |
# ex: (0, 1, 4) -> (0, 3) | |
from torch._prims_common import canonicalize_dims | |
wrapped_dims = [canonicalize_dims(ndim, d) for d in dims] | |
# This logic needs to be done after we canonicalize dims but before we | |
# map to inner dims so we can print a nicer error message. | |
zero_in_dims = 0 in wrapped_dims | |
one_in_dims = 1 in wrapped_dims | |
if zero_in_dims ^ one_in_dims: | |
apply, not_apply = ("batch", "ragged") if zero_in_dims else ("ragged", "batch") | |
raise RuntimeError( | |
f"{op_name}(): applying over the {apply} dimension, but not the {not_apply}" | |
" dimension is not supported for NestedTensor" | |
) | |
return ( | |
tuple(_outer_to_inner_dim(ndim, d) for d in dims if d != 0), | |
zero_in_dims, | |
) | |
def check_schema(schema_str: str, func, *args, **kwargs) -> None: | |
named_arg_types = schema_str.split(", ") | |
num_optional_args = sum([x.endswith("?") for x in named_arg_types]) | |
min_args = len(named_arg_types) - num_optional_args | |
# special case: ellipses allows for any number of unchecked args at the end | |
if named_arg_types[-1] == "...": | |
named_arg_types = named_arg_types[:-1] | |
else: | |
if not (len(args) >= min_args and len(args) <= len(named_arg_types)): | |
raise ValueError( | |
f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} " | |
f"arguments and at most {len(named_arg_types)} arguments, but got: " | |
f"{len(args)} arguments" | |
) | |
arg_type_check_fns = { | |
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor), | |
"jt": lambda x: isinstance(x, NestedTensor) | |
and x._lengths is None | |
and x._ragged_idx == 1, # ops with "jt" require contiguous JT only | |
"jt_all": lambda x: isinstance( | |
x, NestedTensor | |
), # ops with "jt_all" can accept all kinds of JT | |
"any": lambda x: True, | |
} | |
for i, named_arg_type in enumerate(named_arg_types): | |
name, arg_type = named_arg_type.split(": ") | |
is_optional = arg_type.endswith("?") | |
normalized_arg_type = arg_type[:-1] if is_optional else arg_type | |
if normalized_arg_type not in arg_type_check_fns.keys(): | |
raise AssertionError(f"Unknown arg type: {normalized_arg_type}") | |
if i >= len(args): | |
if not is_optional: | |
raise ValueError( | |
f"NestedTensor {func.__name__}({schema_str}) " | |
f"missing required argument: {name}" | |
) | |
continue | |
_check_fn = arg_type_check_fns[normalized_arg_type] | |
def check_fn(x, is_optional=is_optional): | |
if is_optional: | |
return x is None or _check_fn(x) | |
else: | |
return _check_fn(x) | |
if not check_fn(args[i]): | |
type_to_desc = { | |
"t": "tensor", | |
"t?": "optional tensor", | |
"jt": "contiguous jagged layout NestedTensor", | |
"jt_all": "jagged layout NestedTensor", | |
"any": "<any type>", | |
} | |
raise ValueError( | |
f"NestedTensor {func.__name__}({schema_str}): expected {name} to be a " | |
f"{type_to_desc[arg_type]}" | |
) | |
def check_ragged_dim_same( | |
func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str | |
) -> None: | |
# Calling into .shape here | |
if a._size[a._ragged_idx] != b._size[b._ragged_idx]: | |
raise RuntimeError( | |
f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the " | |
"same exact offsets tensor." | |
) | |
# returns True if the raggedness-relevant portions of the NT shape | |
# match those of the specified size | |
def raggedness_matches(nt, size): | |
end = nt._ragged_idx + 1 | |
nt_ragged = nt._size[:end] | |
size_ragged = size[:end] | |
return len(nt_ragged) == len(size_ragged) and ( | |
all(ns == s or s == -1 for ns, s in zip(nt_ragged, size_ragged)) | |
) | |
def squeeze_leading_ones(t): | |
# Note: [ Squeezing leading ones ] | |
# | |
# Squeeze leading ones from t. | |
# | |
# We want: | |
# (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?) | |
# (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) (not yet supported) | |
# | |
# 1) Squeeze extra ones and grab values from NT | |
# (1, 1, ?, ?) -> (?, ?) and (sum(*), ?, ?) -> (B, j0, ?, ?) | |
# 2) Do dense broadcasting: | |
# (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?) | |
# 3) Construct nested tensor | |
# (sum(*), ?, ?) -> (B, j0, ?, ?) | |
# | |
# If unsqueezing on the 0th dim becomes supported, we would unsqueeze | |
# at step (4) and we would need to update this function to record how | |
# many ones we unsqueezed. | |
while t.shape[0] == 1: | |
t = t.squeeze(0) | |
return t | |
def register_func(tables, aten_ops, schema_str): | |
if not isinstance(aten_ops, list): | |
aten_ops = [aten_ops] | |
if not isinstance(tables, list): | |
tables = [tables] | |
def wrapper(func): | |
for aten_op in aten_ops: | |
def get_inner(aten_op): | |
def inner(*args, **kwargs): | |
check_schema(schema_str, func, *args, **kwargs) | |
return func(aten_op, *args, **kwargs) | |
return inner | |
for table in tables: | |
table[aten_op] = get_inner(aten_op) | |
return func | |
return wrapper | |
register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE) | |
def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]: | |
dispatch_func = JAGGED_OPS_TABLE.get(func, None) | |
if dispatch_func is not None: | |
return dispatch_func | |
# Handle pointwise fallbacks | |
if torch.Tag.pointwise in func.tags: | |
# Assume there aren't additional tensors that aren't the "unary/binary" args | |
num_tensor_args = sum([isinstance(x, torch.Tensor) for x in args]) | |
if num_tensor_args == 1: | |
check_schema("self: jt_all, ...", func, *args, **kwargs) | |
return functools.partial(jagged_unary_pointwise, func) | |
elif num_tensor_args == 2: | |
check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs) | |
return functools.partial(jagged_binary_pointwise, func) | |
return None | |
def extract_kwargs(arg): | |
kwargs = { | |
"offsets": arg.offsets(), | |
"_metadata_cache": arg._metadata_cache, | |
"_ragged_idx": arg._ragged_idx, | |
} | |
return kwargs | |
def jagged_unary_pointwise(func, *args, **kwargs): | |
return NestedTensor( | |
func(args[0]._values, *args[1:], **kwargs), **extract_kwargs(args[0]) | |
) | |
def jagged_binary_pointwise(func, *args, **kwargs): | |
a, b = args[0], args[1] | |
assert isinstance(a, NestedTensor) or isinstance(b, NestedTensor) | |
mismatch_error_msg = ( | |
"cannot call binary pointwise function {} with inputs of shapes {} and {}" | |
) | |
# a is NT, b is NT | |
if isinstance(a, NestedTensor) and isinstance(b, NestedTensor): | |
# ex: (B, j0, D) + (B, j0, D) | |
# ex: (B, j0, D) + (B, j0, 1) | |
if raggedness_matches(a, b._size): | |
return NestedTensor( | |
func(a._values, b._values, *args[2:], **kwargs), **extract_kwargs(a) | |
) | |
raise RuntimeError(mismatch_error_msg.format(func.__name__, a._size, b._size)) | |
# either a is NT or b is NT at this point | |
a_is_nt = isinstance(a, NestedTensor) | |
extracted_kwargs = extract_kwargs(a) if a_is_nt else extract_kwargs(b) | |
# === Handle broadcasting across the batch / ragged dims === | |
# Easy case: take advantage of pre-existing broadcasting logic | |
# ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?) | |
# ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?) | |
# ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?) | |
nt, t = (a, b) if a_is_nt else (b, a) | |
# See Note: [ Squeezing leading ones ] | |
if t.dim() > nt.dim(): | |
raise NotImplementedError("NYI: broadcasting NT with T with larger dim") | |
t_squeezed = squeeze_leading_ones(t) | |
if nt.dim() >= t_squeezed.dim() + 2: | |
lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values) | |
return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs) | |
# Harder case: do manual broadcasting over unbound components | |
# when NT dim == non-NT dim | |
# ex: (B, j0, D_0, D_1) + (B, 1, D_0, D_1) -> (B, j0, D_0, D_1) | |
if a.dim() == b.dim(): | |
# ex: (B, j0, D_0, D_1) + (1, 1, D_0, D_1) -> should | |
# be (B, j0, D_0, D_1) but not yet supported | |
if a.shape[0] != b.shape[0]: | |
raise RuntimeError( | |
mismatch_error_msg.format(func.__name__, a.shape, b.shape) | |
) | |
# need to use offsets to broadcast across ragged dim properly | |
# NB: inefficient fallback here; Triton codegen can help this | |
# TODO: Make this work with autograd | |
outputs = [] | |
for a_comp, b_comp in zip(a.unbind(), b.unbind()): | |
outputs.append(func(a_comp, b_comp, *args[2:], **kwargs)) | |
new_values = torch.cat(outputs, dim=0) | |
return NestedTensor(new_values, **extracted_kwargs) | |
# ex: (B, j0, D_0, D_1) + (A, B, 1, D_0, D_1) -> error because this breaks the invariant | |
# that ragged dim is wrt left-most batch dim | |
raise RuntimeError(mismatch_error_msg.format(func.__name__, a.shape, b.shape)) | |
def jagged_torch_function(func, *args, **kwargs): | |
# SDPA has special kernels that handle nested tensors. | |
# Dispatch to the correct implementation here | |
if func is torch._C._nn.scaled_dot_product_attention: | |
return jagged_scaled_dot_product_attention(*args, **kwargs) | |
# Handle flatten() here because it's CompositeImplicit. | |
if func.__name__ == "flatten": | |
def _flatten_sig(input, start_dim=0, end_dim=-1): | |
pass | |
_, new_kwargs = normalize_function( | |
_flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
# NB: stay in outer dim space because we're going to redispatch on a NT input | |
start_dim = _wrap_jagged_dim( | |
inp.dim(), new_kwargs["start_dim"], "flatten", convert_to_inner_dim=False | |
) | |
end_dim = _wrap_jagged_dim( | |
inp.dim(), new_kwargs["end_dim"], "flatten", convert_to_inner_dim=False | |
) | |
if start_dim == end_dim: | |
return inp | |
product = functools.reduce(operator.mul, inp.shape[start_dim : end_dim + 1]) | |
new_shape = (*inp.shape[:start_dim], product, *inp.shape[end_dim + 1 :]) | |
return inp.reshape(*new_shape) | |
raise NotImplementedError(func) | |
def tensor_attr_supported_getter(func, *args, **kwargs): | |
if func == torch.ops.aten.is_non_overlapping_and_dense.default: | |
return False | |
if func == torch.ops.aten.sym_size.default: | |
return args[0]._size | |
if func == torch.ops.aten.dim.default: | |
return len(args[0]._size) | |
if func == torch.ops.aten.sym_numel.default: | |
if args[0]._lengths is not None: | |
return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:])) | |
return args[0]._values.numel() | |
if func == torch.ops.aten.sym_stride.default: | |
return args[0]._strides | |
if func == torch.ops.aten.sym_storage_offset.default: | |
return args[0]._values.storage_offset() | |
def prim_layout_default(func, *args, **kwargs): | |
return torch.jagged | |
def tensor_attr_unsupported_getter(func, *args, **kwargs): | |
if func == torch.ops.aten.size.default: | |
raise RuntimeError( | |
"NestedTensors does not support directly calling torch.ops.aten.size " | |
"please use `nested_tensor.size()` instead." | |
) | |
def is_contiguous_general(func, *args, **kwargs): | |
from torch._prims_common import is_contiguous_for_memory_format | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
# If created from narrow() check for lengths | |
if inp.lengths() is not None: | |
return False | |
new_kwargs["memory_format"] = new_kwargs.get( | |
"memory_format", torch.contiguous_format | |
) | |
if new_kwargs["memory_format"] == torch.preserve_format: | |
return True | |
return is_contiguous_for_memory_format(inp._values, **new_kwargs) | |
register_jagged_func( | |
torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?" | |
)(is_contiguous_general) | |
def linear_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) | |
def linear_backward_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
grad_output = new_kwargs.pop("grad_output") | |
weight = new_kwargs.pop("weight") | |
check_ragged_dim_same(func, inp, "self", grad_output, "grad_output") | |
ds = NestedTensor( | |
torch.mm(grad_output._values, weight), **extract_kwargs(grad_output) | |
) | |
dw = torch.mm(grad_output._values.T, inp._values) | |
db = None # NYI: gradient for bias, need to reduce over ragged dim | |
return (ds, dw, db) | |
def to_copy_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
# don't change layout | |
new_kwargs.pop("layout") | |
new_values = func(inp._values, **new_kwargs) | |
# NB: Purposefully keep offsets on the old device. | |
return NestedTensor(new_values, **extract_kwargs(inp)) | |
register_jagged_func( | |
[ | |
torch.ops.aten.empty_like.default, | |
torch.ops.aten.ones_like.default, | |
torch.ops.aten.zeros_like.default, | |
torch.ops.aten.randn_like.default, | |
torch.ops.aten.detach.default, | |
], | |
"self: jt_all", | |
)(jagged_unary_pointwise) | |
register_jagged_func( | |
torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any" | |
)(jagged_unary_pointwise) | |
def native_dropout_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
out1, out2 = func(inp._values, **new_kwargs) | |
return ( | |
NestedTensor(out1, **extract_kwargs(inp)), | |
NestedTensor(out2, **extract_kwargs(inp)), | |
) | |
def native_dropout_backward_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
grad_output = new_kwargs.pop("grad_output") | |
mask = new_kwargs.pop("mask") | |
return NestedTensor( | |
func(grad_output._values, mask._values, **new_kwargs), | |
**extract_kwargs(grad_output), | |
) | |
def prod_dim_int(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
# TODO: Figure out how to handle this better | |
# keep_dim is required to keep it in jagged format | |
if not new_kwargs["keepdim"]: | |
raise RuntimeError("prod(): keepdim=True must be set for NestedTensor") | |
dim = new_kwargs["dim"] | |
new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), dim, "prod") | |
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(args[0])) | |
def split_tensor(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "split") | |
return tuple( | |
NestedTensor(values=x, **extract_kwargs(inp)) | |
for x in func(inp._values, **new_kwargs) | |
) | |
def split_with_sizes_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
new_kwargs["dim"] = _wrap_jagged_dim( | |
inp.dim(), new_kwargs["dim"], "split_with_sizes" | |
) | |
return [ | |
NestedTensor(values=x, **extract_kwargs(inp)) | |
for x in func(inp._values, **new_kwargs) | |
] | |
def chunk_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
new_kwargs["dim"] = _wrap_jagged_dim( | |
inp.dim(), new_kwargs["dim"], "chunk", allow_batch_dim=True | |
) | |
if new_kwargs["dim"] == 0: | |
chunks = new_kwargs["chunks"] | |
dim0_size = inp._size[0] | |
chunk_size = math.ceil(dim0_size / chunks) | |
# get _offsets of the chunks | |
lengths = inp._offsets.diff() | |
chunked_lengths = lengths.chunk(chunks) | |
chunked_offsets = [torch.cumsum(x, dim=0) for x in chunked_lengths] | |
chunked_offsets = [F.pad(x, (1, 0), value=0) for x in chunked_offsets] | |
nested_kwargs = [ | |
{"offsets": per_offsets, "_ragged_idx": inp._ragged_idx} | |
for per_offsets in chunked_offsets | |
] | |
# get _values of the chunks | |
split_sizes = [x.sum().item() for x in chunked_lengths] | |
chunk_values = inp._values.split(split_sizes) | |
return [ | |
NestedTensor(values=chunk_values[i], **(nested_kwargs[i])) | |
for i in range(0, chunk_size) | |
] | |
else: | |
return [ | |
NestedTensor(values=x, **extract_kwargs(inp)) | |
for x in func(inp._values, **new_kwargs) | |
] | |
def unbind_int(func, *args, **kwargs): | |
# Note that this specializes on the length of the offsets | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
dim = new_kwargs["dim"] | |
if dim != 0: | |
raise RuntimeError("unbind(): only supported for NestedTensor on dim=0") | |
inp = new_kwargs.pop("input") | |
values = inp.values() | |
offsets = inp.offsets() | |
lengths = inp.lengths() | |
if inp._ragged_idx != 1: | |
raise RuntimeError( | |
"unbind(): only supported for NestedTensor when jagged dimension is 1" | |
) | |
if lengths is None: | |
return torch.split(values, offsets.diff().tolist()) | |
return [ | |
values[offsets[i] : (offsets[i] + lengths[i])] for i in range(lengths.shape[0]) | |
] | |
def squeeze_dim(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
values = inp._values | |
new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), new_kwargs["dim"], "squeeze") | |
return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp)) | |
def unsqueeze_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
values = inp._values | |
# Account for collapsed jagged dim | |
dim = new_kwargs["dim"] | |
new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size) + 1, dim, "unsqueeze") | |
return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp)) | |
def cat_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
tensors = new_kwargs.pop("tensors") | |
# Convert any non-nested to nested | |
nested = [t for t in tensors if t.is_nested] | |
assert len(nested) > 0 | |
first = nested[0] | |
tensors = [t if t.is_nested else t.expand_as(first) for t in tensors] | |
# Account for collapsed jagged dim | |
dim = new_kwargs["dim"] | |
new_kwargs["dim"] = _wrap_jagged_dim(len(first.shape), dim, "cat") | |
return NestedTensor( | |
func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0]) | |
) | |
def matmul_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
other = new_kwargs.pop("other") | |
if inp.is_nested and not other.is_nested: | |
return NestedTensor( | |
func(inp._values, other, **new_kwargs), **extract_kwargs(inp) | |
) | |
elif inp.is_nested and other.is_nested: | |
# BMM with equivalent ragged dims between the two inputs | |
if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size): | |
return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp)) | |
raise RuntimeError( | |
f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}" | |
) | |
def expand_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
size = new_kwargs["size"] | |
assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit")) | |
if not raggedness_matches(inp, size): | |
raise RuntimeError(f"expand(): cannot expand shape {inp._size} -> {size}") | |
expand_arg = [-1, *size[2:]] | |
return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp)) | |
def expand_as_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
other = new_kwargs.pop("other") | |
return NestedTensor(func(inp, other._values), **extract_kwargs(other)) | |
def where_self(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
condition = new_kwargs.pop("condition") | |
inp = new_kwargs.pop("input") | |
other = new_kwargs.pop("other") | |
assert condition._size == other._size == inp._size | |
return NestedTensor( | |
func(condition._values, inp._values, other._values, **new_kwargs), | |
**extract_kwargs(condition), | |
) | |
def _pin_memory_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) | |
def is_pinned_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
return func(inp._values, **new_kwargs) | |
def is_same_size_default(func, *args, **kwargs): | |
return args[0]._size == args[1]._size | |
def sum_dim_IntList(func, *args, **kwargs): | |
# sum_dim_IntList can produce a NT or a T depending on whether the ragged dims | |
# are reduced away. | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
assert inp._ragged_idx == 1 | |
new_kwargs["dim"], ragged_reduced_away = _wrap_jagged_dims( | |
inp.dim(), new_kwargs["dim"], "sum" | |
) | |
if not ragged_reduced_away: | |
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) | |
else: | |
# Don't wrap because we reduced away the raggedness | |
out = func(inp._values, **new_kwargs) | |
if new_kwargs["keepdim"]: | |
out = out.unsqueeze(0) | |
return out | |
def transpose_int(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
from torch._prims_common import canonicalize_dims | |
inp = new_kwargs.pop("input") | |
dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"])) | |
if inp._lengths is not None: | |
raise ValueError( | |
"transpose(): not supported on jagged layout nested tensor with holes" | |
) | |
# To support the SDPA API, inputs need to have the ragged idx transposed to dim 2 | |
# instead of 1, although the internal Flash and mem-effn implementations will | |
# use the inputs with raggedness in dim 1. | |
if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx: | |
if dim0 == 0 or dim1 == 0: | |
raise ValueError( | |
"Transpose is not supported on the batch dimension for jagged NT" | |
) | |
if dim0 == inp._ragged_idx: | |
to_dim = dim1 | |
else: | |
to_dim = dim0 | |
inp_kwargs = extract_kwargs(inp) | |
inp_kwargs["_ragged_idx"] = to_dim | |
return NestedTensor( | |
inp.values().transpose( | |
_outer_to_inner_dim(len(inp._size), dim0), | |
_outer_to_inner_dim(len(inp._size), dim1), | |
), | |
**inp_kwargs, | |
) | |
new_kwargs["dim0"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim0"], "transpose") | |
new_kwargs["dim1"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim1"], "transpose") | |
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) | |
def view_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
size = new_kwargs.pop("size") | |
if inp._ragged_idx != 1 and tuple(inp._size) != tuple(size): | |
raise RuntimeError( | |
f"view(): does not support ragged_idx != 1 except when inp._size == size. " | |
f"inp._size is ({inp._size}) and size is ({size})." | |
) | |
# Ensure specified size still includes batch and ragged dims | |
if len(size) < 3 or not raggedness_matches(inp, size): | |
raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}") | |
# outer size: the size of the NT, e.g. [3, j0, 10] | |
# inner size: the size of the values, e.g. [8, 10] (e.g. for offsets = [0, 3, 5, 8]) | |
# this function gets inner_size[inner_idx] for a given inner_idx. | |
# | |
# example: for outer size [a, b, c, j0, d, e, f] | |
# assume that j0 is ragged, other are concrete integers | |
# and ragged_idx=3 | |
# inner size will be [b, c, inp._values.size(ragged_idx), d, e, f] | |
# therefore: | |
# inner_size[0] = outer_size[1] | |
# inner_size[1] = outer_size[2] | |
# inner_size[0] = inp._values.size(ragged_idx - 1) | |
# inner_size[3] = outer_size[4] | |
# inner_size[4] = outer_size[5] | |
def get_inner_size(inner_idx): | |
nonlocal inp, size | |
if inner_idx == inp._ragged_idx - 1: | |
return inp._values.size(inner_idx) | |
else: | |
return size[inner_idx + 1] | |
inner_size = [get_inner_size(i) for i in range(len(size) - 1)] | |
return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp)) | |
def native_layer_norm_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
normalized_shape = new_kwargs["normalized_shape"] | |
# Ensure we're not trying to normalize over the ragged dim | |
if inp.dim() < 3 or (inp.dim() - len(normalized_shape)) < 2: | |
raise RuntimeError( | |
"layer_norm(): normalizing over ragged dim not supported for nested tensors" | |
) | |
output, mean, std = func(inp._values, **new_kwargs) | |
return (NestedTensor(output, **extract_kwargs(inp)), mean, std) | |
def native_layer_norm_backward_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
grad_out = new_kwargs.pop("grad_out") | |
inp = new_kwargs.pop("input") | |
d_input, d_gamma, d_beta = func(grad_out._values, inp._values, **new_kwargs) | |
if d_input is None: | |
return (None, d_gamma, d_beta) | |
return (NestedTensor(d_input, **extract_kwargs(inp)), d_gamma, d_beta) | |
def select_int(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "select") | |
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) | |
def slice_tensor(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "slice") | |
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) | |
def convolution_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) | |
def mean_dim(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
# NB: mean expects dim as a single item list of ints for some reason | |
new_kwargs["dim"] = [_wrap_jagged_dim(inp.dim(), new_kwargs["dim"][0], "mean")] | |
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) | |
def stack_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
# guaranteed this is non-empty if we got here | |
tensors = new_kwargs.pop("tensors") | |
for t in tensors: | |
if not isinstance(t, NestedTensor): | |
raise RuntimeError("stack(): expected all nested tensors inputs") | |
if t.dim() != tensors[0].dim(): | |
raise RuntimeError( | |
"stack(): expected all nested tensors to have the same dim" | |
) | |
if not raggedness_matches(t, tensors[0].shape): | |
raise RuntimeError( | |
"stack(): expected all nested tensors to have the same nested structure" | |
) | |
new_kwargs["dim"] = _wrap_jagged_dim( | |
tensors[0].dim() + 1, new_kwargs["dim"], "stack" | |
) | |
return NestedTensor( | |
func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0]) | |
) | |
def embedding_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
# guaranteed this is non-empty if we got here | |
indices = new_kwargs.pop("indices") | |
weight = new_kwargs.pop("weight") | |
return NestedTensor( | |
func(weight, indices._values, **new_kwargs), **extract_kwargs(indices) | |
) | |
def values_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
# TODO: Handle inference mode properly. | |
# See https://github.com/pytorch/pytorch/issues/112024#issuecomment-1779554292 | |
return inp._values.detach() | |
def _nested_view_from_jagged_default(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
values, offsets, lengths = ( | |
new_kwargs["input"], | |
new_kwargs["offsets"], | |
new_kwargs["lengths"], | |
) | |
ragged_idx = new_kwargs["ragged_idx"] | |
return NestedTensor(values, offsets, lengths=lengths, _ragged_idx=ragged_idx) | |
def _nested_get_offsets(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
return inp._offsets | |
def _nested_get_lengths(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
return inp._lengths | |
def _nested_get_ragged_idx(func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
inp = new_kwargs.pop("input") | |
return inp._ragged_idx | |
# Make the dummy available on the C++ side. | |
def _nested_get_jagged_dummy(func, *args, **kwargs): | |
from torch.nested._internal.nested_tensor import _nt_view_dummy | |
return _nt_view_dummy | |
with torch.library._scoped_library("aten", "IMPL") as aten: | |
aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU") | |
aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA") | |
aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta") | |