from typing import * BACKEND = "flash_attn" DEBUG = False def __from_env(): import os global BACKEND global DEBUG env_attn_backend = os.environ.get("ATTN_BACKEND") env_sttn_debug = os.environ.get("ATTN_DEBUG") if env_attn_backend is not None and env_attn_backend in [ "xformers", "flash_attn", "sdpa", "naive", ]: BACKEND = env_attn_backend if env_sttn_debug is not None: DEBUG = env_sttn_debug == "1" print(f"[ATTENTION] Using backend: {BACKEND}") __from_env() def set_backend(backend: Literal["xformers", "flash_attn"]): global BACKEND BACKEND = backend def set_debug(debug: bool): global DEBUG DEBUG = debug from .full_attn import * from .modules import *