Spaces:
Runtime error
Runtime error
File size: 833 Bytes
47c60f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import os
from enum import Enum
from .device_id import DeviceId
#NOTE: This must be called first before any torch imports in order to work properly!
class DeviceException(Exception):
pass
class _Device:
def __init__(self):
self.set(DeviceId.CPU)
def is_gpu(self):
''' Returns `True` if the current device is GPU, `False` otherwise. '''
return self.current() is not DeviceId.CPU
def current(self):
return self._current_device
def set(self, device:DeviceId):
if device == DeviceId.CPU:
os.environ['CUDA_VISIBLE_DEVICES']=''
else:
os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
import torch
torch.backends.cudnn.benchmark=False
self._current_device = device
return device |