from typing import Callable, ParamSpec, TypeVar | |
import os | |
from functools import lru_cache, wraps | |
import torch | |
IS_ROCM = torch.version.hip is not None | |
class CudaPlatform: | |
def get_device_name(cls, device_id: int = 0) -> str: | |
return torch.cuda.get_device_name(0) | |
class RocmPlatform: | |
def get_device_name(cls, device_id: int = 0) -> str: | |
return torch.cuda.get_device_name(device_id) | |
current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() | |