File size: 13,389 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""Defines bias subclasses that work with scaled_dot_product_attention"""
from enum import auto, IntEnum
from typing import Optional
from warnings import warn

import torch
from torch.backends.cuda import (
    can_use_efficient_attention,
    can_use_flash_attention,
    SDPAParams,
)
from torch.nn.attention import _raise_kernel_warnings
from torch.nn.attention._utils import (
    _calculate_scale,
    _input_requires_grad,
    _postprocess_flash_output,
    _validate_sdpa_input,
)
from torch.nn.functional import scaled_dot_product_attention

__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"]


torch._dynamo.allow_in_graph(can_use_flash_attention)
torch._dynamo.allow_in_graph(can_use_efficient_attention)
torch._dynamo.allow_in_graph(SDPAParams)


class CausalVariant(IntEnum):
    r"""

    Enum for causal variants used in attention mechanisms.



    Defines two types of causal biases:



    `UPPER_LEFT`: Represents upper-left triangular bias for standard causal attention.

    The equivalent pytorch code for constructing this bias is:



    .. code-block:: python



        torch.tril(torch.ones(size, dtype=torch.bool))



    For instance, with `shape=(3,4)`, the materialized bias tensor will be:



    .. code-block:: text



        [[1, 0, 0, 0],

         [1, 1, 0, 0],

         [1, 1, 1, 0]]





    `LOWER_RIGHT`: Represents lower-right triangular bias, the include values are aligned to the lower

    right corner of the matrix.



    The equivalent pytorch code for constructing this bias is:



    .. code-block:: python



        diagonal_offset = size[1] - size[0]

        torch.tril(

            torch.ones(size, dtype=torch.bool),

            diagonal=diagonal_offset,

        )



    For instance, with `shape=(3,4)`, the materialized bias tensor will be:



    .. code-block:: text



        [[1, 1, 0, 0],

         [1, 1, 1, 0],

         [1, 1, 1, 1]]



    Note that these variants are equivalent to each other when the sequence lengths of the query and key/value

    tensors are equal since the triangular matrix is square.



    .. warning:: This enum is a prototype and subject to change.

    """

    UPPER_LEFT = auto()
    LOWER_RIGHT = auto()


class CausalBias(torch.Tensor):
    """

    A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum.



    This class is used for defining causal (triangular) attention biases. For construing the bias, there exist

    two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`.



    Example:



    .. code-block:: python



        from torch.nn.attention.bias import causal_lower_right



        bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8



        # Create a lower-right causal bias

        attn_bias = causal_lower_right(seqlen_q, seqlen_kv)



        q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)

        k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)

        v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)



        out = F.scaled_dot_product_attention(q, k, v, attn_bias)



    .. warning:: This class is a prototype and subject to change.

    """

    def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int):
        """

        Initializes the CausalBias instance with a specified variant and sequence lengths.



        Args:

            variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT).

            seq_len_q (int): The sequence length of the query tensor.

            seq_len_kv (int): The sequence length of the key/value tensor.



        Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs.

        """
        assert isinstance(variant, CausalVariant)
        self.variant = variant
        self.seq_len_q = seq_len_q
        self.seq_len_kv = seq_len_kv
        if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT:
            warn(
                "Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!"
            )

    def _upper_left(self, device: torch.device) -> torch.Tensor:
        """Upper left causal bias"""
        return torch.tril(
            torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool)
        )

    def _lower_right(self, device: torch.device) -> torch.Tensor:
        """Lower right causal bias"""
        diagonal_offset = self.seq_len_kv - self.seq_len_q
        return torch.tril(
            torch.ones(
                self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool
            ),
            diagonal=diagonal_offset,
        )

    def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
        """

        Materializes the causal bias into a tensor form.



        Depending on the variant, this method generates either an upper-left or lower-right

        triangular matrix to represent the causal bias.



        Args:

            device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU.



        Returns:

            torch.Tensor: The materialized bias tensor.

        """
        if device is None:
            device = torch.device("cpu")
        if self.variant == CausalVariant.UPPER_LEFT:
            return self._upper_left(device)
        elif self.variant == CausalVariant.LOWER_RIGHT:
            return self._lower_right(device)

    @staticmethod
    def _dispatch(

        query: torch.Tensor,

        key: torch.Tensor,

        value: torch.Tensor,

        attn_mask: "CausalBias",

        dropout_p: float = 0.0,

        is_causal: bool = False,

        scale: Optional[float] = None,

    ) -> torch.Tensor:
        r"""

        Handles the logic for computing attention with the specified causal bias.



        Args:

            query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.

            key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.

            value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.

            attn_mask (CausalBias): The type of causal attention to apply.

                A boolean mask where a value of True indicates that the element *should* take part in attention.

                A float mask of the same type as query, key, value that is added to the attention score.

            dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied

            is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal

                are set.

            scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set

                to :math:`\frac{1}{\sqrt{E}}`.



        Returns:

            output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.



        Raises:

            ValueError: If the causal bias variant is not a CausalVariant type.



        """
        if is_causal:
            raise ValueError("CausalBias should not be used with causal=True")

        if (
            attn_mask.seq_len_q == attn_mask.seq_len_kv
            or attn_mask.variant == CausalVariant.UPPER_LEFT
        ):
            return scaled_dot_product_attention(
                query,
                key,
                value,
                attn_mask=None,
                dropout_p=dropout_p,
                is_causal=True,
                scale=scale,
            )
        elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
            _validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
            sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal)
            if can_use_flash_attention(sdpa_params):
                needs_padding = query.size(-1) % 8 != 0
                og_head_size = query.size(-1)
                og_scale = _calculate_scale(og_head_size, scale)
                if needs_padding:
                    query = torch.nn.functional.pad(query, (0, 8 - query.size(-1) % 8))
                    key = torch.nn.functional.pad(key, (0, 8 - key.size(-1) % 8))
                    value = torch.nn.functional.pad(value, (0, 8 - value.size(-1) % 8))
                out = torch.ops.aten._scaled_dot_product_flash_attention(
                    query,
                    key,
                    value,
                    dropout_p,
                    is_causal=True,  # TODO: Flash accepts causal = True and for this particular op it means lower right
                    return_debug_mask=False,
                    scale=og_scale,
                )[0]
                return _postprocess_flash_output(out, og_head_size)
            if can_use_efficient_attention(sdpa_params):
                compute_log_sumexp = False
                if _input_requires_grad(query, key, value):
                    compute_log_sumexp = True
                return torch.ops.aten._efficient_attention_forward(
                    query.transpose(1, 2),
                    key.transpose(1, 2),
                    value.transpose(1, 2),
                    bias=None,
                    cu_seqlens_q=None,
                    cu_seqlens_k=None,
                    max_seqlen_q=None,
                    max_seqlen_k=None,
                    dropout_p=dropout_p,
                    custom_mask_type=int(attn_mask.variant),
                    compute_log_sumexp=compute_log_sumexp,
                    scale=scale,
                    causal_diagonal=None,
                    seqlen_k=None,
                )[0].transpose(1, 2)
            else:
                _raise_kernel_warnings(sdpa_params)
                # We cant use efficient attention the only support for lower right is via materialization
                return scaled_dot_product_attention(
                    query,
                    key,
                    value,
                    attn_mask=attn_mask._materialize(query.device),
                    dropout_p=dropout_p,
                    is_causal=False,
                    scale=scale,
                )
        else:
            raise ValueError(
                f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}"
            )

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias"""
        if kwargs is None:
            kwargs = {}
        if func != torch.nn.functional.scaled_dot_product_attention:
            raise NotImplementedError(
                "CausalBias only supports scaled_dot_product_attention"
            )
        return cls._dispatch(*args, **kwargs)

    def __repr__(self):
        return self._materialize().__repr__()


def causal_upper_left(*size) -> CausalBias:
    """

    Creates an upper-left triangular causal bias.



    This function generates a upper-left triangular matrix to represent causal attention bias with a

    diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix.

    This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`.



    The equivalent pytorch code for constructing this bias is:



    .. code-block:: python



        torch.tril(torch.ones(size, dtype=torch.bool))



    For instance, with `shape=(3,4)`, the materialized bias tensor will be:



    .. code-block:: text



        [[1, 0, 0, 0],

         [1, 1, 0, 0],

         [1, 1, 1, 0]]



    Args:

        size: The size of the bias matrix.



    Returns:

        CausalBias: The UPPER_LEFT triangular causal bias variant.

    """
    assert len(size) == 2, "causal_upper_left only supports 2D tensors"
    seq_len_q, seq_len_kv = size
    return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv)


def causal_lower_right(*size) -> CausalBias:
    """

    Creates a lower-right triangular causal bias.



    This function generates a lower-right triangular matrix to represent causal attention bias with a

    diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix.



    The equivalent pytorch code for constructing this bias is:



    .. code-block:: python



        diagonal_offset = size[1] - size[0]

        torch.tril(

            torch.ones(size, dtype=torch.bool),

            diagonal=diagonal_offset,

        )



    For instance, with `shape=(3,4)`, the materialized bias tensor will be:



    .. code-block:: text



        [[1, 1, 0, 0],

         [1, 1, 1, 0],

         [1, 1, 1, 1]]



    Args:

        size: The size of the bias matrix.



    Returns:

        CausalBias: The LOWER_RIGHT triangular causal bias variant.

    """
    assert len(size) == 2, "causal_lower_right only supports 2D tensors"
    seq_len_q, seq_len_kv = size
    return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv)