Spaces:
Sleeping
Sleeping
from typing import Optional | |
import torch | |
from torch.overrides import TorchFunctionMode | |
from torch.utils._contextlib import context_decorator | |
import functools | |
CURRENT_DEVICE: Optional[torch.device] = None | |
def _device_constructors(): | |
return { | |
# standard ones | |
torch.empty, | |
torch.empty_permuted, | |
torch.empty_strided, | |
torch.empty_quantized, | |
torch.ones, | |
torch.arange, | |
torch.bartlett_window, | |
torch.blackman_window, | |
torch.eye, | |
torch.fft.fftfreq, | |
torch.fft.rfftfreq, | |
torch.full, | |
torch.fill, | |
torch.hamming_window, | |
torch.hann_window, | |
torch.kaiser_window, | |
torch.linspace, | |
torch.logspace, | |
torch.nested.nested_tensor, | |
# This function doesn't actually take a device argument | |
# torch.normal, | |
torch.ones, | |
torch.rand, | |
torch.randn, | |
torch.randint, | |
torch.randperm, | |
torch.range, | |
torch.sparse_coo_tensor, | |
torch.sparse_compressed_tensor, | |
torch.sparse_csr_tensor, | |
torch.sparse_csc_tensor, | |
torch.sparse_bsr_tensor, | |
torch.sparse_bsc_tensor, | |
torch.tril_indices, | |
torch.triu_indices, | |
torch.vander, | |
torch.zeros, | |
torch.asarray, | |
# weird ones | |
torch.tensor, | |
torch.as_tensor, | |
torch.scalar_tensor, | |
torch.asarray, | |
} | |
# NB: This is directly called from C++ in torch/csrc/Device.cpp | |
class DeviceContext(TorchFunctionMode): | |
def __init__(self, device): | |
self.device = torch.device(device) | |
def __enter__(self): | |
global CURRENT_DEVICE | |
self.old_device = CURRENT_DEVICE | |
CURRENT_DEVICE = self.device | |
return super().__enter__() | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
global CURRENT_DEVICE | |
CURRENT_DEVICE = self.old_device | |
return super().__exit__(exc_type, exc_val, exc_tb) | |
def __torch_function__(self, func, types, args=(), kwargs=None): | |
kwargs = kwargs or {} | |
if func in _device_constructors() and kwargs.get('device') is None: | |
kwargs['device'] = self.device | |
return func(*args, **kwargs) | |
# NB: This is directly called from C++ in torch/csrc/Device.cpp | |
def device_decorator(device, func): | |
return context_decorator(lambda: device, func) | |
def set_device(device): | |
""" | |
Set the default device inside of the wrapped function by decorating it with this function. | |
If you would like to use this as a context manager, use device as a | |
context manager directly, e.g., ``with torch.device(device)``. | |
""" | |
return lambda func: device_decorator(torch.device(device), func) | |