File size: 559 Bytes
29e93ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from typing import Callable, ParamSpec, TypeVar
import os
from functools import lru_cache, wraps
import torch
IS_ROCM = torch.version.hip is not None
class CudaPlatform:
@classmethod
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(0)
class RocmPlatform:
@classmethod
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
|