kernel
File size: 559 Bytes
29e93ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from typing import Callable, ParamSpec, TypeVar
import os
from functools import lru_cache, wraps

import torch

IS_ROCM = torch.version.hip is not None

class CudaPlatform:
    @classmethod
    @lru_cache(maxsize=8)
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(0)

class RocmPlatform:
    @classmethod
    @lru_cache(maxsize=8)
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)


current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()