Spaces:
Sleeping
Sleeping
import functools | |
import hashlib | |
import os | |
from torch._dynamo.device_interface import get_interface_for_device | |
def has_triton_package() -> bool: | |
try: | |
import triton | |
return triton is not None | |
except ImportError: | |
return False | |
def has_triton() -> bool: | |
def cuda_extra_check(device_interface): | |
return device_interface.Worker.get_device_properties().major >= 7 | |
triton_supported_devices = {"cuda": cuda_extra_check} | |
def is_device_compatible_with_triton(): | |
for device, extra_check in triton_supported_devices.items(): | |
device_interface = get_interface_for_device(device) | |
if device_interface.is_available() and extra_check(device_interface): | |
return True | |
return False | |
return is_device_compatible_with_triton() and has_triton_package() | |
def triton_backend_hash(): | |
from triton.common.backend import get_backend, get_cuda_version_key | |
import torch | |
if torch.version.hip: | |
# Does not work with ROCm | |
return None | |
if not torch.cuda.is_available(): | |
return None | |
backend = get_backend("cuda") | |
if backend is None: | |
return get_cuda_version_key() | |
else: | |
return backend.get_version_key() | |
def triton_key(): | |
import pkgutil | |
import triton | |
TRITON_PATH = os.path.dirname(os.path.abspath(triton.__file__)) | |
contents = [] | |
# This is redundant. Doing it to be consistent with upstream. | |
# frontend | |
with open(os.path.join(TRITON_PATH, "compiler", "compiler.py"), "rb") as f: | |
contents += [hashlib.sha256(f.read()).hexdigest()] | |
# compiler | |
compiler_path = os.path.join(TRITON_PATH, "compiler") | |
backends_path = os.path.join(TRITON_PATH, "compiler", "backends") | |
for lib in pkgutil.iter_modules([compiler_path, backends_path]): | |
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: # type: ignore[call-arg, union-attr, arg-type] | |
contents += [hashlib.sha256(f.read()).hexdigest()] | |
# backend | |
libtriton_hash = hashlib.sha256() | |
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: | |
while True: | |
chunk = f.read(1024**2) | |
if not chunk: | |
break | |
libtriton_hash.update(chunk) | |
contents.append(libtriton_hash.hexdigest()) | |
# language | |
language_path = os.path.join(TRITON_PATH, "language") | |
for lib in pkgutil.iter_modules([language_path]): | |
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: # type: ignore[call-arg, union-attr, arg-type] | |
contents += [hashlib.sha256(f.read()).hexdigest()] | |
from triton import __version__ | |
return f"{__version__}" + "-".join(contents) | |
def triton_hash_with_backend(): | |
import torch | |
if torch.version.hip: | |
# Does not work with ROCm | |
return None | |
backend_hash = triton_backend_hash() | |
key = f"{triton_key()}-{backend_hash}" | |
return hashlib.sha256(key.encode("utf-8")).hexdigest() | |