kernel
moe / ext-torch /platforms.py
danieldk's picture
danieldk HF Staff
Add MoE kernels from vLLM
29e93ec
raw
history blame
559 Bytes
from typing import Callable, ParamSpec, TypeVar
import os
from functools import lru_cache, wraps
import torch
IS_ROCM = torch.version.hip is not None
class CudaPlatform:
@classmethod
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(0)
class RocmPlatform:
@classmethod
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()