from typing import Callable, ParamSpec, TypeVar import os from functools import lru_cache, wraps import torch IS_ROCM = torch.version.hip is not None class CudaPlatform: @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(0) class RocmPlatform: @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()