Spaces:
Sleeping
Sleeping
import copyreg | |
import functools | |
import sys | |
import traceback | |
import warnings | |
from collections import defaultdict | |
from typing import Any, DefaultDict, List, Optional | |
import torch | |
def _type(self, dtype=None, non_blocking=False, **kwargs): | |
"""Returns the type if `dtype` is not provided, else casts this object to | |
the specified type. | |
If this is already of the correct type, no copy is performed and the | |
original object is returned. | |
Args: | |
dtype (type or string): The desired type | |
non_blocking (bool): If ``True``, and the source is in pinned memory | |
and destination is on the GPU or vice versa, the copy is performed | |
asynchronously with respect to the host. Otherwise, the argument | |
has no effect. | |
**kwargs: For compatibility, may contain the key ``async`` in place of | |
the ``non_blocking`` argument. The ``async`` arg is deprecated. | |
""" | |
non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs) | |
if dtype is None: | |
return self.__module__ + "." + self.__class__.__name__ | |
if isinstance(dtype, str): | |
dtype = _import_dotted_name(dtype) | |
if dtype == type(self): | |
return self | |
if self.is_sparse: | |
if not dtype.is_sparse: | |
raise RuntimeError("Cannot cast sparse tensor to dense tensor") | |
new_module_name = dtype.__module__.replace(".sparse", "") | |
new_values_type_name = new_module_name + "." + dtype.__name__ | |
new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking) | |
new_indices_type_name = new_module_name + ".LongTensor" | |
new_indices = torch.Tensor._indices(self).type( | |
new_indices_type_name, non_blocking | |
) | |
return dtype(new_indices, new_values, self.size()) | |
if dtype.is_sparse: | |
raise RuntimeError("Cannot cast dense tensor to sparse tensor") | |
return dtype(self.size()).copy_(self, non_blocking) | |
def _hpu(self, device=None, non_blocking=False, **kwargs): | |
"""Returns a copy of this object in HPU memory. | |
If this object is already in HPU memory and on the correct device, then | |
no copy is performed and the original object is returned. | |
Args: | |
device (int): The destination HPU id. Defaults to the current device. | |
non_blocking (bool): If ``True`` and the source is in pinned memory, | |
the copy will be asynchronous with respect to the host. Otherwise, | |
the argument has no effect. | |
**kwargs: For compatibility, may contain the key ``async`` in place of | |
the ``non_blocking`` argument. | |
""" | |
non_blocking = _get_async_or_non_blocking("hpu", non_blocking, kwargs) | |
hpu = getattr(torch, "hpu", None) | |
assert hpu is not None, "HPU device module is not loaded" | |
if self.is_hpu: | |
if device is None: | |
device = hpu.current_device() | |
if self.get_device() == device: | |
return self | |
else: | |
if device is None: | |
device = -1 | |
with hpu.device(device): | |
assert not self.is_sparse, "sparse storage is not supported for HPU tensors" | |
untyped_storage = torch.UntypedStorage(self.size(), device=torch.device("hpu")) | |
untyped_storage.copy_(self, non_blocking) | |
return untyped_storage | |
def _cuda(self, device=None, non_blocking=False, **kwargs): | |
"""Returns a copy of this object in CUDA memory. | |
If this object is already in CUDA memory and on the correct device, then | |
no copy is performed and the original object is returned. | |
Args: | |
device (int): The destination GPU id. Defaults to the current device. | |
non_blocking (bool): If ``True`` and the source is in pinned memory, | |
the copy will be asynchronous with respect to the host. Otherwise, | |
the argument has no effect. | |
**kwargs: For compatibility, may contain the key ``async`` in place of | |
the ``non_blocking`` argument. | |
""" | |
non_blocking = _get_async_or_non_blocking("cuda", non_blocking, kwargs) | |
if self.is_cuda: | |
if device is None: | |
device = torch.cuda.current_device() | |
if self.get_device() == device: | |
return self | |
else: | |
if device is None: | |
device = -1 | |
with torch.cuda.device(device): | |
if self.is_sparse: | |
new_type = getattr(torch.cuda.sparse, self.__class__.__name__) | |
indices = torch.Tensor._indices(self).cuda(device, non_blocking) | |
values = torch.Tensor._values(self).cuda(device, non_blocking) | |
return new_type(indices, values, self.size()) | |
else: | |
untyped_storage = torch.UntypedStorage( | |
self.size(), device=torch.device("cuda") | |
) | |
untyped_storage.copy_(self, non_blocking) | |
return untyped_storage | |
def _get_async_or_non_blocking(function_name, non_blocking, kwargs): | |
"""Return the non-blocking flag given the function name and kwargs. | |
Args: | |
function_name (str): the name of the function being used. | |
non_blocking (bool): the default value. | |
**kwargs (dict): the kwargs passed to the function. | |
""" | |
if not kwargs: | |
return non_blocking | |
if len(kwargs) != 1 or "async" not in kwargs: | |
message = "{}() got an unexpected keyword argument '{}'" | |
argument = list(kwargs.keys()).pop() | |
raise TypeError(message.format(function_name, argument)) | |
warnings.warn("'async' is deprecated; use 'non_blocking'") | |
return kwargs["async"] | |
# Note [Don't serialize hooks] | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# Since time immemorial, we have serialized the backward hooks associated with | |
# variables. This kind of half-worked--Python can pickle global functions | |
# (but not closures!)--but there were problems. | |
# | |
# - It's fragile. If you serialize a backward hook into a saved | |
# model, and then you rename the function associated with the hook, | |
# now your saved model is broken and you can't load it anymore. | |
# | |
# - It's not actually used. The standard recommendation is to | |
# serialize the *state_dict* of a model, not the model itself | |
# (since this is more stable to code changes affecting the model | |
# serialization), and the state dict saves "data" only, thus | |
# stripping the backward hooks. In some cases, hooks are | |
# essential to the well-functioning of a model (e.g., DDP), | |
# but DDP already manages readding the hooks! | |
# | |
# - We didn't serialize them in many cases. Prior to #10220, we | |
# were dropping backward hooks in ForkingPickler. We "fixed" this | |
# to be convenient with other serialization sites, but lack of | |
# serializing backward hooks wasn't actually the root cause of | |
# the bug. | |
# | |
# With these cases in mind, we have decided that a better strategy | |
# is to just NOT serialize hooks at all. | |
# | |
# Since this is a BC-breaking change, we should warn when we previously | |
# serialized a hook, but no longer do so. This will be done by adding a special | |
# sentinel property to hooks will be used to suppress this warning. If a hook | |
# has the property _torch_serialize_ignore, we will not emit a warning if we | |
# attempt to serialize a Tensor with this hook attached to it. | |
# | |
# By the way, when _backward_hooks is skipped, we must give an EMPTY | |
# OrderedDict(), if you pass a None you'll run afoul #12219. | |
# TODO: Once we decide to break serialization FC, `storage` no longer needs to | |
# be a TypedStorage | |
def _rebuild_tensor(storage, storage_offset, size, stride): | |
# first construct a tensor with the correct dtype/device | |
t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device) | |
return t.set_(storage._untyped_storage, storage_offset, size, stride) | |
def get_tensor_metadata(tensor): | |
# Tensor's Metadata for serializing. | |
# Currently, this only returns a dict[string, bool] specifing whether | |
# `conj` or `neg` bit is set. | |
assert isinstance(tensor, torch.Tensor) | |
return torch._C._get_tensor_metadata(tensor) # type: ignore[attr-defined] | |
def set_tensor_metadata(tensor, metadata): | |
# See `get_tensor_metadata` above | |
assert isinstance(metadata, dict) | |
assert isinstance(tensor, torch.Tensor) | |
torch._C._set_tensor_metadata(tensor, metadata) # type: ignore[attr-defined] | |
def _rebuild_tensor_v2( | |
storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None | |
): | |
tensor = _rebuild_tensor(storage, storage_offset, size, stride) | |
tensor.requires_grad = requires_grad | |
if metadata: | |
set_tensor_metadata(tensor, metadata) | |
# NB: This line exists only for backwards compatibility; the | |
# general expectation is that backward_hooks is an empty | |
# OrderedDict. See Note [Don't serialize hooks] | |
tensor._backward_hooks = backward_hooks | |
return tensor | |
def _rebuild_tensor_v3( | |
storage, | |
storage_offset, | |
size, | |
stride, | |
requires_grad, | |
backward_hooks, | |
dtype, | |
metadata=None, | |
): | |
t = torch.empty( | |
(0,), | |
dtype=dtype, | |
device=storage._untyped_storage.device, | |
requires_grad=requires_grad, | |
) | |
t.set_(storage._untyped_storage, storage_offset, size, stride) | |
if metadata: | |
set_tensor_metadata(t, metadata) | |
t._backward_hooks = backward_hooks | |
return t | |
_sparse_tensors_to_validate: List["torch.Tensor"] = [] | |
# In _legacy_load() in serialization.py we unpickle storages after the sparse | |
# tensors have been already unpickled. Those storages contain data necessary for | |
# validating sparse tensors: indices and values. That's why sparse tensors are | |
# first unpickled without any validation, and then this function is called just | |
# before _legacy_load() returns, so that all the sparse tensors can be validated | |
# in bulk. | |
# | |
# The same procedure must be followed by _load() in serialization.py because due | |
# to Pickler semantics, we have to use the same (non-validating) function for | |
# unpickling sparse tensors, regardless of the caller. | |
def _validate_loaded_sparse_tensors(): | |
try: | |
for t in _sparse_tensors_to_validate: | |
if t.layout is torch.sparse_coo: | |
torch._validate_sparse_coo_tensor_args( | |
t._indices(), t._values(), t.size(), t.is_coalesced() | |
) | |
elif t.layout in { | |
torch.sparse_csr, | |
torch.sparse_csc, | |
torch.sparse_bsr, | |
torch.sparse_bsc, | |
}: | |
# TODO: Validation currently involves an expensive traversal | |
# on CPU, which may include a device transfer. | |
if t.layout in {torch.sparse_csr, torch.sparse_bsr}: | |
compressed_indices, plain_indices = ( | |
t.crow_indices(), | |
t.col_indices(), | |
) | |
else: | |
compressed_indices, plain_indices = ( | |
t.ccol_indices(), | |
t.row_indices(), | |
) | |
torch._validate_sparse_compressed_tensor_args( | |
compressed_indices, plain_indices, t.values(), t.size(), t.layout | |
) | |
else: | |
raise NotImplementedError( | |
f"_validate_loaded_sparse_tensors for layout `{t.layout}`" | |
) | |
finally: | |
_sparse_tensors_to_validate.clear() | |
def _rebuild_sparse_tensor(layout, data): | |
""" | |
Rebuilds a sparse tensor from its sparse storage representation. | |
Args: | |
layout (str): The sparse storage layout of the tensor. | |
data (tuple): The tensor's sparse storage representation. | |
""" | |
if layout == torch.sparse_coo: | |
if len(data) == 3: | |
# For BC: | |
indices, values, size = data | |
is_coalesced = None | |
else: | |
indices, values, size, is_coalesced = data | |
result = torch.sparse_coo_tensor( | |
indices, values, size, check_invariants=False, is_coalesced=is_coalesced | |
) | |
_sparse_tensors_to_validate.append(result) | |
return result | |
elif layout in { | |
torch.sparse_csr, | |
torch.sparse_csc, | |
torch.sparse_bsr, | |
torch.sparse_bsc, | |
}: | |
compressed_indices, plain_indices, values, size = data | |
result = torch.sparse_compressed_tensor( | |
compressed_indices, | |
plain_indices, | |
values, | |
size, | |
layout=layout, | |
check_invariants=False, | |
) | |
_sparse_tensors_to_validate.append(result) | |
return result | |
raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}") | |
def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets): | |
return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets) | |
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad): | |
tensor = torch.from_numpy(data).to(dtype=dtype, device=device) | |
tensor.requires_grad = requires_grad | |
return tensor | |
# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch | |
_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy | |
def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad): | |
return torch.empty_strided( | |
size, stride, dtype=dtype, device="meta", requires_grad=requires_grad | |
) | |
def _rebuild_wrapper_subclass( | |
cls, dtype, size, stride, storage_offset, layout, device, requires_grad | |
): | |
return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] | |
cls, | |
size, | |
strides=stride, | |
storage_offset=storage_offset, | |
layout=layout, | |
device=device, | |
requires_grad=requires_grad, | |
) | |
# TODO: Once we decide to break serialization FC, `storage` no longer needs to | |
# be a TypedStorage | |
def _rebuild_qtensor( | |
storage, | |
storage_offset, | |
size, | |
stride, | |
quantizer_params, | |
requires_grad, | |
backward_hooks, | |
): | |
qscheme = quantizer_params[0] | |
if qscheme == torch.per_tensor_affine: | |
_, scale, zero_point = quantizer_params | |
tensor = torch._empty_affine_quantized( | |
size, | |
scale=scale, | |
zero_point=zero_point, | |
dtype=storage.dtype, | |
device=storage.device, | |
) | |
elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams): | |
_, scales, zero_points, axis = quantizer_params | |
if type(scales) is list and type(zero_points) is list: | |
if qscheme == torch.per_channel_affine: | |
scales = torch.tensor(scales, dtype=torch.double, device=storage.device) | |
zero_points = torch.tensor( | |
zero_points, dtype=torch.long, device=storage.device | |
) | |
else: | |
scales = torch.tensor(scales, dtype=torch.float, device=storage.device) | |
zero_points = torch.tensor( | |
zero_points, dtype=torch.float, device=storage.device | |
) | |
tensor = torch._empty_per_channel_affine_quantized( | |
size, | |
scales=scales, | |
zero_points=zero_points, | |
axis=axis, | |
dtype=storage.dtype, | |
device=storage.device, | |
) | |
else: | |
raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}") | |
tensor.set_(storage, storage_offset, size, stride) | |
tensor.requires_grad = requires_grad | |
# NB: This line exists only for backwards compatibility; the | |
# general expectation is that backward_hooks is an empty | |
# OrderedDict. See Note [Don't serialize hooks] | |
tensor._backward_hooks = backward_hooks | |
return tensor | |
def _rebuild_parameter(data, requires_grad, backward_hooks): | |
param = torch.nn.Parameter(data, requires_grad) | |
# NB: This line exists only for backwards compatibility; the | |
# general expectation is that backward_hooks is an empty | |
# OrderedDict. See Note [Don't serialize hooks] | |
param._backward_hooks = backward_hooks | |
return param | |
def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state): | |
param = torch.nn.Parameter(data, requires_grad) | |
# NB: This line exists only for backwards compatibility; the | |
# general expectation is that backward_hooks is an empty | |
# OrderedDict. See Note [Don't serialize hooks] | |
param._backward_hooks = backward_hooks | |
# Restore state on Parameter like python attr. | |
param = _set_obj_state(param, state) | |
return param | |
def _get_obj_state(obj): | |
# Get the state of the python subclass | |
# This loosely mimicks the function on the object class but since Tensor do not inherit | |
# from it, we cannot call that function directly | |
# https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891 | |
# Note that starting with Python 3.11, this `__getstate__` is always defined and thus | |
# the else branch will never be taken. | |
getstate_fn = getattr(obj, "__getstate__", None) | |
if getstate_fn: | |
state = getstate_fn() | |
else: | |
slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined] | |
if slots_to_save: | |
state = ( | |
obj.__dict__, | |
{ | |
name: getattr(obj, name) | |
for name in slots_to_save | |
if hasattr(obj, name) | |
}, | |
) | |
else: | |
state = obj.__dict__ | |
return state | |
def _set_obj_state(obj, state): | |
if isinstance(state, tuple): | |
if not len(state) == 2: | |
raise RuntimeError(f"Invalid serialized state: {state}") | |
dict_state = state[0] | |
slots_state = state[1] | |
else: | |
dict_state = state | |
slots_state = None | |
# Starting with Python 3.11, the __dict__ attribute is lazily created | |
# and is serialized as None when not needed. | |
if dict_state: | |
for k, v in dict_state.items(): | |
setattr(obj, k, v) | |
if slots_state: | |
for k, v in slots_state.items(): | |
setattr(obj, k, v) | |
return obj | |
def _import_dotted_name(name): | |
components = name.split(".") | |
obj = __import__(components[0]) | |
for component in components[1:]: | |
obj = getattr(obj, component) | |
return obj | |
def _flatten_dense_tensors(tensors): | |
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of | |
same dense type. | |
Since inputs are dense, the resulting tensor will be a concatenated 1D | |
buffer. Element-wise operation on this buffer will be equivalent to | |
operating individually. | |
Args: | |
tensors (Iterable[Tensor]): dense tensors to flatten. | |
Returns: | |
A contiguous 1D buffer containing input tensors. | |
""" | |
return torch._C._nn.flatten_dense_tensors(tensors) | |
def _flatten_sparse_tensors(tensors): | |
"""Flatten sparse tensors into two contiguous 1D buffers, one of indices and | |
one of values. Assume tensors are of same sparse type. | |
Args: | |
tensors (Iterable[Tensor]): sparse tensors to flatten. | |
Returns: | |
A tuple of two contiguous 1D buffers, one containing input tensors' | |
indices and the other containing the values. | |
""" | |
flat_indices = torch._C._nn.flatten_dense_tensors( | |
[torch.Tensor._indices(t) for t in tensors] | |
) | |
flat_values = torch._C._nn.flatten_dense_tensors( | |
[torch.Tensor._values(t) for t in tensors] | |
) | |
return flat_indices, flat_values | |
def _unflatten_dense_tensors(flat, tensors): | |
"""View a flat buffer using the sizes of tensors. Assume that tensors are of | |
same dense type, and that flat is given by _flatten_dense_tensors. | |
Args: | |
flat (Tensor): flattened dense tensors to unflatten. | |
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to | |
unflatten flat. | |
Returns: | |
Unflattened dense tensors with sizes same as tensors and values from | |
flat. | |
""" | |
return torch._C._nn.unflatten_dense_tensors(flat, tensors) | |
def _unflatten_sparse_tensors(flat, tensors): | |
"""View flat buffer (containing indices and values) using the sizes of | |
tensors. Assume that tensors are of same sparse type, and that flat is given | |
by _flatten_sparse_tensors. | |
Args: | |
flat (tuple(Tensor, Tensor)): flattened indices and values of sparse | |
tensors to unflatten. | |
tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to | |
unflatten flat. | |
Returns: | |
Unflattened sparse tensors with sizes same as tensors and values from | |
flat. | |
""" | |
flat_indices, flat_values = flat | |
indices = torch._C._nn.unflatten_dense_tensors( | |
flat_indices, [torch.Tensor._indices(t) for t in tensors] | |
) | |
values = torch._C._nn.unflatten_dense_tensors( | |
flat_values, [torch.Tensor._values(t) for t in tensors] | |
) | |
outputs = [] | |
for t, i, v in zip(tensors, indices, values): | |
outputs.append(t.new(i, v, t.size())) | |
return tuple(outputs) | |
def _reorder_tensors_as(tensors, ordered_tensors): | |
"""Assume that tensors are of same order as ordered_tensors within their | |
types, e.g., from _take_tensors. Reorder them to be of same order as | |
ordered_tensors. | |
Args: | |
tensors (Iterable[Tensor]): tensors to be reordered. They should be of | |
the same order as ordered_tensors within their own types. | |
ordered_tensors (Iterable[Tensor]): tensors whose order will be the | |
reference. | |
Returns: | |
Ordered tuple of tensors with contents from tensors and order of | |
ordered_tensors. | |
""" | |
type_dict = defaultdict(list) | |
for tensor in tensors: | |
type_dict[tensor.type()].append(tensor) | |
type_dict_ = {t: iter(coll) for t, coll in type_dict.items()} | |
return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors) | |
def _take_tensors(tensors, size_limit): | |
"""Group tensors into chunks. This generator yields a chunk at each time, | |
each containing tensors of same type up to certain byte limit in total size. | |
Args: | |
tensors (Sequence): A sequence of tensors to be separated into chunks. | |
size_limit (int): The limit of each chunk in bytes. | |
Yields: | |
Blocks of tensors of same type and within size_limit. The yielded | |
tensors are only ordered as the original sequence within its types. | |
""" | |
buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0]) | |
for tensor in tensors: | |
t = tensor.type() | |
if tensor.is_sparse: | |
indices = torch.Tensor._indices(tensor) | |
values = torch.Tensor._values(tensor) | |
size = ( | |
indices.numel() * indices.element_size() | |
+ values.numel() * values.element_size() | |
) | |
else: | |
size = tensor.numel() * tensor.element_size() | |
buf_and_size = buf_dict[t] | |
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: | |
yield buf_and_size[0] | |
buf_and_size = buf_dict[t] = [[], 0] | |
buf_and_size[0].append(tensor) | |
buf_and_size[1] += size | |
for buf, _ in buf_dict.values(): | |
if len(buf) > 0: | |
yield buf | |
# annotation decorator to get annotations in a way that is compatible | |
# with both Python 2 and 3 | |
def annotate(ret, **kwargs): | |
def dec(fun): | |
fun.__annotations__ = dict(kwargs) | |
fun.__annotations__["return"] = ret | |
return fun | |
return dec | |
def render_call(fn, args, kwargs): | |
str_fn = torch.overrides.resolve_name(fn) | |
if str_fn is None: | |
str_fn = str(fn) | |
str_args: List[str] = [] | |
with torch._tensor_str.printoptions(threshold=0, edgeitems=0): | |
str_args.extend(repr(a) for a in args) | |
str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items()) | |
r = f"{str_fn}({', '.join(str_args)})" | |
return r | |
# NOTE [ Python Traceback Reference Cycle Problem ] | |
# | |
# When using sys.exc_info(), it is important to **not** store the exc_info[2], | |
# which is the traceback, because otherwise you will run into the traceback | |
# reference cycle problem, i.e., the traceback holding reference to the frame, | |
# and the frame (which holds reference to all the object in its temporary scope) | |
# holding reference the traceback. | |
class KeyErrorMessage(str): | |
r"""str subclass that returns itself in repr""" | |
def __repr__(self): | |
return self | |
class ExceptionWrapper: | |
r"""Wraps an exception plus traceback to communicate across threads""" | |
def __init__(self, exc_info=None, where="in background"): | |
# It is important that we don't store exc_info, see | |
# NOTE [ Python Traceback Reference Cycle Problem ] | |
if exc_info is None: | |
exc_info = sys.exc_info() | |
self.exc_type = exc_info[0] | |
self.exc_msg = "".join(traceback.format_exception(*exc_info)) | |
self.where = where | |
def reraise(self): | |
r"""Reraises the wrapped exception in the current thread""" | |
# Format a message such as: "Caught ValueError in DataLoader worker | |
# process 2. Original Traceback:", followed by the traceback. | |
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" | |
if self.exc_type == KeyError: | |
# KeyError calls repr() on its argument (usually a dict key). This | |
# makes stack traces unreadable. It will not be changed in Python | |
# (https://bugs.python.org/issue2651), so we work around it. | |
msg = KeyErrorMessage(msg) | |
elif getattr(self.exc_type, "message", None): | |
# Some exceptions have first argument as non-str but explicitly | |
# have message field | |
raise self.exc_type(message=msg) | |
try: | |
exception = self.exc_type(msg) | |
except TypeError: | |
# If the exception takes multiple arguments, don't try to | |
# instantiate since we don't know how to | |
raise RuntimeError(msg) from None | |
raise exception | |
def _get_available_device_type(): | |
if torch.cuda.is_available(): | |
return "cuda" | |
if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined] | |
return "xpu" | |
custom_backend_name = torch._C._get_privateuse1_backend_name() | |
custom_device_mod = getattr(torch, custom_backend_name, None) | |
if custom_device_mod and custom_device_mod.is_available(): | |
return custom_backend_name | |
# add more available device types here | |
return None | |
def _get_device_attr(get_member): | |
device_type = _get_available_device_type() | |
if device_type and device_type.lower() == "cuda": | |
return get_member(torch.cuda) | |
if device_type and device_type.lower() == "xpu": | |
return get_member(torch.xpu) # type: ignore[attr-defined] | |
if device_type == torch._C._get_privateuse1_backend_name(): | |
return get_member(getattr(torch, device_type)) | |
# add more available device types here | |
return None | |
def _get_current_device_index(): | |
# current device index | |
return _get_device_attr(lambda m: m.current_device()) | |
def _get_all_device_indices(): | |
# all device index | |
return _get_device_attr(lambda m: list(range(m.device_count()))) | |
def _get_devices_properties(device_ids): | |
# all device properties | |
return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids] | |
def get_current_device_index() -> int: | |
r"""Checks if there are CUDA devices available and | |
returns the device index of the current default CUDA device. | |
Returns -1 in case there are no CUDA devices available. | |
Arguments: ``None`` | |
""" | |
if torch.cuda.device_count() > 0: | |
return torch.cuda.current_device() | |
return -1 | |
def _get_device_index( | |
device: Any, optional: bool = False, allow_cpu: bool = False | |
) -> int: | |
r"""Gets the device index from :attr:`device`, which can be a torch.device | |
object, a Python integer, or ``None``. | |
If :attr:`device` is a torch.device object, returns the device index if it | |
has index. Note that for a device without a specified index, | |
i.e., ``torch.device('xxx')``, this will return the current default | |
device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, | |
CPU devices will be accepted and ``-1`` will be returned in this case. | |
If :attr:`device` is a Python integer, it is returned as is. | |
If :attr:`device` is ``None``, this will return the current default | |
device of the supported runtime platform if :attr:`optional` is ``True``. | |
i.e., the current default CUDA device will be returned if CUDA runtime is supported. | |
""" | |
if isinstance(device, str): | |
device = torch.device(device) | |
device_idx: Optional[int] = None | |
if isinstance(device, torch.device): | |
if not allow_cpu and device.type == "cpu": | |
raise ValueError(f"Expected a non cpu device, but got: {device}") | |
device_idx = -1 if device.type == "cpu" else device.index | |
if isinstance(device, int): | |
device_idx = device | |
if device_idx is None: | |
if optional: | |
# The eager API _get_current_device_index uses `lambda` functions which are | |
# not supported in JIT and hence not scriptable. The JIT equivalent API to get | |
# the current device index is `get_current_device_index()` which can | |
# be scripted. We use is_scripting to check the mode we are in and call the | |
# appropriate API. | |
if torch.jit.is_scripting(): | |
device_idx = get_current_device_index() | |
else: | |
device_idx = _get_current_device_index() | |
else: | |
raise ValueError( | |
f"Expected a torch.device with a specified index or an integer, but got:{device}" | |
) | |
return device_idx | |
def _handle_complex(tensor): | |
""" | |
Returns a real view of a tensor if complex dtype else just the tensor | |
need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule | |
""" | |
return ( | |
torch.view_as_real(tensor) | |
if not isinstance(tensor, torch.nn.UninitializedParameter) | |
and tensor.is_complex() | |
else tensor | |
) | |
def _element_size(dtype): | |
""" | |
Returns the element size for a dtype, in bytes | |
""" | |
if not isinstance(dtype, torch.dtype): | |
raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}") | |
if dtype.is_complex: | |
return torch.finfo(dtype).bits >> 2 | |
elif dtype.is_floating_point: | |
return torch.finfo(dtype).bits >> 3 | |
elif dtype == torch.bool: | |
# NOTE: torch.bool is not supported in torch.iinfo() | |
return 1 | |
else: | |
return torch.iinfo(dtype).bits >> 3 | |
class _ClassPropertyDescriptor: | |
def __init__(self, fget, fset=None): | |
self.fget = fget | |
def __get__(self, instance, owner=None): | |
if owner is None: | |
owner = type(instance) | |
return self.fget.__get__(instance, owner)() | |
def classproperty(func): | |
if not isinstance(func, (classmethod, staticmethod)): | |
func = classmethod(func) | |
return _ClassPropertyDescriptor(func) | |
def is_compiling() -> bool: | |
""" | |
Indicates whether we are tracing/compiling with torch.compile() or torch.export(). | |
TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling(). | |
""" | |
return torch.compiler.is_compiling() | |
def _functionalize_sync(t): | |
# This code lives in python instead of C++ since conditioning on a certain python subclass | |
# is much more of a pain in C++. | |
from torch._subclasses.functional_tensor import FunctionalTensor | |
if isinstance(t, FunctionalTensor): | |
# If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called | |
# when we sync our inner tensor. | |
# Why? | |
# (1) If there are input mutations in the graph, then they will be re-applied during | |
# AOTAutograd when we call _sync() from inside of our functionalization kernels. | |
# (2) _sync() causes us to regenerate our updated the tensor from the updated base, | |
# which dispatches to a bunch of view ops | |
# (3) The input to these view ops is our inner FunctionalTensorWrapper | |
# (since the sync was called from C++), not the python FunctionalTensor | |
# (4) if a python FunctionalTensorMode is active, it will complain when it intercepts | |
# the view op, since it will see an input that is a C++ FunctionalTensorWrapper | |
# (aka a normal torch.Tensor) instead of a python `FunctionalTensor). | |
maybe_functional_mode = torch._C._unset_dispatch_mode( | |
torch._C._TorchDispatchModeKey.FUNCTIONAL | |
) | |
try: | |
torch._functionalize_sync(t.elem) # type: ignore[attr-defined] | |
finally: | |
if maybe_functional_mode is not None: | |
torch._C._set_dispatch_mode(maybe_functional_mode) | |
else: | |
torch._functionalize_sync(t) # type: ignore[attr-defined] | |
def _get_device_module(device_type: str): | |
device_module = getattr(torch, device_type, None) | |
if device_module is None: | |
raise RuntimeError( | |
f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'." | |
) | |
return device_module | |
def _dummy_type(name: str) -> type: | |
def get_err_fn(is_init: bool): | |
def err_fn(obj, *args, **kwargs): | |
if is_init: | |
class_name = obj.__class__.__name__ | |
else: | |
class_name = obj.__name__ | |
raise RuntimeError(f"Tried to instantiate dummy base class {class_name}") | |
return err_fn | |
return type( | |
name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)} | |
) | |
class _LazySeedTracker: | |
# Since seeding is memory-less, only track the latest seed. | |
# Note: `manual_seed_all` followed by `manual_seed` overwrites | |
# the seed on current device. We track the order of **latest** | |
# calls between these two API. | |
def __init__(self): | |
self.manual_seed_all_cb = None | |
self.manual_seed_cb = None | |
self.call_order = [] | |
def queue_seed_all(self, cb, traceback): | |
self.manual_seed_all_cb = (cb, traceback) | |
# update seed_all to be latest | |
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb] | |
def queue_seed(self, cb, traceback): | |
self.manual_seed_cb = (cb, traceback) | |
# update seed to be latest | |
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb] | |
def get_calls(self) -> List: | |
return self.call_order | |