Spaces:
Sleeping
Sleeping
"""Defines bias subclasses that work with scaled_dot_product_attention""" | |
from enum import auto, IntEnum | |
from typing import Optional | |
from warnings import warn | |
import torch | |
from torch.backends.cuda import ( | |
can_use_efficient_attention, | |
can_use_flash_attention, | |
SDPAParams, | |
) | |
from torch.nn.attention import _raise_kernel_warnings | |
from torch.nn.attention._utils import ( | |
_calculate_scale, | |
_input_requires_grad, | |
_postprocess_flash_output, | |
_validate_sdpa_input, | |
) | |
from torch.nn.functional import scaled_dot_product_attention | |
__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"] | |
torch._dynamo.allow_in_graph(can_use_flash_attention) | |
torch._dynamo.allow_in_graph(can_use_efficient_attention) | |
torch._dynamo.allow_in_graph(SDPAParams) | |
class CausalVariant(IntEnum): | |
r""" | |
Enum for causal variants used in attention mechanisms. | |
Defines two types of causal biases: | |
`UPPER_LEFT`: Represents upper-left triangular bias for standard causal attention. | |
The equivalent pytorch code for constructing this bias is: | |
.. code-block:: python | |
torch.tril(torch.ones(size, dtype=torch.bool)) | |
For instance, with `shape=(3,4)`, the materialized bias tensor will be: | |
.. code-block:: text | |
[[1, 0, 0, 0], | |
[1, 1, 0, 0], | |
[1, 1, 1, 0]] | |
`LOWER_RIGHT`: Represents lower-right triangular bias, the include values are aligned to the lower | |
right corner of the matrix. | |
The equivalent pytorch code for constructing this bias is: | |
.. code-block:: python | |
diagonal_offset = size[1] - size[0] | |
torch.tril( | |
torch.ones(size, dtype=torch.bool), | |
diagonal=diagonal_offset, | |
) | |
For instance, with `shape=(3,4)`, the materialized bias tensor will be: | |
.. code-block:: text | |
[[1, 1, 0, 0], | |
[1, 1, 1, 0], | |
[1, 1, 1, 1]] | |
Note that these variants are equivalent to each other when the sequence lengths of the query and key/value | |
tensors are equal since the triangular matrix is square. | |
.. warning:: This enum is a prototype and subject to change. | |
""" | |
UPPER_LEFT = auto() | |
LOWER_RIGHT = auto() | |
class CausalBias(torch.Tensor): | |
""" | |
A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum. | |
This class is used for defining causal (triangular) attention biases. For construing the bias, there exist | |
two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`. | |
Example: | |
.. code-block:: python | |
from torch.nn.attention.bias import causal_lower_right | |
bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8 | |
# Create a lower-right causal bias | |
attn_bias = causal_lower_right(seqlen_q, seqlen_kv) | |
q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16) | |
k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) | |
v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) | |
out = F.scaled_dot_product_attention(q, k, v, attn_bias) | |
.. warning:: This class is a prototype and subject to change. | |
""" | |
def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int): | |
""" | |
Initializes the CausalBias instance with a specified variant and sequence lengths. | |
Args: | |
variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT). | |
seq_len_q (int): The sequence length of the query tensor. | |
seq_len_kv (int): The sequence length of the key/value tensor. | |
Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs. | |
""" | |
assert isinstance(variant, CausalVariant) | |
self.variant = variant | |
self.seq_len_q = seq_len_q | |
self.seq_len_kv = seq_len_kv | |
if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT: | |
warn( | |
"Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!" | |
) | |
def _upper_left(self, device: torch.device) -> torch.Tensor: | |
"""Upper left causal bias""" | |
return torch.tril( | |
torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool) | |
) | |
def _lower_right(self, device: torch.device) -> torch.Tensor: | |
"""Lower right causal bias""" | |
diagonal_offset = self.seq_len_kv - self.seq_len_q | |
return torch.tril( | |
torch.ones( | |
self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool | |
), | |
diagonal=diagonal_offset, | |
) | |
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor: | |
""" | |
Materializes the causal bias into a tensor form. | |
Depending on the variant, this method generates either an upper-left or lower-right | |
triangular matrix to represent the causal bias. | |
Args: | |
device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU. | |
Returns: | |
torch.Tensor: The materialized bias tensor. | |
""" | |
if device is None: | |
device = torch.device("cpu") | |
if self.variant == CausalVariant.UPPER_LEFT: | |
return self._upper_left(device) | |
elif self.variant == CausalVariant.LOWER_RIGHT: | |
return self._lower_right(device) | |
def _dispatch( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: "CausalBias", | |
dropout_p: float = 0.0, | |
is_causal: bool = False, | |
scale: Optional[float] = None, | |
) -> torch.Tensor: | |
r""" | |
Handles the logic for computing attention with the specified causal bias. | |
Args: | |
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`. | |
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`. | |
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`. | |
attn_mask (CausalBias): The type of causal attention to apply. | |
A boolean mask where a value of True indicates that the element *should* take part in attention. | |
A float mask of the same type as query, key, value that is added to the attention score. | |
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied | |
is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal | |
are set. | |
scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set | |
to :math:`\frac{1}{\sqrt{E}}`. | |
Returns: | |
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`. | |
Raises: | |
ValueError: If the causal bias variant is not a CausalVariant type. | |
""" | |
if is_causal: | |
raise ValueError("CausalBias should not be used with causal=True") | |
if ( | |
attn_mask.seq_len_q == attn_mask.seq_len_kv | |
or attn_mask.variant == CausalVariant.UPPER_LEFT | |
): | |
return scaled_dot_product_attention( | |
query, | |
key, | |
value, | |
attn_mask=None, | |
dropout_p=dropout_p, | |
is_causal=True, | |
scale=scale, | |
) | |
elif attn_mask.variant == CausalVariant.LOWER_RIGHT: | |
_validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale) | |
sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal) | |
if can_use_flash_attention(sdpa_params): | |
needs_padding = query.size(-1) % 8 != 0 | |
og_head_size = query.size(-1) | |
og_scale = _calculate_scale(og_head_size, scale) | |
if needs_padding: | |
query = torch.nn.functional.pad(query, (0, 8 - query.size(-1) % 8)) | |
key = torch.nn.functional.pad(key, (0, 8 - key.size(-1) % 8)) | |
value = torch.nn.functional.pad(value, (0, 8 - value.size(-1) % 8)) | |
out = torch.ops.aten._scaled_dot_product_flash_attention( | |
query, | |
key, | |
value, | |
dropout_p, | |
is_causal=True, # TODO: Flash accepts causal = True and for this particular op it means lower right | |
return_debug_mask=False, | |
scale=og_scale, | |
)[0] | |
return _postprocess_flash_output(out, og_head_size) | |
if can_use_efficient_attention(sdpa_params): | |
compute_log_sumexp = False | |
if _input_requires_grad(query, key, value): | |
compute_log_sumexp = True | |
return torch.ops.aten._efficient_attention_forward( | |
query.transpose(1, 2), | |
key.transpose(1, 2), | |
value.transpose(1, 2), | |
bias=None, | |
cu_seqlens_q=None, | |
cu_seqlens_k=None, | |
max_seqlen_q=None, | |
max_seqlen_k=None, | |
dropout_p=dropout_p, | |
custom_mask_type=int(attn_mask.variant), | |
compute_log_sumexp=compute_log_sumexp, | |
scale=scale, | |
causal_diagonal=None, | |
seqlen_k=None, | |
)[0].transpose(1, 2) | |
else: | |
_raise_kernel_warnings(sdpa_params) | |
# We cant use efficient attention the only support for lower right is via materialization | |
return scaled_dot_product_attention( | |
query, | |
key, | |
value, | |
attn_mask=attn_mask._materialize(query.device), | |
dropout_p=dropout_p, | |
is_causal=False, | |
scale=scale, | |
) | |
else: | |
raise ValueError( | |
f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}" | |
) | |
def __torch_function__(cls, func, types, args=(), kwargs=None): | |
"""Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias""" | |
if kwargs is None: | |
kwargs = {} | |
if func != torch.nn.functional.scaled_dot_product_attention: | |
raise NotImplementedError( | |
"CausalBias only supports scaled_dot_product_attention" | |
) | |
return cls._dispatch(*args, **kwargs) | |
def __repr__(self): | |
return self._materialize().__repr__() | |
def causal_upper_left(*size) -> CausalBias: | |
""" | |
Creates an upper-left triangular causal bias. | |
This function generates a upper-left triangular matrix to represent causal attention bias with a | |
diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix. | |
This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`. | |
The equivalent pytorch code for constructing this bias is: | |
.. code-block:: python | |
torch.tril(torch.ones(size, dtype=torch.bool)) | |
For instance, with `shape=(3,4)`, the materialized bias tensor will be: | |
.. code-block:: text | |
[[1, 0, 0, 0], | |
[1, 1, 0, 0], | |
[1, 1, 1, 0]] | |
Args: | |
size: The size of the bias matrix. | |
Returns: | |
CausalBias: The UPPER_LEFT triangular causal bias variant. | |
""" | |
assert len(size) == 2, "causal_upper_left only supports 2D tensors" | |
seq_len_q, seq_len_kv = size | |
return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv) | |
def causal_lower_right(*size) -> CausalBias: | |
""" | |
Creates a lower-right triangular causal bias. | |
This function generates a lower-right triangular matrix to represent causal attention bias with a | |
diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix. | |
The equivalent pytorch code for constructing this bias is: | |
.. code-block:: python | |
diagonal_offset = size[1] - size[0] | |
torch.tril( | |
torch.ones(size, dtype=torch.bool), | |
diagonal=diagonal_offset, | |
) | |
For instance, with `shape=(3,4)`, the materialized bias tensor will be: | |
.. code-block:: text | |
[[1, 1, 0, 0], | |
[1, 1, 1, 0], | |
[1, 1, 1, 1]] | |
Args: | |
size: The size of the bias matrix. | |
Returns: | |
CausalBias: The LOWER_RIGHT triangular causal bias variant. | |
""" | |
assert len(size) == 2, "causal_lower_right only supports 2D tensors" | |
seq_len_q, seq_len_kv = size | |
return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv) | |