Spaces:
Running
Running
"""Defines utilities for interacting with scaled_dot_product_attention""" | |
import math | |
from typing import List, Optional | |
import torch | |
__all__: List[str] = [] | |
def _input_requires_grad(*tensors: torch.Tensor) -> bool: | |
"""Returns True if any of the tensors requires grad""" | |
return any(t.requires_grad for t in tensors) | |
def _postprocess_flash_output(inpt_tensor: torch.Tensor, og_size: int) -> torch.Tensor: | |
"""Handles the unpad of the last dimension""" | |
if inpt_tensor.size(-1) != og_size: | |
return inpt_tensor[..., :og_size] | |
return inpt_tensor | |
def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float: | |
""" | |
For FlashAttention we pad the head dimension to be a multiple of 8 so we need to scale the output | |
by the original head size and not the padded. | |
""" | |
if scale is not None: | |
return scale | |
return 1.0 / math.sqrt(head_dim_size) | |
def _validate_sdpa_input( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p=0.0, | |
is_causal=False, | |
scale=None, | |
): | |
if query.dtype != key.dtype or query.dtype != value.dtype: | |
raise ValueError( | |
f"Expected query, key, and value to have the same dtype, " | |
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, " | |
f"and value.dtype: {value.dtype} instead." | |
) | |
if query.device != key.device or query.device != value.device: | |
raise ValueError( | |
f"Expected query, key, and value to have the same device type, " | |
f"but got query.device: {query.device}, key.device: {key.device}, " | |
f"and value.device: {value.device} instead." | |
) | |
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2: | |
raise ValueError( | |
f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: " | |
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead." | |
) | |