Spaces:
Running
Running
File size: 11,491 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
from typing import List, Optional, Union, Sequence
import torch
from torch import SymInt, Tensor
from torch._C import _add_docstr, _nested # type: ignore[attr-defined]
from torch.types import _device as Device, _dtype as DType
__all__ = [
"to_padded_tensor",
"as_nested_tensor",
"nested_tensor",
"narrow",
]
# Nested Tensor constructor functions
def as_nested_tensor(
tensor_list: Sequence[Tensor],
dtype: Optional[DType] = None,
device: Optional[Device] = None,
layout=None
) -> Tensor:
r"""
Constructs a nested tensor preserving autograd history from :attr:`tensor_list` a list of tensors.
.. note::
Tensors within the list are always copied by this function due to current nested tensor semantics.
Args:
tensor_list (List[Tensor]): a list of tensors with the same ndim
Keyword arguments:
dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor.
Default: if None, same :class:`torch.dtype` as leftmost tensor in the list.
device (:class:`torch.device`, optional): the desired device of returned nested tensor.
Default: if None, same :class:`torch.device` as leftmost tensor in the list
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
Only strided and jagged layouts are supported. Default: if None, the strided layout.
Example::
>>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
>>> nt = torch.nested.as_nested_tensor([a, b])
>>> nt.is_leaf
False
>>> fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)])
>>> nt.backward(fake_grad)
>>> a.grad
tensor([1., 1., 1.])
>>> b.grad
tensor([0., 0., 0., 0., 0.])
"""
if not isinstance(tensor_list, list) or any(
not isinstance(t, Tensor) for t in tensor_list
):
raise TypeError(
"as_nested_tensor(): Expected first argument to be a list of tensors "
)
if layout is None:
layout = torch.strided
if layout == torch.strided:
return torch._nested_tensor_from_tensor_list(tensor_list, dtype, None, device, None)
elif layout == torch.jagged:
from torch.nested._internal.nested_tensor import jagged_from_list
nt, _ = jagged_from_list(tensor_list, offsets=None, device=device, dtype=dtype)
return nt
else:
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")
# Note: This not only adds doc strings for the nested ops, but
# also connects the torch.nested Python namespace to the torch._C._nested builtins.
to_padded_tensor = _add_docstr(
_nested.nested_to_padded_tensor,
r"""
to_padded_tensor(input, padding, output_size=None, out=None) -> Tensor
Returns a new (non-nested) Tensor by padding the :attr:`input` nested tensor.
The leading entries will be filled with the nested data,
while the trailing entries will be padded.
.. warning::
:func:`to_padded_tensor` always copies the underlying data,
since the nested and the non-nested tensors differ in memory layout.
Args:
padding (float): The padding value for the trailing entries.
Keyword args:
output_size (Tuple[int]): The size of the output tensor.
If given, it must be large enough to contain all nested data;
else, will infer by taking the max size of each nested sub-tensor along each dimension.
out (Tensor, optional): the output tensor.
Example::
>>> nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))])
nested_tensor([
tensor([[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276],
[-1.9967, -1.0054, 1.8972, 0.9174, -1.4995]]),
tensor([[-1.8546, -0.7194, -0.2918, -0.1846],
[ 0.2773, 0.8793, -0.5183, -0.6447],
[ 1.8009, 1.8468, -0.9832, -1.5272]])
])
>>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0)
tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276],
[-1.9967, -1.0054, 1.8972, 0.9174, -1.4995],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[-1.8546, -0.7194, -0.2918, -0.1846, 0.0000],
[ 0.2773, 0.8793, -0.5183, -0.6447, 0.0000],
[ 1.8009, 1.8468, -0.9832, -1.5272, 0.0000]]])
>>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6))
tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276, 1.0000],
[-1.9967, -1.0054, 1.8972, 0.9174, -1.4995, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
[[-1.8546, -0.7194, -0.2918, -0.1846, 1.0000, 1.0000],
[ 0.2773, 0.8793, -0.5183, -0.6447, 1.0000, 1.0000],
[ 1.8009, 1.8468, -0.9832, -1.5272, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]])
>>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2))
RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported.
""",
)
def nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor:
r"""
Constructs a nested tensor with no autograd history (also known as a “leaf tensor”, see
:ref:`Autograd mechanics <autograd-mechanics>`) from :attr:`tensor_list` a list of tensors.
Args:
tensor_list (List[array_like]): a list of tensors, or anything that can be passed to torch.tensor,
where each element of the list has the same dimensionality.
Keyword arguments:
dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor.
Default: if None, same :class:`torch.dtype` as leftmost tensor in the list.
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
Only strided and jagged layouts are supported. Default: if None, the strided layout.
device (:class:`torch.device`, optional): the desired device of returned nested tensor.
Default: if None, same :class:`torch.device` as leftmost tensor in the list
requires_grad (bool, optional): If autograd should record operations on the
returned nested tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned nested tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
Example::
>>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
>>> nt = torch.nested.nested_tensor([a, b], requires_grad=True)
>>> nt.is_leaf
True
"""
if layout is None:
layout = torch.strided
if layout == torch.strided:
return _nested.nested_tensor(
tensor_list,
dtype=dtype,
device=device,
requires_grad=requires_grad,
pin_memory=pin_memory)
elif layout == torch.jagged:
# Need to wrap lists of scalars as tensors
list_of_tensors = [t if isinstance(t, Tensor) else torch.as_tensor(t) for t in tensor_list]
from torch.nested._internal.nested_tensor import jagged_from_list
with torch.no_grad():
nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
nt.requires_grad_(requires_grad)
if pin_memory:
nt = nt.pin_memory() # type: ignore[assignment]
return nt
else:
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")
def narrow(tensor: Tensor, dim: int, start: Union[int, Tensor], length: Union[int, Tensor], layout=torch.strided) -> Tensor:
r"""
Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows
similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor
shows only the elements in the interval `[start, start+length)`. As nested representations
allow for a different `start` and `length` at each 'row' of that dimension, :attr:`start` and :attr:`length`
can also be tensors of shape `tensor.shape[0]`.
There's some differences depending on the layout you use for the nested tensor. If using strided layout,
torch.narrow will do a copy of the narrowed data into a contiguous NT with strided layout, while
jagged layout narrow() will create a non-contiguous view of your original strided tensor. This particular
representation is really useful for representing kv-caches in Transformer models, as specialized
SDPA kernels can deal with format easily, resulting in performance improvements.
Args:
tensor (:class:`torch.Tensor`): a strided tensor, which will be used as the underlying data
for the nested tensor if using the jagged layout or will be copied for the strided layout.
dim (int): the dimension where narrow will be applied. Only `dim=1` is supported for the
jagged layout, while strided supports all dim
start (Union[int, :class:`torch.Tensor`]): starting element for the narrow operation
length (Union[int, :class:`torch.Tensor`]): number of elements taken during the narrow op
Keyword arguments:
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
Only strided and jagged layouts are supported. Default: if None, the strided layout.
Example::
>>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64)
>>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64)
>>> narrow_base = torch.randn(5, 10, 20)
>>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged)
>>> nt_narrowed.is_contiguous()
False
"""
if not isinstance(start, (int, SymInt, Tensor)):
raise RuntimeError("start must be an integer or a tensor")
if not isinstance(length, (int, SymInt, Tensor)):
raise RuntimeError("length must be an integer or a tensor")
if layout == torch.strided:
if isinstance(start, Tensor) or isinstance(length, Tensor):
raise RuntimeError("start and length must be integers for the strided layout NT impl")
# TODO: switch to as_nested_tensor(tensor) when it is available
nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow(dim, start, length)
elif layout == torch.jagged:
if dim != 1:
raise RuntimeError("jagged layout only supports dim=1")
from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths
if isinstance(start, (int, SymInt)):
start = torch.tensor([start], device=tensor.device, dtype=torch.int64)
if isinstance(length, (int, SymInt)):
length = torch.tensor([length], device=tensor.device, dtype=torch.int64)
nt, _, _ = jagged_from_tensor_and_lengths(tensor, start, length)
else:
raise RuntimeError(f"Specified layout is unsupported for nested narrow: {layout}")
return nt
|