Spaces:
Sleeping
Sleeping
import copyreg | |
import enum | |
import functools | |
import warnings | |
from collections import OrderedDict | |
from copy import deepcopy | |
from numbers import Number | |
from typing import Any, Dict, Optional, Tuple, Union | |
import torch | |
import torch._C as _C | |
import torch.utils.hooks as hooks | |
from torch._namedtensor_internals import ( | |
check_serializing_named_tensor, | |
is_ellipsis, | |
resolve_ellipsis, | |
single_ellipsis_index, | |
unzip_namedshape, | |
update_names, | |
) | |
from torch.overrides import ( | |
get_default_nowrap_functions, | |
handle_torch_function, | |
has_torch_function, | |
has_torch_function_unary, | |
has_torch_function_variadic, | |
) | |
from torch.utils.dlpack import DLDeviceType | |
def _handle_torch_function_and_wrap_type_error_to_not_implemented(f): | |
assigned = functools.WRAPPER_ASSIGNMENTS | |
def wrapped(*args, **kwargs): | |
try: | |
# See https://github.com/pytorch/pytorch/issues/75462 | |
if has_torch_function(args): | |
return handle_torch_function(wrapped, args, *args, **kwargs) | |
return f(*args, **kwargs) | |
except TypeError: | |
return NotImplemented | |
return wrapped | |
# Should not be used, this is kept only for BC of loading old serialized Tensor subclasses | |
def _rebuild_from_type(func, type, args, dict): | |
if type is Tensor: | |
return func(*args) | |
ret = func(*args).as_subclass(type) | |
ret.__dict__ = dict | |
return ret | |
def _rebuild_from_type_v2(func, new_type, args, state): | |
ret = func(*args) | |
if type(ret) is not new_type: | |
ret = ret.as_subclass(new_type) | |
# Tensor does define __setstate__ even though it doesn't define | |
# __getstate__. So only use __setstate__ if it is NOT the one defined | |
# on Tensor | |
if ( | |
getattr(ret.__class__, "__setstate__", Tensor.__setstate__) | |
is not Tensor.__setstate__ | |
): | |
ret.__setstate__(state) | |
else: | |
ret = torch._utils._set_obj_state(ret, state) | |
return ret | |
# NB: If you subclass Tensor, and want to share the subclassed class | |
# across processes, you must also update torch/multiprocessing/reductions.py | |
# to define a ForkingPickler serialization mode for the class. | |
# | |
# NB: If you add a new method to Tensor, you must update | |
# torch/_C/__init__.pyi.in to add a type annotation for your method; | |
# otherwise, it will not show up in autocomplete. | |
class Tensor(torch._C.TensorBase): | |
def __deepcopy__(self, memo): | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo) | |
if not self.is_leaf: | |
raise RuntimeError( | |
"Only Tensors created explicitly by the user " | |
"(graph leaves) support the deepcopy protocol at the moment. " | |
"If you were attempting to deepcopy a module, this may be because " | |
"of a torch.nn.utils.weight_norm usage, " | |
"see https://github.com/pytorch/pytorch/pull/103001" | |
) | |
if id(self) in memo: | |
return memo[id(self)] | |
with torch.no_grad(): | |
# TODO: skipping storage copy is wrong for meta, as meta | |
# does accurate alias tracking; however, the code below | |
# doesn't work because of | |
# https://github.com/pytorch/pytorch/issues/47442 | |
# Update the test in test_serialization if you remove 'meta' from here | |
if ( | |
self.is_sparse | |
or self.device.type | |
in ["lazy", "xla", "mtia", "mps", "ort", "meta", "ipu"] | |
or ( | |
not torch._C._has_storage(self) | |
and self.device.type == torch._C._get_privateuse1_backend_name() | |
) | |
or (type(self) is not Tensor and self.data_ptr() == 0) | |
): | |
new_tensor = self.clone() | |
if type(new_tensor) is not type(self): | |
raise RuntimeError( | |
"The default implementation of __deepcopy__() for wrapper subclasses " | |
"only works for subclass types that implement clone() and for which " | |
"cloning returns another instance of the same subclass. You should either " | |
"properly implement clone() for your subclass or override __deepcopy__() " | |
"if it is intended behavior for clone() to return an instance of a " | |
"different type." | |
) | |
else: | |
new_storage = self._typed_storage()._deepcopy(memo) | |
if self.is_quantized: | |
# quantizer_params can be different type based on torch attribute | |
quantizer_params: Union[ | |
Tuple[torch.qscheme, float, int], | |
Tuple[torch.qscheme, Tensor, Tensor, int], | |
] | |
if self.qscheme() == torch.per_tensor_affine: | |
quantizer_params = ( | |
self.qscheme(), | |
self.q_scale(), | |
self.q_zero_point(), | |
) | |
elif self.qscheme() in ( | |
torch.per_channel_affine, | |
torch.per_channel_affine_float_qparams, | |
): | |
quantizer_params = ( | |
self.qscheme(), | |
self.q_per_channel_scales(), | |
self.q_per_channel_zero_points(), | |
self.q_per_channel_axis(), | |
) | |
else: | |
raise RuntimeError( | |
f"Unsupported qscheme {self.qscheme()} in deepcopy" | |
) | |
# TODO: Once we decide to break serialization FC, no longer | |
# need to wrap with TypedStorage | |
new_tensor = torch._utils._rebuild_qtensor( | |
torch.storage.TypedStorage( | |
wrap_storage=new_storage._untyped_storage, | |
dtype=self.dtype, | |
_internal=True, | |
), | |
self.storage_offset(), | |
self.size(), | |
self.stride(), | |
quantizer_params, | |
self.requires_grad, | |
self._backward_hooks, | |
) | |
if type(new_tensor) is not type(self): | |
raise RuntimeError( | |
"The default implementation of __deepcopy__() for quantized tensors " | |
"expects the tensor returned by torch._utils._rebuild_qtensor() to " | |
"match the type of the instance being copied. If you encounter this, " | |
"please open an issue on PyTorch's GitHub." | |
) | |
else: | |
new_tensor = self.new_empty([]) | |
if type(new_tensor) is not type(self): | |
raise RuntimeError( | |
"The default implementation of __deepcopy__() for non-wrapper subclasses " | |
"only works for subclass types that implement new_empty() and for which " | |
"that function returns another instance of the same subclass. You should " | |
"either properly implement new_empty() for your subclass or override " | |
"__deepcopy__() if it is intended behavior for new_empty() to return " | |
"an instance of a different type." | |
) | |
new_tensor.set_( | |
new_storage, self.storage_offset(), self.size(), self.stride() | |
) | |
if self.is_conj(): | |
new_tensor = new_tensor.conj_physical() | |
if self.is_neg(): | |
new_tensor = new_tensor.neg() | |
if self.requires_grad: | |
new_tensor.requires_grad_() | |
if self.grad is not None: | |
new_tensor.grad = self.grad.__deepcopy__(memo) | |
if type(self) is not Tensor: | |
if type(new_tensor) is not type(self): | |
raise RuntimeError( | |
"Type of deepcopy result does not match the type of the source tensor. " | |
"If you encounter this, please open an issue on PyTorch's GitHub." | |
) | |
# Plain Tensors don't have slots | |
slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] | |
for slot in slots_to_save: | |
if hasattr(self, slot): | |
setattr(new_tensor, slot, deepcopy(getattr(self, slot), memo)) | |
new_tensor.__dict__ = deepcopy(self.__dict__, memo) | |
memo[id(self)] = new_tensor | |
return new_tensor | |
def __reduce_ex__(self, proto): | |
state = torch._utils._get_obj_state(self) | |
if type(self) is Tensor and not state: | |
# Fast path for regular tensor without Python state. | |
return self._reduce_ex_internal(proto) | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto) | |
func, args = self._reduce_ex_internal(proto) | |
return (_rebuild_from_type_v2, (func, type(self), args, state)) | |
def storage(self): | |
r""" | |
storage() -> torch.TypedStorage | |
Returns the underlying :class:`TypedStorage`. | |
.. warning:: | |
:class:`TypedStorage` is deprecated. It will be removed in the future, and | |
:class:`UntypedStorage` will be the only storage class. To access the | |
:class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`. | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.storage, (self,), self) | |
torch.storage._warn_typed_storage_removal(stacklevel=2) | |
return self._typed_storage() | |
# For internal use only, to avoid raising deprecation warning | |
def _typed_storage(self): | |
untyped_storage = self.untyped_storage() | |
return torch.TypedStorage( | |
wrap_storage=untyped_storage, dtype=self.dtype, _internal=True | |
) | |
def _reduce_ex_internal(self, proto): | |
check_serializing_named_tensor(self) | |
# See Note [Don't serialize hooks] | |
torch.utils.hooks.warn_if_has_hooks(self) | |
backward_hooks: Dict[Any, Any] = OrderedDict() | |
# Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, ORT Tensors. | |
# We considered a few options: | |
# 1. CPU tensor can't be used here. | |
# Otherwise in torch.load CPU storage is reconstructed with randomly | |
# initialized data, moved onto backend device, and then storage is updated | |
# to the serialized content. This works perfectly for CPU/CUDA but not these backends; | |
# their tensors are disconnected with storage so they don't get the update. | |
# 2. Python list is not a good fit due to performance reason. | |
# `tolist()` converts every single element in the tensor into python objects | |
# and serialize them one by one. | |
if self.device.type in ["xla", "mtia", "ort"] or ( | |
not torch._C._has_storage(self) | |
and self.device.type == torch._C._get_privateuse1_backend_name() | |
): | |
# Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't | |
# support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, | |
# this would reconstruct the BFloat16 tensor from numpy. | |
numpy_tensor = ( | |
self.cpu().numpy() | |
if self.dtype != torch.bfloat16 | |
else self.cpu().to(torch.float32).numpy() | |
) | |
return ( | |
torch._utils._rebuild_device_tensor_from_numpy, | |
(numpy_tensor, self.dtype, str(self.device), self.requires_grad), | |
) | |
if self.device.type == "meta": | |
# NB: This implementation BREAKS storage sharing. Current | |
# hypothesis is that no one cares for meta tensors. | |
arg_meta = ( | |
self.dtype, | |
tuple(self.size()), | |
self.stride(), | |
self.requires_grad, | |
) | |
return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) | |
if self.is_quantized: | |
# quantizer_params can be different type based on torch attribute | |
quantizer_params: Union[ | |
Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int] | |
] | |
if self.qscheme() == torch.per_tensor_affine: | |
quantizer_params = ( | |
torch.per_tensor_affine, | |
self.q_scale(), | |
self.q_zero_point(), | |
) | |
elif self.qscheme() in ( | |
torch.per_channel_affine, | |
torch.per_channel_affine_float_qparams, | |
): | |
# convert scales and zero points to tuple to avoid recursive calls | |
# when/if we get multi-axis quantized tensors in the future, the shape | |
# is recoverable from the main tensor shape | |
quantizer_params = ( | |
torch.per_channel_affine, | |
self.q_per_channel_scales(), | |
self.q_per_channel_zero_points(), | |
self.q_per_channel_axis(), | |
) | |
else: | |
raise RuntimeError( | |
f"Serialization is not supported for tensors of type {self.qscheme()}" | |
) | |
# TODO: Once we decide to break serialization FC, no longer | |
# need to wrap with TypedStorage | |
args_qtensor = ( | |
torch.storage.TypedStorage( | |
wrap_storage=self._typed_storage()._untyped_storage, | |
dtype=self.dtype, | |
_internal=True, | |
), | |
self.storage_offset(), | |
tuple(self.size()), | |
self.stride(), | |
quantizer_params, | |
self.requires_grad, | |
backward_hooks, | |
) | |
return (torch._utils._rebuild_qtensor, args_qtensor) | |
elif self.is_sparse: | |
if self.layout == torch.sparse_coo: | |
args_sparse = ( | |
self.layout, | |
(self._indices(), self._values(), self.size(), self.is_coalesced()), | |
) | |
else: | |
raise NotImplementedError( | |
f"sparse tensor __reduce_ex__ for layout `{self.layout}`" | |
) | |
return (torch._utils._rebuild_sparse_tensor, args_sparse) | |
elif self.layout in { | |
torch.sparse_csr, | |
torch.sparse_csc, | |
torch.sparse_bsr, | |
torch.sparse_bsc, | |
}: | |
if self.layout in {torch.sparse_csr, torch.sparse_bsr}: | |
compressed_indices, plain_indices = ( | |
self.crow_indices(), | |
self.col_indices(), | |
) | |
else: | |
compressed_indices, plain_indices = ( | |
self.ccol_indices(), | |
self.row_indices(), | |
) | |
args_sparse_compressed = ( | |
self.layout, | |
( | |
compressed_indices, | |
plain_indices, | |
self.values(), | |
self.size(), | |
), | |
) | |
return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) | |
elif self.is_nested: | |
args_nested = ( | |
# NB: values() currently returns the storage as a buffer in an unsafe way. | |
# Ideally, we'd use a private API for this instead. TODO: Switch to this if | |
# we ever get around to adding it. | |
self.values(), | |
self._nested_tensor_size(), | |
self._nested_tensor_strides(), | |
self._nested_tensor_storage_offsets(), | |
) | |
return (torch._utils._rebuild_nested_tensor, args_nested) | |
elif ( | |
self.data_ptr() == 0 | |
and type(self) is not torch.Tensor | |
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ | |
): | |
arg_wrapper_subclass = ( | |
type(self), | |
self.dtype, | |
tuple(self.size()), | |
self.stride(), | |
self.storage_offset(), | |
self.layout, | |
self.device, | |
self.requires_grad, | |
) | |
return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) | |
else: | |
v3_dtypes = [ | |
torch.float8_e5m2, | |
torch.float8_e4m3fn, | |
torch.float8_e5m2fnuz, | |
torch.float8_e4m3fnuz, | |
torch.bits8, | |
torch.bits16, | |
torch.bits1x8, | |
torch.bits2x4, | |
torch.bits4x2, | |
torch.complex32, | |
] | |
if self.dtype in v3_dtypes: | |
rebuild_func = torch._utils._rebuild_tensor_v3 | |
storage = self.untyped_storage() | |
else: | |
# TODO: Once we decide to break serialization FC, no longer | |
# need to wrap with TypedStorage | |
rebuild_func = torch._utils._rebuild_tensor_v2 # type: ignore[assignment] | |
storage = torch.storage.TypedStorage( | |
wrap_storage=self._typed_storage()._untyped_storage, | |
dtype=self.dtype, | |
_internal=True, | |
) # type: ignore[assignment] | |
args = ( | |
storage, | |
self.storage_offset(), | |
tuple(self.size()), | |
self.stride(), | |
self.requires_grad, | |
backward_hooks, | |
) # previously was self._backward_hooks | |
if isinstance(storage, torch.storage.UntypedStorage): | |
args = args + (self.dtype,) # type: ignore[assignment] | |
metadata = torch._utils.get_tensor_metadata(self) | |
if metadata: | |
args = args + (metadata,) # type: ignore[assignment] | |
return (rebuild_func, args) | |
def __setstate__(self, state): | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.__setstate__, (self,), self, state) | |
# Warning: this method is NOT called when you torch.load() a tensor; | |
# that is managed by _rebuild_tensor_v2 | |
if not self.is_leaf: | |
raise RuntimeError("__setstate__ can be only called on leaf Tensors") | |
if len(state) == 4: | |
# legacy serialization of Tensor | |
self.set_(*state) | |
return | |
elif len(state) == 5: | |
# legacy serialization of Variable | |
self.data = state[0] | |
state = (state[3], state[4], state[2]) | |
# The setting of _backward_hooks is expected to be a no-op. | |
# See Note [Don't serialize hooks] | |
self.requires_grad, _, self._backward_hooks = state | |
def __repr__(self, *, tensor_contents=None): | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.__repr__, (self,), self, tensor_contents=tensor_contents | |
) | |
# All strings are unicode in Python 3. | |
return torch._tensor_str._str(self, tensor_contents=tensor_contents) | |
def backward( | |
self, gradient=None, retain_graph=None, create_graph=False, inputs=None | |
): | |
r"""Computes the gradient of current tensor wrt graph leaves. | |
The graph is differentiated using the chain rule. If the tensor is | |
non-scalar (i.e. its data has more than one element) and requires | |
gradient, the function additionally requires specifying ``gradient``. | |
It should be a tensor of matching type and location, that contains | |
the gradient of the differentiated function w.r.t. ``self``. | |
This function accumulates gradients in the leaves - you might need to zero | |
``.grad`` attributes or set them to ``None`` before calling it. | |
See :ref:`Default gradient layouts<default-grad-layouts>` | |
for details on the memory layout of accumulated gradients. | |
.. note:: | |
If you run any forward ops, create ``gradient``, and/or call ``backward`` | |
in a user-specified CUDA stream context, see | |
:ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`. | |
.. note:: | |
When ``inputs`` are provided and a given input is not a leaf, | |
the current implementation will call its grad_fn (though it is not strictly needed to get this gradients). | |
It is an implementation detail on which the user should not rely. | |
See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. | |
Args: | |
gradient (Tensor or None): Gradient w.r.t. the | |
tensor. If it is a tensor, it will be automatically converted | |
to a Tensor that does not require grad unless ``create_graph`` is True. | |
None values can be specified for scalar Tensors or ones that | |
don't require grad. If a None value would be acceptable then | |
this argument is optional. | |
retain_graph (bool, optional): If ``False``, the graph used to compute | |
the grads will be freed. Note that in nearly all cases setting | |
this option to True is not needed and often can be worked around | |
in a much more efficient way. Defaults to the value of | |
``create_graph``. | |
create_graph (bool, optional): If ``True``, graph of the derivative will | |
be constructed, allowing to compute higher order derivative | |
products. Defaults to ``False``. | |
inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be | |
accumulated into ``.grad``. All other Tensors will be ignored. If not | |
provided, the gradient is accumulated into all the leaf Tensors that were | |
used to compute the attr::tensors. | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.backward, | |
(self,), | |
self, | |
gradient=gradient, | |
retain_graph=retain_graph, | |
create_graph=create_graph, | |
inputs=inputs, | |
) | |
torch.autograd.backward( | |
self, gradient, retain_graph, create_graph, inputs=inputs | |
) | |
def register_hook(self, hook): | |
r"""Registers a backward hook. | |
The hook will be called every time a gradient with respect to the | |
Tensor is computed. The hook should have the following signature:: | |
hook(grad) -> Tensor or None | |
The hook should not modify its argument, but it can optionally return | |
a new gradient which will be used in place of :attr:`grad`. | |
This function returns a handle with a method ``handle.remove()`` | |
that removes the hook from the module. | |
.. note:: | |
See :ref:`backward-hooks-execution` for more information on how when this hook | |
is executed, and how its execution is ordered relative to other hooks. | |
Example:: | |
>>> v = torch.tensor([0., 0., 0.], requires_grad=True) | |
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient | |
>>> v.backward(torch.tensor([1., 2., 3.])) | |
>>> v.grad | |
2 | |
4 | |
6 | |
[torch.FloatTensor of size (3,)] | |
>>> h.remove() # removes the hook | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.register_hook, (self,), self, hook) | |
if not self.requires_grad: | |
raise RuntimeError( | |
"cannot register a hook on a tensor that doesn't require gradient" | |
) | |
if self._backward_hooks is None: | |
self._backward_hooks = OrderedDict() | |
if self.grad_fn is not None: | |
self.grad_fn._register_hook_dict(self) | |
handle = hooks.RemovableHandle(self._backward_hooks) | |
self._backward_hooks[handle.id] = hook | |
return handle | |
def register_post_accumulate_grad_hook(self, hook): | |
r"""Registers a backward hook that runs after grad accumulation. | |
The hook will be called after all gradients for a tensor have been accumulated, | |
meaning that the .grad field has been updated on that tensor. The post | |
accumulate grad hook is ONLY applicable for leaf tensors (tensors without a | |
.grad_fn field). Registering this hook on a non-leaf tensor will error! | |
The hook should have the following signature:: | |
hook(param: Tensor) -> None | |
Note that, unlike other autograd hooks, this hook operates on the tensor | |
that requires grad and not the grad itself. The hook can in-place modify | |
and access its Tensor argument, including its .grad field. | |
This function returns a handle with a method ``handle.remove()`` | |
that removes the hook from the module. | |
.. note:: | |
See :ref:`backward-hooks-execution` for more information on how when this hook | |
is executed, and how its execution is ordered relative to other hooks. Since | |
this hook runs during the backward pass, it will run in no_grad mode (unless | |
create_graph is True). You can use torch.enable_grad() to re-enable autograd | |
within the hook if you need it. | |
Example:: | |
>>> v = torch.tensor([0., 0., 0.], requires_grad=True) | |
>>> lr = 0.01 | |
>>> # simulate a simple SGD update | |
>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr)) | |
>>> v.backward(torch.tensor([1., 2., 3.])) | |
>>> v | |
tensor([-0.0100, -0.0200, -0.0300], requires_grad=True) | |
>>> h.remove() # removes the hook | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.register_post_accumulate_grad_hook, (self,), self, hook | |
) | |
if not self.requires_grad: | |
raise RuntimeError( | |
"cannot register a hook on a tensor that doesn't require gradient" | |
) | |
if self.grad_fn is not None: | |
raise RuntimeError( | |
"post accumulate grad hooks cannot be registered on non-leaf tensors" | |
) | |
if self._post_accumulate_grad_hooks is None: | |
self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict() | |
handle = hooks.RemovableHandle(self._post_accumulate_grad_hooks) | |
self._post_accumulate_grad_hooks[handle.id] = hook | |
return handle | |
def reinforce(self, reward): | |
def trim(str): | |
return "\n".join([line.strip() for line in str.split("\n")]) | |
raise RuntimeError( | |
trim( | |
r"""reinforce() was removed. | |
Use torch.distributions instead. | |
See https://pytorch.org/docs/master/distributions.html | |
Instead of: | |
probs = policy_network(state) | |
action = probs.multinomial() | |
next_state, reward = env.step(action) | |
action.reinforce(reward) | |
action.backward() | |
Use: | |
probs = policy_network(state) | |
# NOTE: categorical is equivalent to what used to be called multinomial | |
m = torch.distributions.Categorical(probs) | |
action = m.sample() | |
next_state, reward = env.step(action) | |
loss = -m.log_prob(action) * reward | |
loss.backward() | |
""" | |
) | |
) | |
detach = _C._add_docstr( | |
_C.TensorBase.detach, | |
r""" | |
Returns a new Tensor, detached from the current graph. | |
The result will never require gradient. | |
This method also affects forward mode AD gradients and the result will never | |
have forward mode AD gradients. | |
.. note:: | |
Returned Tensor shares the same storage with the original one. | |
In-place modifications on either of them will be seen, and may trigger | |
errors in correctness checks. | |
""", | |
) | |
detach_ = _C._add_docstr( | |
_C.TensorBase.detach_, | |
r""" | |
Detaches the Tensor from the graph that created it, making it a leaf. | |
Views cannot be detached in-place. | |
This method also affects forward mode AD gradients and the result will never | |
have forward mode AD gradients. | |
""", | |
) | |
def is_shared(self): | |
r"""Checks if tensor is in shared memory. | |
This is always ``True`` for CUDA tensors. | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.is_shared, (self,), self) | |
return self._typed_storage()._is_shared() | |
def share_memory_(self): | |
r"""Moves the underlying storage to shared memory. | |
This is a no-op if the underlying storage is already in shared memory | |
and for CUDA tensors. Tensors in shared memory cannot be resized. | |
See :meth:`torch.UntypedStorage.share_memory_` for more details. | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.share_memory_, (self,), self) | |
self._typed_storage()._share_memory_() | |
return self | |
def module_load(self, other, assign=False): | |
r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`. | |
Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. | |
It is expected that ``self`` is a parameter or buffer in an ``nn.Module`` and ``other`` is the | |
value in the state dictionary with the corresponding key, this method defines | |
how ``other`` is remapped before being swapped with ``self`` via | |
:func:`~torch.utils.swap_tensors`` in ``module.load_state_dict()``. | |
.. note:: | |
This method should always return a new object that is not ``self`` or ``other``. | |
For example, the default implementation returns ``self.copy_(other).detach()`` | |
if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``. | |
Args: | |
other (Tensor): value in state dict with key corresponding to ``self`` | |
assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict` | |
""" | |
if has_torch_function_variadic(self, other): | |
return handle_torch_function( | |
Tensor.module_load, (self, other), self, other, assign=assign | |
) | |
if assign: | |
return other.detach() | |
else: | |
return self.copy_(other).detach() | |
def __reversed__(self): | |
r"""Reverses the tensor along dimension 0.""" | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.__reversed__, (self,), self) | |
if self.dim() == 0: | |
return self | |
else: | |
return self.flip(0) | |
def norm( | |
self, | |
p: Optional[Union[float, str]] = "fro", | |
dim=None, | |
keepdim=False, | |
dtype=None, | |
): | |
r"""See :func:`torch.norm`""" | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype | |
) | |
return torch.norm(self, p, dim, keepdim, dtype=dtype) | |
def solve(self, other): | |
from ._linalg_utils import solve | |
return solve(self, other) | |
def lstsq(self, other): | |
from ._linalg_utils import lstsq | |
return lstsq(self, other) | |
def eig(self, eigenvectors=False): | |
from ._linalg_utils import eig | |
return eig(self, eigenvectors=eigenvectors) | |
def symeig(self, eigenvectors=False): | |
from ._linalg_utils import _symeig | |
return _symeig(self, eigenvectors=eigenvectors) | |
def lu(self, pivot=True, get_infos=False): | |
r"""See :func:`torch.lu`""" | |
# If get_infos is True, then we don't need to check for errors and vice versa | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos | |
) | |
LU, pivots, infos = torch._lu_with_info( | |
self, pivot=pivot, check_errors=(not get_infos) | |
) | |
if get_infos: | |
return LU, pivots, infos | |
else: | |
return LU, pivots | |
def stft( | |
self, | |
n_fft: int, | |
hop_length: Optional[int] = None, | |
win_length: Optional[int] = None, | |
window: "Optional[Tensor]" = None, | |
center: bool = True, | |
pad_mode: str = "reflect", | |
normalized: bool = False, | |
onesided: Optional[bool] = None, | |
return_complex: Optional[bool] = None, | |
): | |
r"""See :func:`torch.stft` | |
.. warning:: | |
This function changed signature at version 0.4.1. Calling with | |
the previous signature may cause error or return incorrect result. | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.stft, | |
(self,), | |
self, | |
n_fft, | |
hop_length=hop_length, | |
win_length=win_length, | |
window=window, | |
center=center, | |
pad_mode=pad_mode, | |
normalized=normalized, | |
onesided=onesided, | |
return_complex=return_complex, | |
) | |
return torch.stft( | |
self, | |
n_fft, | |
hop_length, | |
win_length, | |
window, | |
center, | |
pad_mode, | |
normalized, | |
onesided, | |
return_complex=return_complex, | |
) | |
def istft( | |
self, | |
n_fft: int, | |
hop_length: Optional[int] = None, | |
win_length: Optional[int] = None, | |
window: "Optional[Tensor]" = None, | |
center: bool = True, | |
normalized: bool = False, | |
onesided: Optional[bool] = None, | |
length: Optional[int] = None, | |
return_complex: bool = False, | |
): | |
r"""See :func:`torch.istft`""" | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.istft, | |
(self,), | |
self, | |
n_fft, | |
hop_length=hop_length, | |
win_length=win_length, | |
window=window, | |
center=center, | |
normalized=normalized, | |
onesided=onesided, | |
length=length, | |
return_complex=return_complex, | |
) | |
return torch.istft( | |
self, | |
n_fft, | |
hop_length, | |
win_length, | |
window, | |
center, | |
normalized, | |
onesided, | |
length, | |
return_complex=return_complex, | |
) | |
def resize(self, *sizes): | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.resize, (self,), self, *sizes) | |
warnings.warn("non-inplace resize is deprecated") | |
from torch.autograd._functions import Resize | |
return Resize.apply(self, sizes) | |
def resize_as(self, tensor): | |
if has_torch_function_variadic(self, tensor): | |
return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor) | |
warnings.warn("non-inplace resize_as is deprecated") | |
from torch.autograd._functions import Resize | |
return Resize.apply(self, tensor.size()) | |
def split(self, split_size, dim=0): | |
r"""See :func:`torch.split`""" | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.split, (self,), self, split_size, dim=dim | |
) | |
if isinstance(split_size, Tensor): | |
try: | |
split_size = int(split_size) | |
except ValueError: | |
pass | |
if isinstance(split_size, (int, torch.SymInt)): | |
return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined] | |
else: | |
return torch._VF.split_with_sizes(self, split_size, dim) | |
def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None): | |
r"""Returns the unique elements of the input tensor. | |
See :func:`torch.unique` | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.unique, | |
(self,), | |
self, | |
sorted=sorted, | |
return_inverse=return_inverse, | |
return_counts=return_counts, | |
dim=dim, | |
) | |
return torch.unique( | |
self, | |
sorted=sorted, | |
return_inverse=return_inverse, | |
return_counts=return_counts, | |
dim=dim, | |
) | |
def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None): | |
r"""Eliminates all but the first element from every consecutive group of equivalent elements. | |
See :func:`torch.unique_consecutive` | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.unique_consecutive, | |
(self,), | |
self, | |
return_inverse=return_inverse, | |
return_counts=return_counts, | |
dim=dim, | |
) | |
return torch.unique_consecutive( | |
self, return_inverse=return_inverse, return_counts=return_counts, dim=dim | |
) | |
def __rsub__(self, other): | |
return _C._VariableFunctions.rsub(self, other) | |
def __rdiv__(self, other): | |
return self.reciprocal() * other | |
__rtruediv__ = __rdiv__ | |
__itruediv__ = _C.TensorBase.__idiv__ | |
__pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( | |
_C.TensorBase.pow | |
) | |
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( | |
_C.TensorBase.pow_ | |
) | |
def __rmod__(self, other): | |
return torch.remainder(other, self) | |
def __format__(self, format_spec): | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.__format__, (self,), self, format_spec) | |
if self.dim() == 0 and not self.is_meta and type(self) is Tensor: | |
return self.item().__format__(format_spec) | |
return object.__format__(self, format_spec) | |
def __rpow__(self, other): | |
return torch.pow(other, self) | |
def __floordiv__(self, other): | |
return torch.floor_divide(self, other) | |
def __rfloordiv__(self, other): | |
return torch.floor_divide(other, self) | |
def __rlshift__(self, other): | |
return torch.bitwise_left_shift(other, self) | |
def __rrshift__(self, other): | |
return torch.bitwise_right_shift(other, self) | |
def __rmatmul__(self, other): | |
return torch.matmul(other, self) | |
__pos__ = _C.TensorBase.positive | |
__neg__ = _C.TensorBase.neg | |
__abs__ = _C.TensorBase.abs | |
def __len__(self): | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.__len__, (self,), self) | |
if self.dim() == 0: | |
raise TypeError("len() of a 0-d tensor") | |
if torch._C._get_tracing_state(): | |
warnings.warn( | |
"Using len to get tensor shape might cause the trace to be incorrect. " | |
"Recommended usage would be tensor.shape[0]. " | |
"Passing a tensor of different shape might lead to errors or silently give " | |
"incorrect results.", | |
category=torch.jit.TracerWarning, | |
stacklevel=2, | |
) | |
return self.shape[0] | |
def __iter__(self): | |
# NB: we use 'imap' and not 'map' here, so that in Python 2 we get a | |
# generator and don't eagerly perform all the indexes. This could | |
# save us work, and also helps keep trace ordering deterministic | |
# (e.g., if you zip(*hiddens), the eager map will force all the | |
# indexes of hiddens[0] before hiddens[1], while the generator | |
# map will interleave them.) | |
# NB: We have intentionally skipped __torch_function__ dispatch here. | |
# See gh-54457 | |
if self.dim() == 0: | |
raise TypeError("iteration over a 0-d tensor") | |
if torch._C._get_tracing_state(): | |
warnings.warn( | |
"Iterating over a tensor might cause the trace to be incorrect. " | |
"Passing a tensor of different shape won't change the number of " | |
"iterations executed (and might lead to errors or silently give " | |
"incorrect results).", | |
category=torch.jit.TracerWarning, | |
stacklevel=2, | |
) | |
return iter(self.unbind(0)) | |
def __hash__(self): | |
# Do NOT handle __torch_function__ here as user's default | |
# implementation that handle most functions will most likely do it wrong. | |
# It can be easily overridden by defining this method on the user | |
# subclass if needed. | |
return id(self) | |
def __dir__(self): | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.__dir__, (self,), self) | |
tensor_methods = dir(self.__class__) | |
tensor_methods.remove("volatile") # deprecated | |
attrs = list(self.__dict__.keys()) | |
keys = tensor_methods + attrs | |
# property only available dense, cuda tensors | |
if (not self.is_cuda) or self.is_sparse: | |
keys.remove("__cuda_array_interface__") | |
return sorted(keys) | |
# Numpy array interface, to support `numpy.asarray(tensor) -> ndarray` | |
__array_priority__ = 1000 # prefer Tensor ops over numpy ones | |
def __array__(self, dtype=None): | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype) | |
if dtype is None: | |
return self.numpy() | |
else: | |
return self.numpy().astype(dtype, copy=False) | |
# Wrap Numpy array again in a suitable tensor when done, to support e.g. | |
# `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor` | |
def __array_wrap__(self, array): | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.__array_wrap__, (self,), self, array=array | |
) | |
if array.dtype == bool: | |
# Workaround, torch has no built-in bool tensor | |
array = array.astype("uint8") | |
return torch.from_numpy(array) | |
def __contains__(self, element): | |
r"""Check if `element` is present in tensor | |
Args: | |
element (Tensor or scalar): element to be checked | |
for presence in current tensor" | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.__contains__, (self,), self, element) | |
if isinstance( | |
element, (torch.Tensor, Number, torch.SymInt, torch.SymFloat, torch.SymBool) | |
): | |
# type hint doesn't understand the __contains__ result array | |
return (element == self).any().item() # type: ignore[union-attr] | |
raise RuntimeError( | |
f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type(element)}." | |
) | |
def __cuda_array_interface__(self): | |
"""Array view description for cuda tensors. | |
See: | |
https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html | |
""" | |
if has_torch_function_unary(self): | |
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185 | |
return handle_torch_function(Tensor.__cuda_array_interface__.__get__, (self,), self) # type: ignore[attr-defined] | |
# raise AttributeError for unsupported tensors, so that | |
# hasattr(cpu_tensor, "__cuda_array_interface__") is False. | |
if not self.is_cuda: | |
raise AttributeError( | |
"Can't get __cuda_array_interface__ on non-CUDA tensor type: %s " | |
"If CUDA data is required use tensor.cuda() to copy tensor to device memory." | |
% self.type() | |
) | |
if self.is_sparse: | |
raise AttributeError( | |
"Can't get __cuda_array_interface__ on sparse type: %s " | |
"Use Tensor.to_dense() to convert to a dense tensor first." | |
% self.type() | |
) | |
# RuntimeError, matching tensor.__array__() behavior. | |
if self.requires_grad: | |
raise RuntimeError( | |
"Can't get __cuda_array_interface__ on Variable that requires grad. " | |
"If gradients aren't required, use var.detach() to get Variable that doesn't require grad." | |
) | |
# CUDA devices are little-endian and tensors are stored in native byte | |
# order. 1-byte entries are endian-agnostic. | |
typestr = { | |
torch.complex64: "<c8", | |
torch.complex128: "<c16", | |
torch.float16: "<f2", | |
torch.float32: "<f4", | |
torch.float64: "<f8", | |
torch.uint8: "|u1", | |
torch.int8: "|i1", | |
torch.int16: "<i2", | |
torch.int32: "<i4", | |
torch.int64: "<i8", | |
}[self.dtype] | |
itemsize = self.element_size() | |
shape = tuple(self.shape) | |
if self.is_contiguous(): | |
# __cuda_array_interface__ v2 requires the strides to be omitted | |
# (either not set or set to None) for C-contiguous arrays. | |
strides = None | |
else: | |
strides = tuple(s * itemsize for s in self.stride()) | |
data_ptr = self.data_ptr() if self.numel() > 0 else 0 | |
data = (data_ptr, False) # read-only is false | |
return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=2) | |
def storage_type(self): | |
r"""storage_type() -> type | |
Returns the type of the underlying storage. | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.storage_type, (self,), self) | |
torch.storage._warn_typed_storage_removal() | |
return self._typed_storage()._get_legacy_storage_class() | |
def refine_names(self, *names): | |
r"""Refines the dimension names of :attr:`self` according to :attr:`names`. | |
Refining is a special case of renaming that "lifts" unnamed dimensions. | |
A ``None`` dim can be refined to have any name; a named dim can only be | |
refined to have the same name. | |
Because named tensors can coexist with unnamed tensors, refining names | |
gives a nice way to write named-tensor-aware code that works with both | |
named and unnamed tensors. | |
:attr:`names` may contain up to one Ellipsis (``...``). | |
The Ellipsis is expanded greedily; it is expanded in-place to fill | |
:attr:`names` to the same length as ``self.dim()`` using names from the | |
corresponding indices of ``self.names``. | |
Python 2 does not support Ellipsis but one may use a string literal | |
instead (``'...'``). | |
Args: | |
names (iterable of str): The desired names of the output tensor. May | |
contain up to one Ellipsis. | |
Examples:: | |
>>> imgs = torch.randn(32, 3, 128, 128) | |
>>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W') | |
>>> named_imgs.names | |
('N', 'C', 'H', 'W') | |
>>> tensor = torch.randn(2, 3, 5, 7, 11) | |
>>> tensor = tensor.refine_names('A', ..., 'B', 'C') | |
>>> tensor.names | |
('A', None, None, 'B', 'C') | |
.. warning:: | |
The named tensor API is experimental and subject to change. | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.refine_names, (self,), self, *names) | |
names = resolve_ellipsis(names, self.names, "refine_names") | |
return super().refine_names(names) | |
def align_to(self, *names): | |
r"""Permutes the dimensions of the :attr:`self` tensor to match the order | |
specified in :attr:`names`, adding size-one dims for any new names. | |
All of the dims of :attr:`self` must be named in order to use this method. | |
The resulting tensor is a view on the original tensor. | |
All dimension names of :attr:`self` must be present in :attr:`names`. | |
:attr:`names` may contain additional names that are not in ``self.names``; | |
the output tensor has a size-one dimension for each of those new names. | |
:attr:`names` may contain up to one Ellipsis (``...``). | |
The Ellipsis is expanded to be equal to all dimension names of :attr:`self` | |
that are not mentioned in :attr:`names`, in the order that they appear | |
in :attr:`self`. | |
Python 2 does not support Ellipsis but one may use a string literal | |
instead (``'...'``). | |
Args: | |
names (iterable of str): The desired dimension ordering of the | |
output tensor. May contain up to one Ellipsis that is expanded | |
to all unmentioned dim names of :attr:`self`. | |
Examples:: | |
>>> tensor = torch.randn(2, 2, 2, 2, 2, 2) | |
>>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F') | |
# Move the F and E dims to the front while keeping the rest in order | |
>>> named_tensor.align_to('F', 'E', ...) | |
.. warning:: | |
The named tensor API is experimental and subject to change. | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.align_to, (self,), self, *names) | |
ellipsis_idx = single_ellipsis_index(names, "align_to") | |
if ellipsis_idx is None: | |
return super().align_to(names) | |
return super().align_to( | |
[name for name in names if not is_ellipsis(name)], ellipsis_idx | |
) | |
def unflatten(self, dim, sizes): | |
r""" | |
unflatten(dim, sizes) -> Tensor | |
See :func:`torch.unflatten`. | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.unflatten, (self,), self, dim, sizes) | |
if not sizes: | |
raise RuntimeError("unflatten: sizes must be non-empty") | |
names = None | |
if isinstance(sizes, OrderedDict) or ( | |
isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list)) | |
): | |
names, sizes = unzip_namedshape(sizes) | |
return super().unflatten(dim, sizes, names) | |
else: | |
return super().unflatten(dim, sizes) | |
def rename_(self, *names, **rename_map): | |
"""In-place version of :meth:`~Tensor.rename`.""" | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.rename_, (self,), self, *names, **rename_map | |
) | |
# Note [rename_ / rename API] | |
# The Python API for these is different from the C++ API. In Python: | |
# 1) tensor.rename(*names) takes a vararglist of names | |
# 2) tensor.rename(**rename_map) takes a map of names to rename. | |
# C++ is static, making it difficult to implement similar behavior. | |
return update_names(self, names, rename_map, inplace=True) | |
def rename(self, *names, **rename_map): | |
"""Renames dimension names of :attr:`self`. | |
There are two main usages: | |
``self.rename(**rename_map)`` returns a view on tensor that has dims | |
renamed as specified in the mapping :attr:`rename_map`. | |
``self.rename(*names)`` returns a view on tensor, renaming all | |
dimensions positionally using :attr:`names`. | |
Use ``self.rename(None)`` to drop names on a tensor. | |
One cannot specify both positional args :attr:`names` and keyword args | |
:attr:`rename_map`. | |
Examples:: | |
>>> imgs = torch.rand(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) | |
>>> renamed_imgs = imgs.rename(N='batch', C='channels') | |
>>> renamed_imgs.names | |
('batch', 'channels', 'H', 'W') | |
>>> renamed_imgs = imgs.rename(None) | |
>>> renamed_imgs.names | |
(None, None, None, None) | |
>>> renamed_imgs = imgs.rename('batch', 'channel', 'height', 'width') | |
>>> renamed_imgs.names | |
('batch', 'channel', 'height', 'width') | |
.. warning:: | |
The named tensor API is experimental and subject to change. | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor.rename, (self,), self, *names, **rename_map | |
) | |
# See Note [rename_ / rename API] | |
return update_names(self, names, rename_map, inplace=False) | |
def to_sparse_coo(self): | |
"""Convert a tensor to :ref:`coordinate format <sparse-coo-docs>`. | |
Examples:: | |
>>> dense = torch.randn(5, 5) | |
>>> sparse = dense.to_sparse_coo() | |
>>> sparse._nnz() | |
25 | |
""" | |
return self.to_sparse() | |
def dim_order(self): | |
""" | |
dim_order() -> tuple | |
Returns a tuple of int describing the dim order or physical layout of :attr:`self`. | |
Args: | |
None | |
Dim order represents how dimensions are laid out in memory, | |
starting from the outermost to the innermost dimension. | |
Example:: | |
>>> torch.empty((2, 3, 5, 7)).dim_order() | |
(0, 1, 2, 3) | |
>>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order() | |
(0, 2, 3, 1) | |
.. warning:: | |
The dim_order tensor API is experimental and subject to change. | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.dim_order, (self,), self) | |
import torch._prims_common as utils | |
return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self)) | |
def _update_names(self, names, inplace): | |
if has_torch_function_unary(self): | |
return handle_torch_function( | |
Tensor._update_names, (self,), self, names, inplace | |
) | |
# See Note [rename_ / rename API] | |
if inplace: | |
return super().rename_(names) | |
else: | |
return super().rename(names) | |
def __torch_function__(cls, func, types, args=(), kwargs=None): | |
""" | |
This __torch_function__ implementation wraps subclasses such that | |
methods called on subclasses return a subclass instance instead of | |
a ``torch.Tensor`` instance. | |
One corollary to this is that you need coverage for torch.Tensor | |
methods if implementing __torch_function__ for subclasses. | |
We recommend always calling ``super().__torch_function__`` as the base | |
case when doing the above. | |
While not mandatory, we recommend making `__torch_function__` a classmethod. | |
""" | |
if kwargs is None: | |
kwargs = {} | |
if not all(issubclass(cls, t) for t in types): | |
return NotImplemented | |
with _C.DisableTorchFunctionSubclass(): | |
ret = func(*args, **kwargs) | |
if func in get_default_nowrap_functions(): | |
return ret | |
else: | |
return _convert(ret, cls) | |
__torch_dispatch__ = _C._disabled_torch_dispatch_impl | |
def __dlpack__(self, stream=None): | |
""" | |
Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ | |
of the current tensor to be exported to other libraries. | |
This function will be called from the `from_dlpack` method | |
of the library that will consume the capsule. `from_dlpack` passes the current | |
stream to this method as part of the specification. | |
Args: | |
stream (integer or None): An optional Python integer representing a | |
pointer to a CUDA stream. The current stream is synchronized with | |
this stream before the capsule is created, and since the capsule | |
shares its storage with the tensor this make it safe to access from | |
both streams. If None or -1 is passed then no synchronization is performed. | |
If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for | |
synchronization. | |
""" | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.__dlpack__, (self,), self, stream) | |
# DLPack capsules can't capture all of PyTorch's semantics, | |
# so we prohibit exporting tensors that would lose their properties like | |
# requires_grad and having the conjugate bit set. | |
if self.requires_grad: | |
raise RuntimeError( | |
"Can't export tensors that require gradient, use tensor.detach()" | |
) | |
if self.is_conj(): | |
raise RuntimeError("Can't export tensors with the conjugate bit set") | |
if self.layout != torch.strided: | |
raise RuntimeError( | |
"Can't export tensors with layout other than torch.strided" | |
) | |
if stream is not None and type(stream) is not int: | |
# Stream pointers in CUDA/ROCm are uniquely numbered and can | |
# be retrieved from their integer value. | |
raise TypeError("stream must be ``int`` or ``none``") | |
elif stream is not None and stream != -1: | |
if self.device.type == "cuda": | |
# NB: This logic handles the special case values for default | |
# streams and must be kept in sync with from_dlpack in | |
# torch/utils/dlpack.py | |
if stream == 1 and torch.version.hip is None: | |
stream = torch.cuda.default_stream() | |
elif stream == 0 and torch.version.hip is not None: | |
stream = torch.cuda.default_stream() | |
else: | |
stream = torch.cuda.ExternalStream(stream) | |
# Only synchronize on different streams | |
sync_stream = torch.cuda.current_stream() | |
if stream != sync_stream: | |
event = torch.cuda.Event() | |
event.record(sync_stream) | |
stream.wait_event(event) | |
return torch.to_dlpack(self) | |
def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]: | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.__dlpack_device__, (self,), self) | |
device = self.device | |
idx = device.index if device.index is not None else 0 | |
torch_device_type = device.type | |
if torch_device_type == "cuda" and torch.version.hip is not None: | |
device_type = DLDeviceType.kDLROCM | |
elif torch_device_type == "cpu" and self.is_pinned(): | |
device_type = DLDeviceType.kDLCPUPinned | |
elif torch_device_type == "cuda": | |
device_type = DLDeviceType.kDLGPU | |
elif torch_device_type == "cpu": | |
device_type = DLDeviceType.kDLCPU | |
elif self.device.type == "xpu": | |
device_type = DLDeviceType.kDLOneAPI | |
else: | |
raise ValueError(f"Unknown device type {torch_device_type} for Dlpack") | |
return (device_type, idx) | |
__module__ = "torch" | |
def _convert(ret, cls): | |
if cls is Tensor: | |
return ret | |
if isinstance(ret, Tensor) and not isinstance(ret, cls): | |
ret = ret.as_subclass(cls) | |
if isinstance(ret, (tuple, list)): | |
# Also handles things like namedtuples | |
ret = type(ret)(_convert(r, cls) for r in ret) | |
return ret | |