import os import torch import torch.nn.functional as F scaled_dot_product_attention = F.scaled_dot_product_attention if os.environ.get('CA_USE_SAGEATTN', '0') == '1': try: from sageattention import sageattn except ImportError: raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.') scaled_dot_product_attention = sageattn class CrossAttentionProcessor: def __call__(self, attn, q, k, v): out = scaled_dot_product_attention(q, k, v) return out