kernel
moe / torch-ext /moe /platforms.py
danieldk's picture
danieldk HF Staff
Vendor `w8a8_block_fp8_matmul` and `per_token_group_quant_fp8`
b41d28a
raw
history blame
1.77 kB
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import NamedTuple
import torch
IS_ROCM = torch.version.hip is not None
class DeviceCapability(NamedTuple):
major: int
minor: int
def as_version_str(self) -> str:
return f"{self.major}.{self.minor}"
def to_int(self) -> int:
"""
Express device capability as an integer ``<major><minor>``.
It is assumed that the minor version is always a single digit.
"""
assert 0 <= self.minor < 10
return self.major * 10 + self.minor
class Platform(ABC):
simple_compile_backend: str = "inductor"
@classmethod
@abstractmethod
def get_device_name(cls, device_id: int = 0) -> str: ...
@abstractmethod
def is_rocm(self): ...
class CudaPlatform(Platform):
@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
@classmethod
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(0)
def is_rocm(self):
return False
class RocmPlatform(Platform):
@classmethod
@lru_cache(maxsize=8)
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
@classmethod
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
def is_rocm(self):
return True
current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()