diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so deleted file mode 100755 index b2ea513b646d37a7af78f853e23b3153e86adf98..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9974a8b63b96c905a0d703052e362c92a083c04459bffaecfc73dc35032dbc02 -size 445281152 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_ops.py deleted file mode 100644 index a9819140ce922d5d25722ffeb3c2416285a9d068..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _flash_attn_876ac68_dirty -ops = torch.ops._flash_attn_876ac68_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_flash_attn_876ac68_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so deleted file mode 100755 index 3e675bff4732ce084f7e0af974503100b955467c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:bce17db4cca0b2d8894c084418e2e737e40df9e0e06746fe41270cd742471d42 -size 447262136 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_ops.py deleted file mode 100644 index a9819140ce922d5d25722ffeb3c2416285a9d068..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _flash_attn_876ac68_dirty -ops = torch.ops._flash_attn_876ac68_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_flash_attn_876ac68_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so deleted file mode 100755 index cece22edde52e2053aab7e63e4f403004d2e8f1f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8d5b6687e0f71e4af02344467f5f60c12f307d892f80d444651de4a7ba92e2c1 -size 448651728 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_ops.py deleted file mode 100644 index a9819140ce922d5d25722ffeb3c2416285a9d068..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _flash_attn_876ac68_dirty -ops = torch.ops._flash_attn_876ac68_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_flash_attn_876ac68_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/__init__.py deleted file mode 100644 index ecc2f9d896b6c93f90b0a1499856dc0612177422..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/__init__.py +++ /dev/null @@ -1,393 +0,0 @@ -from typing import Optional, List -import torch -from ._ops import ops as flash_attn_ops -from .flash_attn_interface import ( - flash_attn_func, - flash_attn_kvpacked_func, - flash_attn_qkvpacked_func, - flash_attn_varlen_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_with_kvcache, -) - - -def fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - p_dropout: float = 0.0, - softmax_scale: Optional[float] = None, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - return_softmax: bool = False, - gen: Optional[torch.Generator] = None, -) -> List[torch.Tensor]: - """ - Forward pass for multi-head attention. - - Args: - q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] - k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - out: Optional output tensor, same shape as q - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - p_dropout: Dropout probability - softmax_scale: Scale factor for softmax - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - return_softmax: Whether to return softmax weights - gen: Optional random number generator - - Returns: - List of tensors: [output, softmax_lse, (softmax if return_softmax)] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.fwd( - q, - k, - v, - out, - alibi_slopes, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - gen, - ) - - -def varlen_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - out: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - max_seqlen_q: int = 0, - max_seqlen_k: int = 0, - p_dropout: float = 0.0, - softmax_scale: Optional[float] = None, - zero_tensors: bool = False, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - return_softmax: bool = False, - gen: Optional[torch.Generator] = None, -) -> List[torch.Tensor]: - """ - Forward pass for multi-head attention with variable sequence lengths. - - Args: - q: Query tensor of shape [total_q, num_heads, head_size] - k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] - v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] - cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1] - cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1] - out: Optional output tensor of shape [total_q, num_heads, head_size] - seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size] - leftpad_k: Optional left padding for keys of shape [batch_size] - block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq] - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - max_seqlen_q: Maximum sequence length for queries - max_seqlen_k: Maximum sequence length for keys - p_dropout: Dropout probability - softmax_scale: Scale factor for softmax - zero_tensors: Whether to zero tensors before computation - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - return_softmax: Whether to return softmax weights - gen: Optional random number generator - - Returns: - List of tensors: [output, softmax_lse, (softmax if return_softmax)] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.varlen_fwd( - q, - k, - v, - out, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - leftpad_k, - block_table, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - p_dropout, - softmax_scale, - zero_tensors, - is_causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - gen, - ) - - -def bwd( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor] = None, - dk: Optional[torch.Tensor] = None, - dv: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - p_dropout: float = 0.0, - softmax_scale: Optional[float] = None, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - deterministic: bool = False, - gen: Optional[torch.Generator] = None, - rng_state: Optional[torch.Tensor] = None, -) -> List[torch.Tensor]: - """ - Backward pass for multi-head attention. - - Args: - dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size] - q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] - k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size] - softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q] - dq: Optional gradient tensor for queries, same shape as q - dk: Optional gradient tensor for keys, same shape as k - dv: Optional gradient tensor for values, same shape as v - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - p_dropout: Dropout probability - softmax_scale: Scale factor for softmax - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - deterministic: Whether to use deterministic algorithms - gen: Optional random number generator - rng_state: Optional RNG state from forward pass - - Returns: - List of tensors: [dq, dk, dv] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - alibi_slopes, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - softcap, - deterministic, - gen, - rng_state, - ) - - -def varlen_bwd( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - dq: Optional[torch.Tensor] = None, - dk: Optional[torch.Tensor] = None, - dv: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - max_seqlen_q: int = 0, - max_seqlen_k: int = 0, - p_dropout: float = 0.0, - softmax_scale: Optional[float] = None, - zero_tensors: bool = False, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - deterministic: bool = False, - gen: Optional[torch.Generator] = None, - rng_state: Optional[torch.Tensor] = None, -) -> List[torch.Tensor]: - """ - Backward pass for multi-head attention with variable sequence lengths. - - Args: - dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size] - q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] - k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size] - softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q] - cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1] - cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1] - dq: Optional gradient tensor for queries, same shape as q - dk: Optional gradient tensor for keys, same shape as k - dv: Optional gradient tensor for values, same shape as v - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - max_seqlen_q: Maximum sequence length for queries - max_seqlen_k: Maximum sequence length for keys - p_dropout: Dropout probability - softmax_scale: Scale factor for softmax - zero_tensors: Whether to zero tensors before computation - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - deterministic: Whether to use deterministic algorithms - gen: Optional random number generator - rng_state: Optional RNG state from forward pass - - Returns: - List of tensors: [dq, dk, dv] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.varlen_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - p_dropout, - softmax_scale, - zero_tensors, - is_causal, - window_size_left, - window_size_right, - softcap, - deterministic, - gen, - rng_state, - ) - - -def fwd_kvcache( - q: torch.Tensor, - kcache: torch.Tensor, - vcache: torch.Tensor, - k: Optional[torch.Tensor] = None, - v: Optional[torch.Tensor] = None, - seqlens_k: Optional[torch.Tensor] = None, - rotary_cos: Optional[torch.Tensor] = None, - rotary_sin: Optional[torch.Tensor] = None, - cache_batch_idx: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - is_rotary_interleaved: bool = False, - num_splits: int = 1, -) -> List[torch.Tensor]: - """ - Forward pass for multi-head attention with KV cache. - - Args: - q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] - kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] - vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] - k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size] - v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size] - seqlens_k: Optional sequence lengths for keys of shape [batch_size] - rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2] - rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2] - cache_batch_idx: Optional indices to index into the KV cache - leftpad_k: Optional left padding for keys of shape [batch_size] - block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq] - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - out: Optional output tensor, same shape as q - softmax_scale: Scale factor for softmax - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - is_rotary_interleaved: Whether rotary embeddings are interleaved - num_splits: Number of splits for computation - - Returns: - List of tensors: [output, softmax_lse] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.fwd_kvcache( - q, - kcache, - vcache, - k, - v, - seqlens_k, - rotary_cos, - rotary_sin, - cache_batch_idx, - leftpad_k, - block_table, - alibi_slopes, - out, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - softcap, - is_rotary_interleaved, - num_splits, - ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so deleted file mode 100755 index 25e9c6b38de6f5e85da7cc7443fe0ff81da7c6e2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cbfa3ad9f76e08999d944eabc46279ba3e8368d3b4d6180d8cff1da136c2ad34 -size 445272824 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_ops.py deleted file mode 100644 index a9819140ce922d5d25722ffeb3c2416285a9d068..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _flash_attn_876ac68_dirty -ops = torch.ops._flash_attn_876ac68_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_flash_attn_876ac68_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/bert_padding.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/bert_padding.py deleted file mode 100644 index 3c2d35159a014a9d03aabead9e52e009168696ea..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/bert_padding.py +++ /dev/null @@ -1,218 +0,0 @@ -# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py - -import torch -import torch.nn.functional as F -from einops import rearrange, repeat - - -class IndexFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices): - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - # return input[indices] - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) - ).reshape(-1, *other_shape) - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - grad_output = rearrange(grad_output, "b ... -> b (...)") - grad_input = torch.zeros( - [ctx.first_axis_dim, grad_output.shape[1]], - device=grad_output.device, - dtype=grad_output.dtype, - ) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - # grad_input[indices] = grad_output - grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -index_first_axis = IndexFirstAxis.apply - - -class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, values, indices, first_axis_dim): - ctx.save_for_backward(indices) - assert indices.ndim == 1 - assert values.ndim >= 2 - output = torch.zeros( - first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype - ) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - output[indices] = values - # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) - return output - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - grad_values = grad_output[indices] - # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) - return grad_values, None, None - - -index_put_first_axis = IndexPutFirstAxis.apply - - -class IndexFirstAxisResidual(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices): - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - output = input[indices] - # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last - # memory format to channel_first. In other words, input might not be contiguous. - # If we don't detach, Pytorch complains about output being a view and is being modified inplace - return output, input.detach() - - @staticmethod - def backward(ctx, grad_output, grad_residual): - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - assert grad_residual.shape[1:] == other_shape - grad_input = grad_residual - # grad_input[indices] += grad_output - indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) - indices = indices.expand_as(grad_output) - grad_input.scatter_add_(0, indices, grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -index_first_axis_residual = IndexFirstAxisResidual.apply - - -def unpad_input(hidden_states, attention_mask, unused_mask=None): - """ - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. - unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. - indices: (total_nnz), the indices of masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. - """ - all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask - seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) - used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return ( - index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - used_seqlens_in_batch, - ) - - -def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): - """ - Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). - The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). - - For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: - ``` - [ - [2, 3, 0, 0, 0, 0], - [3, 2, 0, 0, 0, 0], - [6, 0, 0, 0, 0, 0] - ] - ``` - , which refers to the 3D-attention mask: - ``` - [ - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0], - [0, 0, 1, 1, 0, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 1] - ], - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 1] - ], - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1] - ] - ] - ```. - - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - """ - length = attention_mask_in_length.sum(dim=-1) - seqlen = attention_mask_in_length.size(-1) - attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) - real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() - seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] - indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return ( - index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def pad_input(hidden_states, indices, batch, seqlen): - """ - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. - batch: int, batch size for the padded sequence. - seqlen: int, maximum sequence length for the padded sequence. - Return: - hidden_states: (batch, seqlen, ...) - """ - dim = hidden_states.shape[-1] - # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) - # output[indices] = hidden_states - output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, "(b s) ... -> b s ...", b=batch) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/flash_attn_interface.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/flash_attn_interface.py deleted file mode 100644 index 690d644f0a1c3d6ccfd26acbf8b22376a47cfff0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/flash_attn_interface.py +++ /dev/null @@ -1,1609 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -from typing import Optional, Sequence, Tuple, Union - -import torch -import torch.nn as nn -import os - -# # isort: off -# # We need to import the CUDA kernels after importing torch -# USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" -# if USE_TRITON_ROCM: -# from .flash_attn_triton_amd import interface_fa as flash_attn_gpu -# else: -# import flash_attn_2_cuda as flash_attn_gpu - - -from ._ops import ops as flash_attn_gpu - -# # isort: on - -def maybe_contiguous(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def _get_block_size_n(device, head_dim, is_dropout, is_causal): - # This should match the block sizes in the CUDA kernel - assert head_dim <= 256 - major, minor = torch.cuda.get_device_capability(device) - is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) - is_sm80 = major == 8 and minor == 0 - is_sm90 = major == 9 and minor == 0 - if head_dim <= 32: - return 128 - if head_dim <= 64: - return 128 if not is_dropout else 64 - elif head_dim <= 96: - return 64 - elif head_dim <= 128: - if is_sm8x: - return 64 if (not is_dropout and is_causal) else 32 - else: - return 64 if not is_dropout else 32 - elif head_dim <= 192: - return 64 - elif head_dim <= 224: - return 64 - elif head_dim <= 256: - return 64 - - -def round_multiple(x, m): - return (x + m - 1) // m * m - - -# torch.compile() support is only enabled for pytorch >= 2.4 -# The reason for this is that we are using the new custom_op and register_fake -# APIs, which support inplace modification of inputs in the function itself -if torch.__version__ >= "2.4.0": - _torch_custom_op_wrapper = torch.library.custom_op - _torch_register_fake_wrapper = torch.library.register_fake -else: - def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): - def wrap(func): - return func - if fn is None: - return wrap - return fn - def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): - def wrap(func): - return func - if fn is None: - return wrap - return fn - _torch_custom_op_wrapper = noop_custom_op_wrapper - _torch_register_fake_wrapper = noop_register_fake_wrapper - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda") -def _flash_attn_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - return_softmax: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( - q, - k, - v, - None, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - None, - ) - return out, softmax_lse, S_dmask, rng_state - - -@_torch_register_fake_wrapper("flash_attn::_flash_attn_forward") -def _flash_attn_forward_fake( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - return_softmax: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - batch_size, seqlen_q, num_heads, head_size = q.shape - seqlen_k = k.shape[1] - out = torch.empty_like(q) - softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout) - p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - if return_softmax: - p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) - rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) - - return out, softmax_lse, p, rng_state - - -if torch.__version__ >= "2.4.0": - _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward -else: - _wrapped_flash_attn_forward = _flash_attn_forward - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda") -def _flash_attn_varlen_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - return_softmax: bool = False, - block_table: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( - q, - k, - v, - None, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - leftpad_k, - block_table, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - zero_tensors, - causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - None, - ) - # if out.isnan().any() or softmax_lse.isnan().any(): - # breakpoint() - return out, softmax_lse, S_dmask, rng_state - - -@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward") -def _flash_attn_varlen_forward_fake( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - return_softmax: bool = False, - block_table: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - paged_kv = block_table is not None - batch_size = cu_seqlens_q.numel() - 1 - total_q, num_heads, _ = q.shape - - out = torch.empty_like(q) - softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) - p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - seqlen_q_rounded = round_multiple(max_seqlen_q, 128) - seqlen_k_rounded = round_multiple(max_seqlen_k, 128) - if return_softmax: - p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) - rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) - return out, softmax_lse, p, rng_state - - -if torch.__version__ >= "2.4.0": - _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward -else: - _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") -def _flash_attn_backward( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, -) -> torch.Tensor: - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - ( - dq, - dk, - dv, - softmax_d, - ) = flash_attn_gpu.bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - deterministic, - None, - rng_state, - ) - return softmax_d - - -@_torch_register_fake_wrapper("flash_attn::_flash_attn_backward") -def _flash_attn_backward_fake( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, -) -> torch.Tensor: - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - if dq is None: - dq = torch.empty_like(q) - if dk is None: - dk = torch.empty_like(k) - if dv is None: - dv = torch.empty_like(v) - batch_size, seqlen_q, num_heads, _ = q.shape - softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) - - return softmax_d - - -if torch.__version__ >= "2.4.0": - _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward -else: - _wrapped_flash_attn_backward = _flash_attn_backward - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") -def _flash_attn_varlen_backward( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> torch.Tensor: - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - ( - dq, - dk, - dv, - softmax_d, - ) = flash_attn_gpu.varlen_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - zero_tensors, - causal, - window_size_left, - window_size_right, - softcap, - deterministic, - None, - rng_state, - ) - # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): - # breakpoint() - return softmax_d - - -@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward") -def _flash_attn_varlen_backward_fake( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> torch.Tensor: - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - batch_size = cu_seqlens_q.numel() - 1 - total_q, num_heads, _ = q.shape - - if dq is None: - dq = torch.empty_like(q) - if dk is None: - dk = torch.empty_like(k) - if dv is None: - dv = torch.empty_like(v) - softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) - - return softmax_d - - -if torch.__version__ >= "2.4.0": - _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward -else: - _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward - - -class FlashAttnQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach() - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) - dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dqkv[:, :, 0], - dqkv[:, :, 1], - dqkv[:, :, 2], - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None - - -class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach() - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - block_table=None, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) - ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors - qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) - dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_varlen_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - cu_seqlens, - cu_seqlens, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None, None - - -class FlashAttnKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - kv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, kv] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach() - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - dq = torch.empty_like(q) - kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) - dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dkv[:, :, 0], - dkv[:, :, 1], - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None - - -class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, kv] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - k, v = kv[:, 0].detach(), kv[:, 1].detach() - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - block_table=None, - ) - if is_grad: - ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state - ) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - dq = torch.empty_like(q) - kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) - dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_varlen_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dkv[:, 0], - dkv[:, 1], - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None - - -class FlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None - - -class FlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - block_table, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - block_table=block_table, - ) - if is_grad: - ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state - ) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_varlen_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # <=0.0 means deactivate - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If Q, K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of Q, K, V. - For multi-query and grouped-query attention (MQA/GQA), please see - flash_attn_kvpacked_func and flash_attn_func. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. - - Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to - the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnQKVPackedFunc.apply( - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of K, V. - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - kv: (batch_size, seqlen, 2, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnKVPackedFunc.apply( - q, - kv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If Q, K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of Q, K, V. - For multi-query and grouped-query attention (MQA/GQA), please see - flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. - - Arguments: - qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into qkv. - max_seqlen: int. Maximum sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenQKVPackedFunc.apply( - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of K, V. - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenKVPackedFunc.apply( - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - block_table=None, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - block_table, - torch.is_grad_enabled(), - ) - - -def flash_attn_with_kvcache( - q, - k_cache, - v_cache, - k=None, - v=None, - rotary_cos=None, - rotary_sin=None, - cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, - cache_batch_idx: Optional[torch.Tensor] = None, - cache_leftpad: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - rotary_interleaved=True, - alibi_slopes=None, - num_splits=0, - return_softmax_lse=False, -): - """ - If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from - k and v. This is useful for incremental decoding: you can pass in the cached keys/values from - the previous step, and update them with the new keys/values from the current step, and do - attention with the updated cache, all in 1 kernel. - - If you pass in k / v, you must make sure that the cache is large enough to hold the new values. - For example, the KV cache could be pre-allocated with the max sequence length, and you can use - cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. - - Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be - rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos - and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at - indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). - - See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. - - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Note: Does not support backward pass. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) - page_block_size must be a multiple of 256. - v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) - k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate - k with k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. - rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding - to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. - rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. - cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the - KV cache. - cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. - If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. - If the indices are not distinct, and k and v are provided, the values updated in the cache - might come from any of the duplicate indices. - cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. - block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. - If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, - rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 - (i.e. GPT-NeoX style). - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - num_splits: int. If > 1, split the key/value into this many chunks along the sequence. - If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic - to automatically determine the number of splits. - Don't change this unless you know what you are doing. - return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. - - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - """ - assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" - assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - if cache_seqlens is not None and isinstance(cache_seqlens, int): - cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device - ) - cache_seqlens = maybe_contiguous(cache_seqlens) - cache_batch_idx = maybe_contiguous(cache_batch_idx) - block_table = maybe_contiguous(block_table) - out, softmax_lse = flash_attn_gpu.fwd_kvcache( - q, - k_cache, - v_cache, - k, - v, - cache_seqlens, - rotary_cos, - rotary_sin, - cache_batch_idx, - cache_leftpad, - block_table, - alibi_slopes, - None, - softmax_scale, - causal, - window_size[0], - window_size[1], - softcap, - rotary_interleaved, - num_splits, - ) - return (out, softmax_lse) if return_softmax_lse else out diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/layers/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/layers/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/layers/patch_embed.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/layers/patch_embed.py deleted file mode 100644 index 05562f8e8bcdb58e947c6f402a49eacd2d031871..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/layers/patch_embed.py +++ /dev/null @@ -1,67 +0,0 @@ -# We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py -# But we use nn.Linear instead of Conv2d and it's about 8x faster. - -from functools import partial - -import torch.nn as nn -from einops import rearrange -from torch import _assert -from torch.nn.modules.utils import _pair - -try: - from flash_attn.ops.fused_dense import FusedDense -except ImportError: - FusedDense = None - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding""" - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - norm_layer=None, - flatten=True, - bias=True, - fused_bias_fc=False, - ): - super().__init__() - img_size = _pair(img_size) - patch_size = _pair(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - if fused_bias_fc and FusedDense is None: - raise ImportError("fused_dense is not installed") - - linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense - self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x): - _, _, H, W = x.shape - _assert( - H == self.img_size[0], - f"Input image height ({H}) doesn't match model ({self.img_size[0]}).", - ) - _assert( - W == self.img_size[1], - f"Input image width ({W}) doesn't match model ({self.img_size[1]}).", - ) - x = self.proj( - rearrange( - x, - "b c (h p1) (w p2) -> b h w (c p1 p2)", - p1=self.patch_size[0], - p2=self.patch_size[1], - ) - ) - if self.flatten: - x = rearrange(x, "b h w c -> b (h w) c") - x = self.norm(x) - return x diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/layers/rotary.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/layers/rotary.py deleted file mode 100644 index d1bfc21fc7de1dd287e8f382847b194a48075981..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/layers/rotary.py +++ /dev/null @@ -1,483 +0,0 @@ -# Copyright (c) 2025, Tri Dao - -import math -from functools import partial -from typing import Optional, Tuple, Union - -import torch -from torch import Tensor - -from einops import rearrange, repeat -# from flash_attn.ops.triton.rotary import apply_rotary -from ..ops.triton.rotary import apply_rotary - - -def rotate_half(x, interleaved=False): - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) - - -def apply_rotary_emb_torch(x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - return torch.cat( - [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], - dim=-1, - ) - - -class ApplyRotaryEmb(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - cos, - sin, - interleaved=False, - inplace=False, - seqlen_offsets: Union[int, Tensor] = 0, - cu_seqlens: Optional[Tensor] = None, - max_seqlen: Optional[int] = None, - ): - out = apply_rotary( - x, - cos, - sin, - seqlen_offsets=seqlen_offsets, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - interleaved=interleaved, - inplace=inplace, - ) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - ctx.inplace = inplace - ctx.max_seqlen = max_seqlen - return out if not inplace else x - - @staticmethod - def backward(ctx, do): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors - else: - cos, sin, cu_seqlens = ctx.saved_tensors - dx = apply_rotary( - do, - cos, - sin, - seqlen_offsets=seqlen_offsets, - cu_seqlens=cu_seqlens, - max_seqlen=ctx.max_seqlen, - interleaved=ctx.interleaved, - inplace=ctx.inplace, - conjugate=True, - ) - return dx, None, None, None, None, None, None, None - - -def apply_rotary_emb( - x, - cos, - sin, - interleaved=False, - inplace=False, - seqlen_offsets: Union[int, Tensor] = 0, - cu_seqlens: Optional[Tensor] = None, - max_seqlen: Optional[int] = None, -): - """ - Arguments: - x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim) - cos, sin: (seqlen_rotary, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - inplace: if True, apply rotary embedding in-place. - seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. - Most commonly used in inference when we have KV cache. - cu_seqlens: (batch + 1,) or None - max_seqlen: int - Return: - out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim) - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - return ApplyRotaryEmb.apply( - x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen - ) - - -# For backward compatibility -apply_rotary_emb_func = apply_rotary_emb - - -def _apply_rotary_emb_qkv( - qkv, - cos, - sin, - cos_k=None, - sin_k=None, - interleaved=False, - inplace=False, - conjugate=False, - seqlen_offsets: Union[int, Tensor] = 0, - num_heads_q: Optional[int] = None, -): - apply_rotary_fn = partial( - apply_rotary, - interleaved=interleaved, - inplace=inplace, - conjugate=conjugate, - seqlen_offsets=seqlen_offsets - ) - if cos_k is None and sin_k is None and qkv.is_contiguous(): - # Call 1 kernel instead of 2 kernels - # We need qkv to be contiguous so that when we reshape to combine (3, nheads) - # dimensions, we get the same tensor - if qkv.dim() == 5: - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") - qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) - qk = apply_rotary_fn(qk, cos, sin) - else: - assert qkv.dim() == 4 - assert num_heads_q is not None - num_heads_k = (qkv.shape[2] - num_heads_q) // 2 - assert qkv.shape[2] == num_heads_q + 2 * num_heads_k - qk = qkv[:, :, :num_heads_q + num_heads_k] - qk = apply_rotary_fn(qk, cos, sin) - if not inplace: - if qkv.dim() == 5: - qkv = torch.cat([rearrange(qk, "b s (t h) d -> b s t h d", t=2), qkv[:, :, 2:]], dim=2) - else: - qkv = torch.cat([qk, qkv[:, :, num_heads_q + num_heads_k :]], dim=2) - else: - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - if qkv.dim() == 5: - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - q, k = qkv[:, :, 0], qkv[:, :, 1] - else: - assert qkv.dim() == 4 - assert num_heads_q is not None - num_heads_k = (qkv.shape[2] - num_heads_q) // 2 - assert qkv.shape[2] == num_heads_q + 2 * num_heads_k - q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k] - q = apply_rotary_fn(q, cos, sin) - k = apply_rotary_fn(k, cos_k, sin_k) - if not inplace: - if qkv.dim() == 5: - qkv = torch.stack([q, k, qkv[:, :, 2]], dim=2) - else: - qkv = torch.cat([q, k, qkv[:, :, num_heads_q + num_heads_k:]], dim=2) - return qkv - - -class ApplyRotaryEmbQKV_(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cos, - sin, - cos_k=None, - sin_k=None, - interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - num_heads_q: Optional[int] = None, - ): - # apply_rotary_emb_qkv_inplace( - qkv = _apply_rotary_emb_qkv( - qkv, cos, sin, cos_k, sin_k, interleaved=interleaved, inplace=True, - seqlen_offsets=seqlen_offsets, num_heads_q=num_heads_q, - ) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - ctx.num_heads_q = num_heads_q - return qkv - - @staticmethod - def backward(ctx, dqkv): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors - else: - cos, sin, cos_k, sin_k = ctx.saved_tensors - dqkv = _apply_rotary_emb_qkv( - dqkv, cos, sin, cos_k, sin_k, interleaved=ctx.interleaved, inplace=True, - seqlen_offsets=seqlen_offsets, num_heads_q=ctx.num_heads_q, conjugate=True, - ) - return dqkv, None, None, None, None, None, None, None - - -def apply_rotary_emb_qkv_( - qkv, - cos, - sin, - cos_k=None, - sin_k=None, - interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - num_heads_q: Optional[int] = None, -): - """ - Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim). - If qkv has shape (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA), - then num_heads_q must be provided. - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. - Most commonly used in inference when we have KV cache. - Return: - qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of Q and K. - """ - return ApplyRotaryEmbQKV_.apply( - qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, num_heads_q - ) - - -class ApplyRotaryEmbKV_(torch.autograd.Function): - - @staticmethod - def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): - batch, seqlen, two, nheads, headdim = kv.shape - assert two == 2 - k = kv[:, :, 0] - apply_rotary( - k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True - ) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - return kv - - @staticmethod - def backward(ctx, dkv): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, seqlen_offsets = ctx.saved_tensors - else: - cos, sin = ctx.saved_tensors - apply_rotary( - dkv[:, :, 0], - cos, - sin, - seqlen_offsets=seqlen_offsets, - interleaved=ctx.interleaved, - inplace=True, - conjugate=True, - ) - return dkv, None, None, None, None - - -apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply - - -def apply_rotary_emb_kv_( - kv, - cos, - sin, - interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, -): - """ - Arguments: - kv: (batch_size, seqlen, 2, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. - Most commonly used in inference when we have KV cache. - Return: - kv: (batch_size, seqlen, 2, nheads, headdim) - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of K. - """ - return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets) - - -class RotaryEmbedding(torch.nn.Module): - """ - The rotary position embeddings from RoFormer_ (Su et. al). - A crucial insight from the method is that the query and keys are - transformed by rotation matrices which depend on the relative positions. - - Other implementations are available in the Rotary Transformer repo_ and in - GPT-NeoX_, GPT-NeoX was an inspiration - - .. _RoFormer: https://arxiv.org/abs/2104.09864 - .. _repo: https://github.com/ZhuiyiTechnology/roformer - .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox - - If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). - A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 - Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py - """ - - def __init__( - self, - dim: int, - base=10000.0, - interleaved=False, - scale_base=None, - device=None, - ): - """ - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - """ - super().__init__() - self.dim = dim - self.base = float(base) - # Generate and save the inverse frequency buffer (non trainable) - inv_freq = self._compute_inv_freq(device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.interleaved = interleaved - self.scale_base = scale_base - scale = ( - (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) - if scale_base is not None - else None - ) - self.register_buffer("scale", scale, persistent=False) - - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - - def _compute_inv_freq(self, device=None): - return 1.0 / ( - self.base - ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) - ) - - def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): - # Reset the tables if the sequence length has changed, - # if we're on a new device (possibly due to tracing for instance), - # or if we're switching from inference mode to training - if ( - seqlen > self._seq_len_cached - or self._cos_cached is None - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._seq_len_cached = seqlen - # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 - # And the output of arange can be quite large, so bf16 would lose a lot of precision. - t = torch.arange(seqlen, device=device, dtype=torch.float32) - # We want fp32 here as well since inv_freq will be multiplied with t, and the output - # will be large. Having it in bf16 will lose a lot of precision and cause the - # cos & sin output to change significantly. - # We want to recompute self.inv_freq if it was not loaded in fp32 - if self.inv_freq.dtype != torch.float32: - inv_freq = self._compute_inv_freq(device=device) - else: - inv_freq = self.inv_freq - # Don't do einsum, it converts fp32 to bf16 under AMP - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, inv_freq) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") - # We want the multiplication by scale to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - - def forward( - self, - qkv: torch.Tensor, - kv: Optional[torch.Tensor] = None, - seqlen_offset: Union[int, torch.Tensor] = 0, - max_seqlen: Optional[int] = None, - num_heads_q: Optional[int] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - qkv: (batch, seqlen, 3, nheads, headdim) or (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) - if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim). - If qkv has shape (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA), - then num_heads_q must be provided. - kv: (batch, seqlen, 2, nheads, headdim) - seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. - Most commonly used in inference when we have KV cache. - If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one - should pass in max_seqlen, which will update the cos / sin cache up to that length. - Apply rotary embedding *inplace* to qkv and / or kv. - """ - seqlen = qkv.shape[1] - if max_seqlen is not None: - self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) - elif isinstance(seqlen_offset, int): - self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) - if kv is None: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached, - self._sin_cached, - self._cos_k_cached if self.scale is not None else None, - self._sin_k_cached if self.scale is not None else None, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - num_heads_q=num_heads_q, - ) - else: - q = qkv - q = apply_rotary_emb_func( - q, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - inplace=True, - seqlen_offsets=seqlen_offset, - ) - kv = apply_rotary_emb_kv_( - kv, - self._cos_cached if self.scale is None else self._cos_k_cached, - self._sin_cached if self.scale is None else self._sin_k_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - return q, kv diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/activations.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/activations.py deleted file mode 100644 index 7c09649fc41e12d5a360c5672825d8380bc7ec80..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/activations.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -# 1/sqrt(2*pi)-> 0.3989423 -# 1/sqrt(2) -> 0.70710678 -# sqrt(2/pi) -> 0.79788456 - -# this function is tanh approximation of gelu -# actual gelu is: -# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) -@torch.jit.script -def bias_gelu(y, bias): - x = bias + y - return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) - - -# gradient of tanh approximation of gelu -# gradient of actual gelu is: -# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) -@torch.jit.script -def bias_gelu_back(g, y, bias): - """Assume that y has shape (B, D) and bias has shape (D)""" - x = bias + y - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) - grad_y = ff * g - return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) - - -class GeLUFunction(torch.autograd.Function): - @staticmethod - # bias is an optional argument - def forward(ctx, input, bias): - ctx.save_for_backward(input, bias) - return bias_gelu(input, bias) - - @staticmethod - def backward(ctx, grad_output): - input, bias = ctx.saved_tensors - tmp = bias_gelu_back(grad_output, input, bias) - return tmp, tmp - - -bias_gelu_impl = GeLUFunction.apply - -# this function is tanh approximation of gelu -# actual gelu is: -# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) -@torch.jit.script -def gelu_fwd(x): - return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) - - -# gradient of tanh approximation of gelu -# gradient of actual gelu is: -# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) -@torch.jit.script -def gelu_bwd(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) - return (ff * g).to(dtype=x.dtype) - - -class FastGeLUFunction(torch.autograd.Function): - @staticmethod - # bias is an optional argument - def forward(ctx, input): - ctx.save_for_backward(input) - return gelu_fwd(input) - - @staticmethod - def backward(ctx, grad_output): - (input,) = ctx.saved_tensors - tmp = gelu_bwd(grad_output, input) - return tmp - - -fast_gelu_impl = FastGeLUFunction.apply - - -@torch.jit.script -def relu_bwd(g, x): - return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) - - -@torch.jit.script -def sqrelu_fwd(x): - r = F.relu(x) - return (r * r).to(dtype=x.dtype) - - -@torch.jit.script -def sqrelu_bwd(g, x): - return (2.0 * g * F.relu(x)).to(dtype=x.dtype) - - -swiglu_fwd_codestring = """ -template T swiglu_fwd(T x, T y) { - return float(x) * float(y) / (1.0f + ::exp(-float(x))); -} -""" -swiglu_bwd_codestring = """ -template void swiglu_bwd(T x, T y, T g, T& dx, T& dy) { - float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); - dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); - dy = float(x) * x_sigmoid * float(g); -} -""" -swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) -swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) - - -class SwiGLUFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, y): - ctx.save_for_backward(x, y) - return swiglu_fwd(x, y) - - @staticmethod - def backward(ctx, dout): - x, y = ctx.saved_tensors - return swiglu_bwd(x, y, dout) - -swiglu = SwiGLUFunction.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/fused_dense.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/fused_dense.py deleted file mode 100644 index 6b4033d134e4093fe278f7b3f8c7d3128ce9f36d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/fused_dense.py +++ /dev/null @@ -1,688 +0,0 @@ -# Copyright (c) 2023, Tri Dao. -# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py -# We make it work with pytorch amp and with bfloat16. -# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py -from functools import partial -from typing import Optional - -# import fused_dense_cuda # from apex -import fused_dense_lib as fused_dense_cuda -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup - -from flash_attn.utils.torch import custom_fwd, custom_bwd -from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd -from flash_attn.utils.distributed import ( - all_gather_raw, - all_reduce, - all_reduce_raw, - reduce_scatter, - reduce_scatter_raw, -) - - -class FusedDenseFunc(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True - ): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather_raw of x before doing the matmul. - """ - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - - if torch.is_autocast_enabled(): - weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) - bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - weight = weight.contiguous() - if process_group is not None and sequence_parallel: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - output = F.linear(total_x, weight, bias) - if ctx.compute_weight_gradient: - ctx.save_for_backward(x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - if ctx.compute_weight_gradient: - x, weight = ctx.saved_tensors - if process_group is not None and sequence_parallel: - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - else: - (weight,) = ctx.saved_tensors - total_x = None - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, weight.t()) - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw - grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.needs_input_grad[1]: - assert ctx.compute_weight_gradient - if process_group is not None and sequence_parallel: - handle_x.wait() - grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] - ) - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None - if process_group is not None and ctx.needs_input_grad[0]: - handle_grad_input.wait() - return grad_input, grad_weight, grad_bias, None, None, None - - -def fused_dense_func( - x: Tensor, - weight: Tensor, - bias: Optional[Tensor] = None, - return_residual: bool = False, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, -): - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FusedDenseFunc.apply( - x, weight, bias, return_residual, process_group, sequence_parallel - ) - else: - assert process_group is None - out = F.linear(x, weight, bias) - return out if not return_residual else (out, x) - - -class FusedDense(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - return_residual: bool = False, - device=None, - dtype=None, - ) -> None: - super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - self.return_residual = return_residual - - def forward(self, x, process_group=None): - """ - If process_group is not None, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - """ - return fused_dense_func( - x, - self.weight, - self.bias, - return_residual=self.return_residual, - process_group=process_group, - ) - - -class ColumnParallelLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - multiple_of=1, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - if out_features % multiple_of: - raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") - multiple = out_features // multiple_of - # We want to split @multiple across world_size, but it could be an uneven split - div = multiple // world_size - mod = multiple % world_size - # The first @mod ranks get @div + 1 copies, the rest get @div copies - local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) - super().__init__( - in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype - ) - self.process_group = process_group - self.sequence_parallel = sequence_parallel - - def forward(self, x): - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - return fused_dense_func( - x, - self.weight, - self.bias, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - ) - - -class RowParallelLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - multiple_of=1, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - rank = torch.distributed.get_rank(process_group) - if in_features % multiple_of: - raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") - multiple = in_features // multiple_of - # We want to split @multiple across world_size, but it could be an uneven split - div = multiple // world_size - mod = multiple % world_size - # The first @mod ranks get @div + 1 copies, the rest get @div copies - local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) - # Only rank 0 will have bias - super().__init__( - local_multiple * multiple_of, - out_features, - bias=bias and rank == 0, - device=device, - dtype=dtype, - ) - self.process_group = process_group - self.sequence_parallel = sequence_parallel - - def forward(self, x): - """ - We're doing Tensor Parallel with sequence parallelism: we do the matmul and then - a reduce_scatter of the result. - """ - out = fused_dense_func(x, self.weight, self.bias) - reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce - return reduce_fn(out, self.process_group) - - -class FusedMLPFunc(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, - x, - weight1, - bias1, - weight2, - bias2, - activation="gelu_approx", - save_pre_act=True, - return_residual=False, - checkpoint_lvl=0, - heuristic=0, - process_group=None, - sequence_parallel=True, - ): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather of x before doing the matmul. - If sequence_parallel=False, then the input is already gathered. - - checkpoint_lvl: - 0: no recomputation in the bwd - 1: recompute gelu_out / relu_out in the bwd - 2: recompute pre_act and gelu_out / relu_out in the bwd - """ - assert -1 <= heuristic <= 4 - assert activation in ["gelu_approx", "relu", "sqrelu"] - if activation == "sqrelu": - assert heuristic == -1 - if not save_pre_act: - checkpoint_lvl = 2 - assert checkpoint_lvl in [0, 1, 2] - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - ctx.checkpoint_lvl = checkpoint_lvl - ctx.activation = activation - ctx.heuristic = heuristic - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_gpu_dtype() - weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]] - bias1 = bias1.to(dtype=dtype) if bias1 is not None else None - bias2 = bias2.to(dtype=dtype) if bias2 is not None else None - weight1 = weight1.contiguous() - bias1 = bias1.contiguous() if bias1 is not None else None - weight2 = weight2.contiguous() - bias2 = bias2.contiguous() if bias2 is not None else None - if process_group is not None and sequence_parallel: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - if heuristic == -1: - pre_act = F.linear(total_x, weight1, bias1) - activation_fn = ( - partial(F.gelu, approximate="tanh") - if activation == "gelu_approx" - else (sqrelu_fwd if activation == "sqrelu" else F.relu) - ) - with torch.jit.fuser("fuser2"): - output1 = activation_fn(pre_act) - # This is before adding bias1 - # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1) - # with torch.jit.fuser('fuser2'): - # output1 = bias_gelu(pre_act, bias1) - else: - is_gelu = activation == "gelu_approx" - output1, *rest = fused_dense_cuda.linear_act_forward( - total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic - ) - if save_pre_act: - pre_act = rest[0] - output2 = F.linear(output1, weight2, bias2) - if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): - # For RELU the pre_act is very small (just a bit-mask) so we just save it - ctx.save_for_backward(x, weight1, weight2, pre_act, output1) - elif checkpoint_lvl == 1: - ctx.save_for_backward(x, weight1, weight2, pre_act) - elif checkpoint_lvl == 2: - ctx.save_for_backward(x, weight1, weight2, bias1) - output2 = output2.reshape(*batch_shape, output2.shape[-1]) - return output2 if not return_residual else (output2, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - checkpoint_lvl = ctx.checkpoint_lvl - activation = ctx.activation - activation_fn = ( - partial(F.gelu, approximate="tanh") - if activation == "gelu_approx" - else (sqrelu_fwd if activation == "sqrelu" else F.relu) - ) - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - x, weight1, weight2, *rest = ctx.saved_tensors - if process_group is None or not sequence_parallel: - total_x = x - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - if checkpoint_lvl in [0, 1]: - if process_group is not None and sequence_parallel: - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): - pre_act, output1 = rest - elif checkpoint_lvl == 1: - (pre_act,) = rest - with torch.jit.fuser("fuser2"): - output1 = activation_fn(pre_act) - elif checkpoint_lvl == 2: - (bias1,) = rest - if process_group is not None and sequence_parallel: - total_x, _ = all_gather_raw(x, process_group) - if ctx.heuristic == -1: - pre_act = F.linear(total_x, weight1, bias1) - with torch.jit.fuser("fuser2"): - output1 = activation_fn(pre_act) - else: - output1, pre_act = fused_dense_cuda.linear_act_forward( - total_x.reshape(batch_dim, total_x.shape[-1]), - weight1, - bias1, - activation == "gelu_approx", - True, - ctx.heuristic, - ) - - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - output1 = output1.reshape(batch_dim, output1.shape[-1]) - pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1]) - if ctx.needs_input_grad[3]: - grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad( - output1, grad_output, ctx.needs_input_grad[4] - ) - else: - grad_weight2 = None - grad_bias2 = grad_output if ctx.needs_input_grad[4] else None - if ctx.heuristic == -1: - # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act) - grad_output1 = F.linear(grad_output, weight2.t()) - activation_grad_fn = ( - gelu_bwd - if activation == "gelu_approx" - else (sqrelu_bwd if activation == "sqrelu" else relu_bwd) - ) - with torch.jit.fuser("fuser2"): - grad_pre_act = activation_grad_fn(grad_output1, pre_act) - else: - # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't - # just compute gelu/relu grad - grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad( - weight2, grad_output, pre_act, activation == "gelu_approx", ctx.heuristic - ) - if not ctx.needs_input_grad[2]: - grad_bias1 = None - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_pre_act, weight1.t()) - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_pre_act, weight1 - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw - grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.heuristic == -1: - if ctx.needs_input_grad[1]: - if process_group is not None and sequence_parallel and checkpoint_lvl != 2: - handle_x.wait() - grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), - grad_pre_act, - ctx.needs_input_grad[2], - ) - else: - grad_weight1 = None - grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None - else: - if ctx.needs_input_grad[1]: - if process_group is not None and sequence_parallel and checkpoint_lvl != 2: - handle_x.wait() - grad_weight1 = F.linear( - grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t() - ) - else: - grad_weight1 = None - if process_group is not None and ctx.needs_input_grad[0]: - handle_grad_input.wait() - return ( - grad_input, - grad_weight1, - grad_bias1, - grad_weight2, - grad_bias2, - None, - None, - None, - None, - None, - None, - None, - ) - - -def fused_mlp_func( - x: Tensor, - weight1: Tensor, - weight2: Tensor, - bias1: Optional[Tensor] = None, - bias2: Optional[Tensor] = None, - activation: str = "gelu_approx", - save_pre_act: bool = True, - return_residual: bool = False, - checkpoint_lvl: int = 0, - heuristic: int = 0, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, -): - assert activation in ["gelu_approx", "relu", "sqrelu"] - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu) - dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == "relu" else 8) == 0) - if ( - x.is_cuda - and weight1.is_cuda - and weight2.is_cuda - and (bias1 is None or bias1.is_cuda) - and (bias2 is None or bias2.is_cuda) - and dtype_eligible - and dim_eligible - ): - return FusedMLPFunc.apply( - x, - weight1, - bias1, - weight2, - bias2, - activation, - save_pre_act, - return_residual, - checkpoint_lvl, - heuristic, - process_group, - sequence_parallel, - ) - else: - assert process_group is None - pre_act = F.linear(x, weight1, bias1) - activation_fn = ( - partial(F.gelu, approximate="tanh") - if activation == "gelu_approx" - else partial(F.relu, inplace=True) - ) - output1 = activation_fn(pre_act) - output2 = F.linear(output1, weight2, bias2) - return output2 if not return_residual else (output2, x) - - -class FusedMLP(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - bias1=True, - bias2=True, - activation="gelu_approx", - return_residual=False, - checkpoint_lvl=0, - heuristic="auto", - device=None, - dtype=None, - ): - """ - If process_group is not None, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul, gelu, then matmul. - Finally we do a reduce_scatter of the output. - - checkpoint_lvl (increasing lvl means slower but more memory saving): - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute pre_act and gelu_out in the bwd - heuristic: - -1: don't fuse gemm + gelu (separate kernel) - 0..4: use this heuristic for the algo section in the fused gemm + gelu - 'auto': heuristic will be picked automatically: - For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. - For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. - For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation - is slower than the unfused version. - return_residual: whether to return the input x along with the output. This is for - performance reason: for post-norm architecture, returning the input allows us - to fuse the backward of nn.Linear with the residual connection. - """ - assert checkpoint_lvl in [0, 1, 2] - assert activation in ["gelu_approx", "relu", "sqrelu"] - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features * 4 - self.activation = activation - self.return_residual = return_residual - self.checkpoint_lvl = checkpoint_lvl - self.heuristic = heuristic if activation != "sqrelu" else -1 - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) - - def forward(self, x, process_group=None): - dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() - if self.heuristic == "auto": - if self.activation == "gelu_approx": - if torch.cuda.get_device_capability("cuda") == (9, 0): - heuristic = -1 - else: - cuda_ver = tuple(map(int, torch.version.cuda.split("."))) - heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) - else: - heuristic = 0 - else: - heuristic = self.heuristic - out = fused_mlp_func( - x, - self.fc1.weight, - self.fc2.weight, - self.fc1.bias, - self.fc2.bias, - activation=self.activation, - save_pre_act=self.training, - return_residual=self.return_residual, - checkpoint_lvl=self.checkpoint_lvl, - heuristic=heuristic, - process_group=process_group, - ) - if self.return_residual: - out, x = out - if process_group is not None: - out = reduce_scatter(out, process_group) - return out if not self.return_residual else (out, x) - - -class ParallelFusedMLP(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - activation="gelu_approx", - process_group: ProcessGroup = None, - bias1=True, - bias2=True, - sequence_parallel=True, - checkpoint_lvl=0, - heuristic="auto", - device=None, - dtype=None, - ): - """ - process_group is required. We're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul, gelu, then matmul. - Finally we do a reduce_scatter of the output. - - checkpoint_lvl (increasing lvl means slower but more memory saving): - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute pre_act and gelu_out in the bwd - heuristic: - -1: don't fuse gemm + gelu (separate kernel) - 0..4: use this heuristic for the algo section in the fused gemm + gelu - 'auto': heuristic will be picked automatically: - For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. - For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. - """ - assert checkpoint_lvl in [0, 1, 2] - assert activation in ["gelu_approx", "relu", "sqrelu"] - assert process_group is not None - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features * 4 - self.activation = activation - self.process_group = process_group - self.sequence_parallel = sequence_parallel - self.checkpoint_lvl = checkpoint_lvl - self.heuristic = heuristic if activation != "sqrelu" else -1 - self.fc1 = ColumnParallelLinear( - in_features, hidden_features, process_group, bias=bias1, **factory_kwargs - ) - self.fc2 = RowParallelLinear( - hidden_features, out_features, process_group, bias=bias2, **factory_kwargs - ) - - def forward(self, x): - dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() - if self.heuristic == "auto": - if self.activation == "gelu_approx": - cuda_ver = tuple(map(int, torch.version.cuda.split("."))) - heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) - else: - heuristic = 0 - else: - heuristic = self.heuristic - out = fused_mlp_func( - x, - self.fc1.weight, - self.fc2.weight, - self.fc1.bias, - self.fc2.bias, - activation=self.activation, - save_pre_act=self.training, - checkpoint_lvl=self.checkpoint_lvl, - heuristic=heuristic, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - ) - reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce - return reduce_fn(out, self.process_group) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/layer_norm.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/layer_norm.py deleted file mode 100644 index 4b6cd798fd02844ef9cd3897f8ab95e490e638bf..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/layer_norm.py +++ /dev/null @@ -1,800 +0,0 @@ -# Copyright (c) 2022, Tri Dao. -# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py - -import dropout_layer_norm -import torch -from torch.nn import init - - -def maybe_align(x, alignment_in_bytes=16): - """Assume that x already has last dim divisible by alignment_in_bytes""" - # TD [2023-07-04] I'm not 100% sure that clone will align the memory - # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 - return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() - - -def _dropout_add_layer_norm_forward( - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma.numel() - x0mat = x0.view((-1, hidden_size)) - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - rowscale = rowscale.view(-1) if rowscale is not None else None - zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, - residualmat, - gamma, - beta, - rowscale, - colscale, - None, - None, - dropout_p, - epsilon, - 1.0, - 0, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma - - -def _dropout_add_layer_norm_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - dropout_p, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - x0 must not be None if we have colscale. - """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(xmat.shape) - dxmat = dx.view(xmat.shape) if dx is not None else None - x0mat = x0.view((-1, hidden_size)) if x0 is not None else None - rowscale = rowscale.view(-1) if rowscale is not None else None - if colscale is not None: - assert x0 is not None, "x0 is required to compute the gradient of colscale" - dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, - dxmat, - xmat, - x0mat, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - None, - None, - dropout_p, - 1.0, - 0, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - if colscale is None: - return dx0mat, dresidualmat, dgamma, dbeta - else: - dcolscale = rest[0] - return dx0mat, dresidualmat, dgamma, dbeta, dcolscale - - -def _dropout_add_layer_norm_subset_forward( - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma.numel() - x0mat = x0.view((-1, hidden_size)) - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - x0_subset = x0_subset.view(-1) if x0_subset is not None else None - out_subset = out_subset.view(-1) if out_subset is not None else None - zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, - residualmat, - gamma, - beta, - None, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma - - -def _dropout_add_layer_norm_subset_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - colscale, - x0_subset, - out_subset, - dropout_p, - rowscale_const, - x0_numrows, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - x0 must not be None if we have colscale. - """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(-1, hidden_size) - dxmat = dx.view(xmat.shape) if dx is not None else None - x0mat = x0.view((-1, hidden_size)) if x0 is not None else None - x0_subset = x0_subset.view(-1) if x0_subset is not None else None - out_subset = out_subset.view(-1) if out_subset is not None else None - if colscale is not None: - assert x0 is not None, "x0 is required to compute the gradient of colscale" - dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, - dxmat, - xmat, - x0mat, - dmask, - mu, - rsigma, - gamma, - None, - colscale, - x0_subset, - out_subset, - dropout_p, - rowscale_const, - x0_numrows, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - if colscale is None: - return dx0mat, dresidualmat, dgamma, dbeta - else: - dcolscale = rest[0] - return dx0mat, dresidualmat, dgamma, dbeta, dcolscale - - -def _dropout_add_layer_norm_parallel_residual_forward( - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma0.numel() - x0mat = x0.view((-1, hidden_size)) - x1mat = x1.view((-1, hidden_size)) if x1 is not None else None - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - ( - z0mat, - z1mat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( - x0mat, - x1mat, - residualmat, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask0 and dmask1 are None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma - - -def _dropout_add_layer_norm_parallel_residual_backward( - dz0, - dz1, - dx, - x, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - """ - hidden_size = gamma0.numel() - xmat = x.view((-1, hidden_size)) - dz0mat = dz0.view(xmat.shape) - dz1mat = dz1.view(xmat.shape) if dz1 is not None else None - dxmat = dx.view(xmat.shape) if dx is not None else None - ( - dx0mat, - dx1mat, - dresidualmat, - dgamma0, - dbeta0, - dgamma1, - dbeta1, - *rest, - ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( - dz0mat, - dz1mat, - dxmat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 - - -class DropoutAddLayerNormFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma = maybe_align(gamma.contiguous(), 16) - beta = maybe_align(beta.contiguous(), 16) if beta is not None else None - rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None - colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32, - is_rms_norm, - ) - # Only need to save x0 if we need to compute gradient wrt colscale - x0_saved = x0 if colscale is not None else None - ctx.save_for_backward( - xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta is not None - if not return_dmask: - return ( - zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) - ) - else: - dmask = ( - dmask.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask) - return ( - (zmat.view(x0.shape), dmask) - if not prenorm - else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) - ) - - @staticmethod - def backward(ctx, dz, *args): - # assert dz.is_contiguous() - dz = maybe_align(dz.contiguous(), 16) # this happens! - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors - # x0 is None if colscale is None - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - dropout_p, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(x.shape) - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - dcolscale = rest[0] if colscale is not None else None - return ( - dx0, - dresidual, - dgamma, - dbeta if ctx.has_beta else None, - None, - dcolscale, - None, - None, - None, - None, - None, - None, - ) - - -class DropoutAddLayerNormSubsetFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma = maybe_align(gamma.contiguous(), 16) - beta = maybe_align(beta.contiguous(), 16) if beta is not None else None - colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - is_rms_norm, - ) - # Only need to save x0 if we need to compute gradient wrt colscale - x0_saved = x0 if colscale is not None else None - x_shape = (-1, *x0.shape[1:]) - ctx.save_for_backward( - xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.rowscale_const = rowscale_const - ctx.x0_numrows = x0.shape[:-1].numel() - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta is not None - z_shape = (-1, *x0.shape[1:]) - if not return_dmask: - return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) - else: - z = zmat.view(z_shape) - dmask = ( - dmask.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask) - return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) - - @staticmethod - def backward(ctx, dz, *args): - # assert dz.is_contiguous() - dz = maybe_align(dz.contiguous(), 16) # this happens! - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors - # x0 is None if colscale is None - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - colscale, - x0_subset, - out_subset, - dropout_p, - ctx.rowscale_const, - ctx.x0_numrows, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(-1, *x.shape[1:]) - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - dcolscale = rest[0] if colscale is not None else None - return ( - dx0, - dresidual, - dgamma, - dbeta if ctx.has_beta else None, - dcolscale, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma0 = maybe_align(gamma0.contiguous(), 16) - beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None - gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None - beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None - ( - z0mat, - z1mat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - ) = _dropout_add_layer_norm_parallel_residual_forward( - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32, - is_rms_norm, - ) - ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.has_x1 = x1 is not None - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta0 is not None - z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) - if not return_dmask: - return z if not prenorm else (*z, xmat.view(x0.shape)) - else: - dmask0 = ( - dmask0.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - dmask1 = ( - dmask1.view(x0.shape) - if dropout_p > 0.0 and x1 is not None - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask0) - ctx.mark_non_differentiable(dmask1) - return ( - (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) - ) - - @staticmethod - def backward(ctx, dz0, dz1, *args): - dz0 = maybe_align(dz0.contiguous(), 16) # this happens! - dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors - dropout_p = ctx.dropout_p - has_x1 = ctx.has_x1 - has_residual = ctx.has_residual - ( - dx0mat, - dx1mat, - dresidualmat, - dgamma0, - dbeta0, - dgamma1, - dbeta1, - ) = _dropout_add_layer_norm_parallel_residual_backward( - dz0, - dz1, - dx, - x, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(x.shape) - dx1 = dx1mat.view(x.shape) if dx1mat is not None else None - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - return ( - dx0, - dx1, - dresidual, - dgamma0, - dbeta0 if ctx.has_beta else None, - dgamma1, - dbeta1 if ctx.has_beta else None, - None, - None, - None, - None, - None, - None, - ) - - -def layer_norm(x, weight, bias, epsilon): - return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) - - -def dropout_add_layer_norm( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - rowscale=None, - layerscale=None, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormFn.apply( - x0, - residual, - weight, - bias, - rowscale, - layerscale, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -def dropout_add_layer_norm_subset( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - layerscale=None, - x0_subset=None, - out_subset=None, - rowscale_const=1.0, - out_numrows=0, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormSubsetFn.apply( - x0, - residual, - weight, - bias, - layerscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -def dropout_add_layer_norm_parallel_residual( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormParallelResidualFn.apply( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -class DropoutAddLayerNorm(torch.nn.Module): - def __init__( - self, - hidden_size, - prenorm=False, - p=0.0, - eps=1e-5, - residual_in_fp32=False, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.prenorm = prenorm - self.p = p - self.eps = eps - self.residual_in_fp32 = residual_in_fp32 - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - init.zeros_(self.bias) - - def forward(self, x0, residual=None): - return dropout_add_layer_norm( - x0, - residual, - self.weight, - self.bias, - self.p if self.training else 0.0, - self.eps, - prenorm=self.prenorm, - residual_in_fp32=self.residual_in_fp32, - ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/rms_norm.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/rms_norm.py deleted file mode 100644 index 068348d61290e3839dd082b540d898578ba1e8e2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/rms_norm.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) 2022, Tri Dao. -# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py - -import torch -from torch.nn import init - -from flash_attn.ops.layer_norm import ( - DropoutAddLayerNormFn, - DropoutAddLayerNormParallelResidualFn, - DropoutAddLayerNormSubsetFn, -) - - -def rms_norm(x, weight, epsilon): - return DropoutAddLayerNormFn.apply( - x, None, weight, None, None, None, 0.0, epsilon, False, False, True - ) - - -def dropout_add_rms_norm( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - rowscale=None, - layerscale=None, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormFn.apply( - x0, - residual, - weight, - bias, - rowscale, - layerscale, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - True, - return_dropout_mask, - ) - - -def dropout_add_rms_norm_subset( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - layerscale=None, - x0_subset=None, - out_subset=None, - rowscale_const=1.0, - out_numrows=0, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormSubsetFn.apply( - x0, - residual, - weight, - bias, - layerscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - prenorm, - True, - return_dropout_mask, - ) - - -def dropout_add_rms_norm_parallel_residual( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormParallelResidualFn.apply( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - True, - return_dropout_mask, - ) - - -class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - - def forward(self, x): - return rms_norm(x, self.weight, self.eps) - - -class DropoutAddRMSNorm(torch.nn.Module): - def __init__( - self, - hidden_size, - prenorm=False, - p=0.0, - eps=1e-5, - residual_in_fp32=False, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.prenorm = prenorm - self.p = p - self.eps = eps - self.residual_in_fp32 = residual_in_fp32 - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - - def forward(self, x0, residual=None): - return dropout_add_rms_norm( - x0, - residual, - self.weight, - None, - self.p if self.training else 0.0, - self.eps, - prenorm=self.prenorm, - residual_in_fp32=self.residual_in_fp32, - ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/__init__.py deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/cross_entropy.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/cross_entropy.py deleted file mode 100644 index 1b5a415b73f236f3e05fb14b9141959559e18526..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/cross_entropy.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -from typing import Tuple, Optional, Union - -import torch -import torch.nn.functional as F - -import triton -import triton.language as tl - -# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for -# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent -# version of PyTorch. The following 2 lines are for backward compatibility with -# older PyTorch. -if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base - - -@triton.heuristics( - { - "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, - } -) -@triton.jit -def cross_entropy_fwd_kernel( - loss_ptr, # data ptrs - lse_ptr, - z_loss_ptr, - logits_ptr, - labels_ptr, - smoothing, - logit_scale, - lse_square_scale, - ignore_index, - total_classes, - class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes - n_cols, # shapes - logits_row_stride, # strides - BLOCK_SIZE: tl.constexpr, - HAS_SMOOTHING: tl.constexpr, - # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE - SPLIT: tl.constexpr, - PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0) -): - row_idx = tl.program_id(0) - logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) - sum_logits = 0.0 # For smoothing - if not PRECOMPUTED_LSE: - # Statistics for online softmax - m_i = -float("inf") - l_i = 0.0 - for col_offset in range(0, n_cols, BLOCK_SIZE): - cols = col_offset + tl.arange(0, BLOCK_SIZE) - logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( - tl.float32 - ) * logit_scale - if HAS_SMOOTHING: - sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0)) - m_i_new = tl.maximum(m_i, tl.max(logits)) - l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new)) - m_i = m_i_new - lse = tl.log(l_i) + m_i - tl.store(lse_ptr + row_idx, lse) - else: - lse = tl.load(lse_ptr + row_idx) - label_idx = tl.load(labels_ptr + row_idx) - if label_idx == ignore_index: - loss = 0.0 - z_loss = 0.0 - else: - label_idx -= class_start_idx - if label_idx >= 0 and label_idx < n_cols: - logits_label = tl.load(logits_ptr + label_idx) * logit_scale - if HAS_SMOOTHING: - loss = ( - (lse if not SPLIT else 0.0) - - smoothing * sum_logits / total_classes - - (1 - smoothing) * logits_label - ) - else: - loss = (lse if not SPLIT else 0.0) - logits_label - else: - # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss - if HAS_SMOOTHING: - loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) - else: - loss = 0.0 - if not SPLIT: - z_loss = lse_square_scale * lse * lse - loss += z_loss - else: - z_loss = 0.0 - tl.store(loss_ptr + row_idx, loss) - if not SPLIT: - tl.store(z_loss_ptr + row_idx, z_loss) - - -@triton.heuristics( - { - "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, - } -) -@triton.jit -def cross_entropy_bwd_kernel( - dlogits_ptr, # data ptrs - dloss_ptr, - logits_ptr, - lse_ptr, - labels_ptr, - smoothing, - logit_scale, - lse_square_scale, - ignore_index, - total_classes, - class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes - n_cols, # shapes - logits_row_stride, # strides - dlogits_row_stride, - dloss_row_stride, - BLOCK_SIZE: tl.constexpr, - HAS_SMOOTHING: tl.constexpr, -): - row_idx = tl.program_id(0) - col_block_idx = tl.program_id(1) - logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) - dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) - col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - label_idx = tl.load(labels_ptr + row_idx) - if label_idx != ignore_index: - dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) - else: - dloss = 0.0 - logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( - tl.float32 - ) * logit_scale - lse = tl.load(lse_ptr + row_idx) - probs = tl.exp(logits - lse) - probs += 2.0 * lse_square_scale * lse * probs - label_idx -= class_start_idx - if HAS_SMOOTHING: - smooth_positive = 1.0 - smoothing - smooth_negative = smoothing / total_classes - probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative - else: - probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) - tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) - - -class CrossEntropyLoss(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - logits, - labels, - precomputed_lse=None, - smoothing=0.0, - logit_scale=1.0, - lse_square_scale=0.0, - ignore_index=-100, - inplace_backward=False, - process_group=None, - ): - # For some reason Triton generates wrong code when labels has dtype long and its address - # is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index. - if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: - labels = F.pad(labels, (0, 1))[..., :-1] - assert labels.data_ptr() % 16 == 0 - assert logit_scale > 0.0 - n_rows, n_cols = logits.shape - assert labels.shape == (n_rows,) - world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) - total_classes = world_size * n_cols - rank = 0 if process_group is None else torch.distributed.get_rank(process_group) - class_start_idx = rank * n_cols - use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0 - - if logits.stride(-1) != 1: - logits = logits.contiguous() - MAX_BLOCK_SIZE = 16 * 1024 - BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) - num_warps = ( - 4 - if BLOCK_SIZE < 2048 - else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) - ) - losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - if use_precomputed_lse: - assert precomputed_lse.shape == (n_rows,) - lse = precomputed_lse.contiguous() - else: - lse = torch.empty(n_rows, dtype=torch.float, device=logits.device) - z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(logits.device.index): - cross_entropy_fwd_kernel[(n_rows,)]( - losses, # data ptrs - lse, - z_losses, - logits, - labels, - smoothing, - logit_scale, - lse_square_scale, - ignore_index, - total_classes, - class_start_idx, - n_cols, # shapes - logits.stride(0), # strides - BLOCK_SIZE=BLOCK_SIZE, # constants - SPLIT=world_size > 1, - PRECOMPUTED_LSE=use_precomputed_lse, - num_warps=num_warps, - ) - - if world_size > 1: - # If there's no smoothing, if labels are in the vocab of this partition, losses contains - # - predicted logit, and 0 otherwise. - # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains - # -0.9 * predicted logit - 0.1 * sum logit / total_classes. - # For labels not in the vocab of this partition, losses contains - # -0.1 * sum logit / total_classes. - if world_size > 1: - lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) - torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) - handle_losses = torch.distributed.all_reduce( - losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True - ) - lse = torch.logsumexp(lse_allgather, dim=0) - handle_losses.wait() - # After the allreduce, if there's no smoothing, the total losses are - predicted_logit, - # we just have to add the (global) lse. - # If there's smoothing=0.1, the total losses are - # -0.9 * predicted_logit - 0.1 * sum logit / total_classes. - # Again, we just have to add the (global) lse. - losses += lse - if lse_square_scale != 0.0: - z_losses = lse_square_scale * lse.square() - z_losses.masked_fill_(labels == ignore_index, 0.0) - losses += z_losses - else: - z_losses = torch.zeros_like(losses) - losses.masked_fill_(labels == ignore_index, 0.0) - - ctx.save_for_backward(logits, lse, labels) - ctx.mark_non_differentiable(z_losses) - ctx.smoothing = smoothing - ctx.logit_scale = logit_scale - ctx.lse_square_scale = lse_square_scale - ctx.ignore_index = ignore_index - ctx.total_classes = total_classes - ctx.class_start_idx = class_start_idx - ctx.inplace_backward = inplace_backward - return losses, z_losses - - @staticmethod - def backward(ctx, grad_losses, grad_z_losses): - del grad_z_losses # z_losses are only for logging. - - logits, lse, labels = ctx.saved_tensors - dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) - n_rows, n_cols = logits.shape - BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) - num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) - grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(logits.device.index): - cross_entropy_bwd_kernel[grid]( - dlogits, # data ptrs - grad_losses, - logits, - lse, - labels, - ctx.smoothing, - ctx.logit_scale, - ctx.lse_square_scale, - ctx.ignore_index, - ctx.total_classes, - ctx.class_start_idx, - n_cols, # shapes - logits.stride(0), # strides - dlogits.stride(0), - grad_losses.stride(0), - BLOCK_SIZE=BLOCK_SIZE, # constants - num_warps=num_warps, - ) - return dlogits, None, None, None, None, None, None, None, None, None - - -def cross_entropy_loss( - logits: torch.Tensor, - labels: torch.Tensor, - precomputed_lse: Optional[torch.Tensor] = None, - label_smoothing: float = 0.0, - logit_scale: float = 1.0, - lse_square_scale: float = 0.0, - ignore_index=-100, - inplace_backward: bool = False, - process_group=None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Arguments: - logits: (batch, vocab_size) - labels: (batch,) - label_smoothing: float - logit_scale: float. Multiply logits by this scale before calculating the loss. - lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. - This is also referred to as "z-loss". - ignore_index: int. If labels == ignore_index, the loss is set to 0.0. - inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. - This saves memory. - process_group: if not None, we're doing Tensor Parallel: each process is responsible for - one part of the vocab. The loss will be aggregated across processes. - Returns: - losses: (batch,), float - z_losses: (batch,), float - """ - return CrossEntropyLoss.apply( - logits, - labels, - precomputed_lse, - label_smoothing, - logit_scale, - lse_square_scale, - ignore_index, - inplace_backward, - process_group, - ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/k_activations.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/k_activations.py deleted file mode 100644 index efb83c358eb4a85d069ee340a3c83f418f9a805b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/k_activations.py +++ /dev/null @@ -1,162 +0,0 @@ -# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import math -from enum import Enum -from typing import Optional - -import triton -import triton.language as tl - -_sqrt2pi = math.sqrt(2.0 / math.pi) -_sqrt1_2 = math.sqrt(1.0 / 2) -_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) - - -class Activation(str, Enum): - SquaredReLU = "squared_relu" - GeLU = "gelu" - GeLUApprox = "gelu_approx" - LeakyReLU = "leaky_relu" - ReLU = "relu" - - -def get_triton_activation_kernel(activation: Optional[Activation]): - return ( - { - Activation.ReLU: relu, - Activation.LeakyReLU: leaky_relu, - Activation.GeLU: gelu, - Activation.GeLUApprox: gelu_approx, - Activation.SquaredReLU: squared_relu, - }[activation] - if activation - else None - ) - - -def get_triton_activation_bwd_kernel(activation: Optional[Activation]): - return ( - { - Activation.ReLU: relu_grad, - Activation.LeakyReLU: leaky_relu_grad, - Activation.GeLU: gelu_grad, - Activation.GeLUApprox: gelu_approx_grad, - Activation.SquaredReLU: squared_relu_grad, - }[activation] - if activation - else None - ) - - -@triton.jit -def tanh(x): - # Tanh is just a scaled sigmoid - return 2 * tl.sigmoid(2 * x) - 1 - - -@triton.jit -def cosh(x): - exp_x = tl.exp(x) - return (exp_x + 1.0 / exp_x) * 0.5 - - -# a Triton implementation of the most used activations -# See for instance http://arxiv.org/abs/1606.08415 for an overview - -# ReLU -@triton.jit -def relu(x): - """ - ReLU_ activation function - - .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html - """ - zero = 0.0 - return tl.where(x >= 0, x, zero.to(x.dtype)) - - -@triton.jit -def relu_grad(x): - # ReLU is different from other activations - # in that it does not require the input to retrospectively compute its gradient - # here the input is the downstream gradient, and we return the upstream gradient directly - zero = 0.0 - one = 1.0 - return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype)) - - -@triton.jit -def squared_relu(x): - """ - Squared ReLU activation, as proposed in the Primer_ paper. - - .. _Primer: https://arxiv.org/abs/2109.08668 - """ - x_ = relu(x) - return (x_ * x_).to(x.dtype) - - -@triton.jit -def squared_relu_grad(x): - return tl.where(x >= 0, 2.0 * x, 0.0) - - -# Leaky ReLU -@triton.jit -def leaky_relu(x): - """ - LeakyReLU_ activation - - .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html - """ - scale = 0.01 + 0.0 - scale = scale.to(x.dtype) - return tl.where(x >= 0, x, scale * x) - - -@triton.jit -def leaky_relu_grad(x): - min_grad = 0.01 - max_grad = 1 - - min_grad = min_grad.to(x.dtype) - max_grad = max_grad.to(x.dtype) - - return tl.where(x >= 0, max_grad, min_grad) - - -@triton.jit -def gelu(x): - """Gaussian Error Linear Unit (GELU)""" - return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) - - -@triton.jit -def gelu_grad(x): - cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) - pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization - return cdf + x * pdf - - -@triton.jit -def gelu_approx(x): - """ - GeLU_ activation - Gaussian error linear unit, with tanh approximation - - .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf - """ - return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x))) - - -@triton.jit -def gelu_approx_grad(x): - # CREDITS: Fast implementation proposed in - # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 - tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/layer_norm.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/layer_norm.py deleted file mode 100644 index 192cee474b160d1876fafd14c5e3d695e8ff237f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/layer_norm.py +++ /dev/null @@ -1,1252 +0,0 @@ -# Copyright (c) 2024, Tri Dao. -# Implement dropout + residual + layer_norm / rms_norm. - -# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html -# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. -# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. -# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. - -import math -from typing import Optional, List - -import torch -import torch.nn.functional as F -from torch import Tensor - -import triton -import triton.language as tl - -from flash_attn.utils.torch import custom_fwd, custom_bwd -from flash_attn.utils.library import triton_op - - -def maybe_contiguous_lastdim(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def maybe_contiguous(x): - return x.contiguous() if x is not None else None - - -def triton_autotune_configs(): - # Return configs with a valid warp count for the current device - configs = [] - # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 - max_threads_per_block = 1024 - # Default to warp size 32 if not defined by device - warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) - # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit - return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] - if warp_count * warp_size <= max_threads_per_block] - # return [triton.Config({}, num_warps=8)] - - -def layer_norm_ref( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - zero_centered_weight=False, - dropout_mask=None, - dropout_mask1=None, - upcast=False, -): - dtype = x.dtype - if upcast: - x = x.float() - weight = weight.float() - bias = bias.float() if bias is not None else None - residual = residual.float() if residual is not None else residual - x1 = x1.float() if x1 is not None else None - weight1 = weight1.float() if weight1 is not None else None - bias1 = bias1.float() if bias1 is not None else None - if zero_centered_weight: - weight = weight + 1.0 - if weight1 is not None: - weight1 = weight1 + 1.0 - if x1 is not None: - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - if rowscale is not None: - x = x * rowscale[..., None] - if dropout_p > 0.0: - if dropout_mask is not None: - x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) - else: - x = F.dropout(x, p=dropout_p) - if x1 is not None: - if dropout_mask1 is not None: - x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) - else: - x1 = F.dropout(x1, p=dropout_p) - if x1 is not None: - x = x + x1 - if residual is not None: - x = (x + residual).to(x.dtype) - out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( - dtype - ) - if weight1 is None: - return out if not prenorm else (out, x) - else: - out1 = F.layer_norm( - x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps - ).to(dtype) - return (out, out1) if not prenorm else (out, out1, x) - - -def rms_norm_ref( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - zero_centered_weight=False, - dropout_mask=None, - dropout_mask1=None, - upcast=False, -): - dtype = x.dtype - if upcast: - x = x.float() - weight = weight.float() - bias = bias.float() if bias is not None else None - residual = residual.float() if residual is not None else residual - x1 = x1.float() if x1 is not None else None - weight1 = weight1.float() if weight1 is not None else None - bias1 = bias1.float() if bias1 is not None else None - if zero_centered_weight: - weight = weight + 1.0 - if weight1 is not None: - weight1 = weight1 + 1.0 - if x1 is not None: - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - if rowscale is not None: - x = x * rowscale[..., None] - if dropout_p > 0.0: - if dropout_mask is not None: - x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) - else: - x = F.dropout(x, p=dropout_p) - if x1 is not None: - if dropout_mask1 is not None: - x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) - else: - x1 = F.dropout(x1, p=dropout_p) - if x1 is not None: - x = x + x1 - if residual is not None: - x = (x + residual).to(x.dtype) - rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) - out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) - if weight1 is None: - return out if not prenorm else (out, x) - else: - out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( - dtype - ) - return (out, out1) if not prenorm else (out, out1, x) - - -@triton.autotune( - configs=triton_autotune_configs(), - key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS", "HAS_X1", "HAS_W1", "HAS_B1"], -) -# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) -# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) -# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) -# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) -@triton.jit -def _layer_norm_fwd_1pass_kernel( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - RESIDUAL, # pointer to the residual - X1, - W1, - B1, - Y1, - RESIDUAL_OUT, # pointer to the residual - ROWSCALE, - SEEDS, # Dropout seeds for each row - DROPOUT_MASK, - DROPOUT_MASK1, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_res_row, - stride_res_out_row, - stride_x1_row, - stride_y1_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - dropout_p, # Dropout probability - zero_centered_weight, # If true, add 1.0 to the weight - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_RESIDUAL: tl.constexpr, - STORE_RESIDUAL_OUT: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_DROPOUT: tl.constexpr, - STORE_DROPOUT_MASK: tl.constexpr, - HAS_ROWSCALE: tl.constexpr, - HAS_X1: tl.constexpr, - HAS_W1: tl.constexpr, - HAS_B1: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - X += row * stride_x_row - Y += row * stride_y_row - if HAS_RESIDUAL: - RESIDUAL += row * stride_res_row - if STORE_RESIDUAL_OUT: - RESIDUAL_OUT += row * stride_res_out_row - if HAS_X1: - X1 += row * stride_x1_row - if HAS_W1: - Y1 += row * stride_y1_row - # Compute mean and variance - cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + row).to(tl.float32) - x *= rowscale - if HAS_DROPOUT: - # Compute dropout mask - # 7 rounds is good enough, and reduces register pressure - keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) - if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) - if HAS_X1: - x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) - x1 *= rowscale - if HAS_DROPOUT: - # Compute dropout mask - # 7 rounds is good enough, and reduces register pressure - keep_mask = ( - tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - ) - x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) - if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) - x += x1 - if HAS_RESIDUAL: - residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) - x += residual - if STORE_RESIDUAL_OUT: - tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) - if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w += 1.0 - if HAS_BIAS: - b = tl.load(B + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - y = x_hat * w + b if HAS_BIAS else x_hat * w - # Write output - tl.store(Y + cols, y, mask=mask) - if HAS_W1: - w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w1 += 1.0 - if HAS_B1: - b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) - y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 - tl.store(Y1 + cols, y1, mask=mask) - - -def _layer_norm_fwd( - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - residual: Optional[Tensor] = None, - x1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - residual_dtype: Optional[torch.dtype] = None, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - return_dropout_mask: bool = False, - out: Optional[Tensor] = None, - residual_out: Optional[Tensor] = None -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library - # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None - # so that _layer_norm_fwd_impl doesn't have to return them. - if out is None: - out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) - if residual is not None: - residual_dtype = residual.dtype - if residual_out is None and ( - residual is not None - or (residual_dtype is not None and residual_dtype != x.dtype) - or dropout_p > 0.0 - or rowscale is not None - or x1 is not None - ): - residual_out = torch.empty_like( - x, dtype=residual_dtype if residual_dtype is not None else x.dtype - ) - else: - residual_out = None - y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( - x, - weight, - bias, - eps, - out, - residual=residual, - x1=x1, - weight1=weight1, - bias1=bias1, - dropout_p=dropout_p, - rowscale=rowscale, - zero_centered_weight=zero_centered_weight, - is_rms_norm=is_rms_norm, - return_dropout_mask=return_dropout_mask, - residual_out=residual_out, - ) - # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 - if residual_out is None: - residual_out = x - return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 - - -# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema -# since we're returning a tuple of tensors -@triton_op("flash_attn::layer_norm_fwd_impl", mutates_args={"out", "residual_out"}, - schema="(Tensor x, Tensor weight, Tensor bias, float eps, Tensor(a!) out, Tensor? residual, Tensor? x1, Tensor? weight1, Tensor? bias1, float dropout_p, Tensor? rowscale, bool zero_centered_weight, bool is_rms_norm, bool return_dropout_mask, Tensor(a!)? residual_out) -> (Tensor y1, Tensor mean, Tensor rstd, Tensor seeds, Tensor dropout_mask, Tensor dropout_mask1)") -def _layer_norm_fwd_impl( - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - out: Tensor, - residual: Optional[Tensor] = None, - x1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - return_dropout_mask: bool = False, - residual_out: Optional[Tensor] = None -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - M, N = x.shape - assert x.stride(-1) == 1 - if residual is not None: - assert residual.stride(-1) == 1 - assert residual.shape == (M, N) - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - if x1 is not None: - assert x1.shape == x.shape - assert rowscale is None - assert x1.stride(-1) == 1 - if weight1 is not None: - assert weight1.shape == (N,) - assert weight1.stride(-1) == 1 - if bias1 is not None: - assert bias1.shape == (N,) - assert bias1.stride(-1) == 1 - if rowscale is not None: - assert rowscale.is_contiguous() - assert rowscale.shape == (M,) - assert out.shape == x.shape - assert out.stride(-1) == 1 - if residual_out is not None: - assert residual_out.shape == x.shape - assert residual_out.stride(-1) == 1 - if weight1 is not None: - y1 = torch.empty_like(out) - assert y1.stride(-1) == 1 - else: - y1 = None - mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None - rstd = torch.empty((M,), dtype=torch.float32, device=x.device) - if dropout_p > 0.0: - seeds = torch.randint( - 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 - ) - else: - seeds = None - if return_dropout_mask and dropout_p > 0.0: - dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) - if x1 is not None: - dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) - else: - dropout_mask1 = None - else: - dropout_mask, dropout_mask1 = None, None - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - with torch.cuda.device(x.device.index): - torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( - x, - out, - weight, - bias, - residual, - x1, - weight1, - bias1, - y1, - residual_out, - rowscale, - seeds, - dropout_mask, - dropout_mask1, - mean, - rstd, - x.stride(0), - out.stride(0), - residual.stride(0) if residual is not None else 0, - residual_out.stride(0) if residual_out is not None else 0, - x1.stride(0) if x1 is not None else 0, - y1.stride(0) if y1 is not None else 0, - M, - N, - eps, - dropout_p, - # Passing bool make torch inductor very unhappy since it then tries to compare to int_max - int(zero_centered_weight), - is_rms_norm, - BLOCK_N, - residual is not None, - residual_out is not None, - bias is not None, - dropout_p > 0.0, - dropout_mask is not None, - rowscale is not None, - HAS_X1=x1 is not None, - HAS_W1=weight1 is not None, - HAS_B1=bias1 is not None, - ) - return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 - - -@triton.autotune( - configs=triton_autotune_configs(), - key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], -) -# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) -# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) -# @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) -# @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) -# @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) -# @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) -# @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) -@triton.jit -def _layer_norm_bwd_kernel( - X, # pointer to the input - W, # pointer to the weights - B, # pointer to the biases - Y, # pointer to the output to be recomputed - DY, # pointer to the output gradient - DX, # pointer to the input gradient - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - DRESIDUAL, - W1, - DY1, - DX1, - DW1, - DB1, - DRESIDUAL_IN, - ROWSCALE, - SEEDS, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_dy_row, - stride_dx_row, - stride_dres_row, - stride_dy1_row, - stride_dx1_row, - stride_dres_in_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - dropout_p, - zero_centered_weight, - rows_per_program, - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_DRESIDUAL: tl.constexpr, - STORE_DRESIDUAL: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_DROPOUT: tl.constexpr, - HAS_ROWSCALE: tl.constexpr, - HAS_DY1: tl.constexpr, - HAS_DX1: tl.constexpr, - HAS_B1: tl.constexpr, - RECOMPUTE_OUTPUT: tl.constexpr, -): - # Map the program id to the elements of X, DX, and DY it should compute. - row_block_id = tl.program_id(0) - row_start = row_block_id * rows_per_program - # Do not early exit if row_start >= M, because we need to write DW and DB - cols = tl.arange(0, BLOCK_N) - mask = cols < N - X += row_start * stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += row_start * stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += row_start * stride_dres_in_row - DY += row_start * stride_dy_row - DX += row_start * stride_dx_row - if HAS_DY1: - DY1 += row_start * stride_dy1_row - if HAS_DX1: - DX1 += row_start * stride_dx1_row - if RECOMPUTE_OUTPUT: - Y += row_start * stride_y_row - w = tl.load(W + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w += 1.0 - if RECOMPUTE_OUTPUT and HAS_BIAS: - b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) - if HAS_DY1: - w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w1 += 1.0 - dw = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_BIAS: - db = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_DY1: - dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_B1: - db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) - row_end = min((row_block_id + 1) * rows_per_program, M) - for row in range(row_start, row_end): - # Load data to SRAM - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) - dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) - if HAS_DY1: - dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) - if not IS_RMS_NORM: - mean = tl.load(Mean + row) - rstd = tl.load(Rstd + row) - # Compute dx - xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - xhat = tl.where(mask, xhat, 0.0) - if RECOMPUTE_OUTPUT: - y = xhat * w + b if HAS_BIAS else xhat * w - tl.store(Y + cols, y, mask=mask) - wdy = w * dy - dw += dy * xhat - if HAS_BIAS: - db += dy - if HAS_DY1: - wdy += w1 * dy1 - dw1 += dy1 * xhat - if HAS_B1: - db1 += dy1 - if not IS_RMS_NORM: - c1 = tl.sum(xhat * wdy, axis=0) / N - c2 = tl.sum(wdy, axis=0) / N - dx = (wdy - (xhat * c1 + c2)) * rstd - else: - c1 = tl.sum(xhat * wdy, axis=0) / N - dx = (wdy - xhat * c1) * rstd - if HAS_DRESIDUAL: - dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) - dx += dres - # Write dx - if STORE_DRESIDUAL: - tl.store(DRESIDUAL_IN + cols, dx, mask=mask) - if HAS_DX1: - if HAS_DROPOUT: - keep_mask = ( - tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - ) - dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) - else: - dx1 = dx - tl.store(DX1 + cols, dx1, mask=mask) - if HAS_DROPOUT: - keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + row).to(tl.float32) - dx *= rowscale - tl.store(DX + cols, dx, mask=mask) - - X += stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += stride_dres_in_row - if RECOMPUTE_OUTPUT: - Y += stride_y_row - DY += stride_dy_row - DX += stride_dx_row - if HAS_DY1: - DY1 += stride_dy1_row - if HAS_DX1: - DX1 += stride_dx1_row - tl.store(DW + row_block_id * N + cols, dw, mask=mask) - if HAS_BIAS: - tl.store(DB + row_block_id * N + cols, db, mask=mask) - if HAS_DY1: - tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) - if HAS_B1: - tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) - - -def _layer_norm_bwd( - dy: Tensor, - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - mean: Tensor, - rstd: Tensor, - dresidual: Optional[Tensor] = None, - dy1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - seeds: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - has_residual: bool = False, - has_x1: bool = False, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - x_dtype: Optional[torch.dtype] = None, - recompute_output: bool = False, -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - # Need to wrap to handle the case where dresidual_in or dx1 are aliases of x, - # which makes torch.library unhappy - dx, dw, db, dresidual_in, dx1, dw1, db1, y = _layer_norm_bwd_impl( - dy, - x, - weight, - bias, - eps, - mean, - rstd, - dresidual, - dy1, - weight1, - bias1, - seeds, - dropout_p, - rowscale, - has_residual, - has_x1, - zero_centered_weight, - is_rms_norm, - x_dtype=x_dtype, - recompute_output=recompute_output, - ) - # Don't need to compute dresidual_in separately in this case - if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: - dresidual_in = dx - if has_x1 and dropout_p == 0.0: - dx1 = dx - return dx, dw, db, dresidual_in, dx1, dw1, db1, y - - - -@triton_op("flash_attn::layer_norm_bwd_impl", mutates_args={}, - schema="(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)", - allow_decomposition=False, # Don't let torch.compile trace inside - ) -def _layer_norm_bwd_impl( - dy: Tensor, - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - mean: Tensor, - rstd: Tensor, - dresidual: Optional[Tensor] = None, - dy1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - seeds: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - has_residual: bool = False, - has_x1: bool = False, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - x_dtype: Optional[torch.dtype] = None, - recompute_output: bool = False, -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - M, N = x.shape - assert x.stride(-1) == 1 - dy = maybe_contiguous_lastdim(dy) - assert dy.stride(-1) == 1 - assert dy.shape == (M, N) - if dresidual is not None: - dresidual = maybe_contiguous_lastdim(dresidual) - assert dresidual.stride(-1) == 1 - assert dresidual.shape == (M, N) - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - if dy1 is not None: - dy1 = maybe_contiguous_lastdim(dy1) - assert weight1 is not None - assert dy1.shape == dy.shape - assert dy1.stride(-1) == 1 - if weight1 is not None: - assert weight1.shape == (N,) - assert weight1.stride(-1) == 1 - if bias1 is not None: - assert bias1.shape == (N,) - assert bias1.stride(-1) == 1 - if seeds is not None: - assert seeds.is_contiguous() - assert seeds.shape == (M if not has_x1 else M * 2,) - if rowscale is not None: - assert rowscale.is_contiguous() - assert rowscale.shape == (M,) - # allocate output - dx = ( - torch.empty_like(x) - if x_dtype is None - else torch.empty(M, N, dtype=x_dtype, device=x.device) - ) - dresidual_in = ( - torch.empty_like(x) - if has_residual - and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) - else None - ) - dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None - y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None - if recompute_output: - assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" - - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the - # latency of the gmem reads/writes, but will increase the time of summing up dw / db. - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 - _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) - _db = ( - torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) - if bias is not None - else None - ) - _dw1 = torch.empty_like(_dw) if weight1 is not None else None - _db1 = torch.empty_like(_db) if bias1 is not None else None - rows_per_program = math.ceil(M / sm_count) - grid = (sm_count,) - with torch.cuda.device(x.device.index): - torch.library.wrap_triton(_layer_norm_bwd_kernel)[grid]( - x, - weight, - bias, - y, - dy, - dx, - _dw, - _db, - dresidual, - weight1, - dy1, - dx1, - _dw1, - _db1, - dresidual_in, - rowscale, - seeds, - mean, - rstd, - x.stride(0), - 0 if not recompute_output else y.stride(0), - dy.stride(0), - dx.stride(0), - dresidual.stride(0) if dresidual is not None else 0, - dy1.stride(0) if dy1 is not None else 0, - dx1.stride(0) if dx1 is not None else 0, - dresidual_in.stride(0) if dresidual_in is not None else 0, - M, - N, - eps, - dropout_p, - # Passing bool make torch inductor very unhappy since it then tries to compare to int_max - int(zero_centered_weight), - rows_per_program, - is_rms_norm, - BLOCK_N, - dresidual is not None, - dresidual_in is not None, - bias is not None, - dropout_p > 0.0, - HAS_ROWSCALE=rowscale is not None, - HAS_DY1=dy1 is not None, - HAS_DX1=dx1 is not None, - HAS_B1=bias1 is not None, - RECOMPUTE_OUTPUT=y is not None, - ) - dw = _dw.sum(0).to(weight.dtype) - db = _db.sum(0).to(bias.dtype) if bias is not None else None - dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None - db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None - # dresidual_in and dx1 could be None, the wrapper will handle assigning them from dx - return dx, dw, db, dresidual_in, dx1, dw1, db1, y - - -class LayerNormFn(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - zero_centered_weight=False, - is_rms_norm=False, - return_dropout_mask=False, - out_dtype=None, - out=None, - residual_out=None - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) - if residual is not None: - assert residual.shape == x_shape_og - residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) - if x1 is not None: - assert x1.shape == x_shape_og - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) - weight = weight.contiguous() - bias = maybe_contiguous(bias) - weight1 = maybe_contiguous(weight1) - bias1 = maybe_contiguous(bias1) - if rowscale is not None: - rowscale = rowscale.reshape(-1).contiguous() - residual_dtype = ( - residual.dtype - if residual is not None - else (torch.float32 if residual_in_fp32 else None) - ) - if out is not None: - out = out.reshape(-1, out.shape[-1]) - if residual_out is not None: - residual_out = residual_out.reshape(-1, residual_out.shape[-1]) - y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( - x, - weight, - bias, - eps, - residual, - x1, - weight1, - bias1, - dropout_p=dropout_p, - rowscale=rowscale, - out_dtype=out_dtype, - residual_dtype=residual_dtype, - zero_centered_weight=zero_centered_weight, - is_rms_norm=is_rms_norm, - return_dropout_mask=return_dropout_mask, - out=out, - residual_out=residual_out, - ) - ctx.save_for_backward( - residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd - ) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.dropout_p = dropout_p - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.has_x1 = x1 is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - ctx.zero_centered_weight = zero_centered_weight - y = y.reshape(x_shape_og) - y1 = y1.reshape(x_shape_og) if y1 is not None else None - residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None - dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None - dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None - if not return_dropout_mask: - if weight1 is None: - return y if not prenorm else (y, residual_out) - else: - return (y, y1) if not prenorm else (y, y1, residual_out) - else: - if weight1 is None: - return ( - (y, dropout_mask, dropout_mask1) - if not prenorm - else (y, residual_out, dropout_mask, dropout_mask1) - ) - else: - return ( - (y, y1, dropout_mask, dropout_mask1) - if not prenorm - else (y, y1, residual_out, dropout_mask, dropout_mask1) - ) - - @staticmethod - def backward(ctx, dy, *args): - x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors - dy = dy.reshape(-1, dy.shape[-1]) - if weight1 is not None: - dy1, args = args[0], args[1:] - dy1 = dy1.reshape(-1, dy1.shape[-1]) - assert dy1.shape == x.shape - else: - dy1 = None - if ctx.prenorm: - dresidual = args[0] - dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dw, db, dresidual_in, dx1, dw1, db1, _ = _layer_norm_bwd( - dy, - x, - weight, - bias, - ctx.eps, - mean, - rstd, - dresidual, - dy1, - weight1, - bias1, - seeds, - ctx.dropout_p, - rowscale, - ctx.has_residual, - ctx.has_x1, - ctx.zero_centered_weight, - ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - recompute_output=False, - ) - return ( - dx.reshape(ctx.x_shape_og), - dw, - db, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, - dw1, - db1, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -def layer_norm_fn( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - zero_centered_weight=False, - is_rms_norm=False, - return_dropout_mask=False, - out_dtype=None, - out=None, - residual_out=None -): - return LayerNormFn.apply( - x, - weight, - bias, - residual, - x1, - weight1, - bias1, - eps, - dropout_p, - rowscale, - prenorm, - residual_in_fp32, - zero_centered_weight, - is_rms_norm, - return_dropout_mask, - out_dtype, - out, - residual_out - ) - - -def rms_norm_fn( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - zero_centered_weight=False, - return_dropout_mask=False, - out_dtype=None, - out=None, - residual_out=None -): - return LayerNormFn.apply( - x, - weight, - bias, - residual, - x1, - weight1, - bias1, - eps, - dropout_p, - rowscale, - prenorm, - residual_in_fp32, - zero_centered_weight, - True, - return_dropout_mask, - out_dtype, - out, - residual_out - ) - - -class RMSNorm(torch.nn.Module): - - def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, - device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - if dropout_p > 0.0: - self.drop = torch.nn.Dropout(dropout_p) - else: - self.drop = None - self.zero_centered_weight = zero_centered_weight - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - if not self.zero_centered_weight: - torch.nn.init.ones_(self.weight) - else: - torch.nn.init.zeros_(self.weight) - - def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): - return rms_norm_fn( - x, - self.weight, - self.bias, - residual=residual, - eps=self.eps, - dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, - prenorm=prenorm, - residual_in_fp32=residual_in_fp32, - zero_centered_weight=self.zero_centered_weight, - ) - - -class LayerNormLinearFn(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx, - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) - if residual is not None: - assert residual.shape == x_shape_og - residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) - norm_weight = norm_weight.contiguous() - norm_bias = maybe_contiguous(norm_bias) - residual_dtype = ( - residual.dtype - if residual is not None - else (torch.float32 if residual_in_fp32 else None) - ) - y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( - x, - norm_weight, - norm_bias, - eps, - residual, - out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"), - residual_dtype=residual_dtype, - is_rms_norm=is_rms_norm, - ) - y = y.reshape(x_shape_og) - dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype - linear_weight = linear_weight.to(dtype) - linear_bias = linear_bias.to(dtype) if linear_bias is not None else None - out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) - # We don't store y, will be recomputed in the backward pass to save memory - ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - ctx.linear_bias_is_none = linear_bias is None - return out if not prenorm else (out, residual_out.reshape(x_shape_og)) - - @staticmethod - @custom_bwd - def backward(ctx, dout, *args): - x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors - dout = dout.reshape(-1, dout.shape[-1]) - dy = F.linear(dout, linear_weight.t()) - dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) - dy = maybe_contiguous_lastdim(dy) - assert dy.shape == x.shape - if ctx.prenorm: - dresidual = args[0] - dresidual = maybe_contiguous_lastdim(dresidual.reshape(-1, dresidual.shape[-1])) - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( - dy, - x, - norm_weight, - norm_bias, - ctx.eps, - mean, - rstd, - dresidual=dresidual, - has_residual=ctx.has_residual, - is_rms_norm=ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - recompute_output=True, - ) - dlinear_weight = torch.einsum("bo,bi->oi", dout, y) - return ( - dx.reshape(ctx.x_shape_og), - dnorm_weight, - dnorm_bias, - dlinear_weight, - dlinear_bias, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - None, - None, - None, - None, - ) - - -def layer_norm_linear_fn( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, -): - return LayerNormLinearFn.apply( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual, - eps, - prenorm, - residual_in_fp32, - is_rms_norm, - ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/linear.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/linear.py deleted file mode 100644 index a8966dbc345ab0e593df0124451ee7be3dae131a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/linear.py +++ /dev/null @@ -1,594 +0,0 @@ -# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py -# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py -from typing import Optional - -import torch -import triton -import triton.language as tl -from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - -from flash_attn.ops.triton.k_activations import ( - gelu, - gelu_approx, - gelu_approx_grad, - gelu_grad, - squared_relu, - squared_relu_grad, -) - -# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications - - -def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - - -def get_configs_io_bound(): - configs = [] - for num_stages in [2, 3, 4, 5, 6]: - for block_m in [16, 32]: - for block_k in [32, 64]: - for block_n in [32, 64, 128, 256]: - num_warps = 2 if block_n <= 64 else 4 - configs.append( - triton.Config( - { - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SPLIT_K": 1, - }, - num_stages=num_stages, - num_warps=num_warps, - ) - ) - # split_k not used - # for split_k in [2, 4, 8, 16]: - # configs.append(triton.Config( - # {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - # num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) - return configs - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - # good for int8 - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - ] - + get_configs_io_bound(), - key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": estimate_matmul_time, - "top_k": 10, - }, -) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) -@triton.jit -def kernel_fwd( - C, # Pointers to matrices - ACT_INPUT, - A, - B, - bias, - # Matrix dimensions - M, - N, - K, - CACHE_KEY_M, - CACHE_KEY_N, - CACHE_KEY_K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. stride_am is how much to increase a_ptr - # by to get the element one row down (A has M rows) - stride_cm, - # stride_cn, # Assume that stride_cn == 1 - stride_am, - stride_ak, - stride_bn, - stride_bk, - # Meta-parameters - BLOCK_M: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - # split k not used, not performant with activation, kept because early_config_prune is expecting it - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - A_ROWMAJOR: tl.constexpr, - B_COLMAJOR: tl.constexpr, - BIAS: tl.constexpr, - SAVE_ACT_INPUT: tl.constexpr, - ACTIVATION: tl.constexpr, -): - - """ - Kernel for computing Out = activation(A x W + C) - - Input has shape (M, K) - - Weight has shape (K, N) - - Bias has shape (N,) - - Output has shape (M, N) - - ActInputs (optional) has shape (M, N) - 'ActInputs' optionally saves the A x W + C intermediate for backward computations - This kernel will consolidate over K - """ - - pid = tl.program_id(axis=0) - - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - # now compute the block that each program will go through - # rm (resp. rn) denotes a range of indices - # for rows (resp. col) of C - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - # trick to avoid masking on M and N axis - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - - if A_ROWMAJOR: - A = A + (ram[:, None] * stride_am + rk[None, :]) - else: - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - if B_COLMAJOR: - B = B + (rk[:, None] + rbn[None, :] * stride_bn) - else: - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.0) - b = tl.load(B, mask=rk[:, None] < k, other=0.0) - acc += tl.dot(a, b) - - if A_ROWMAJOR: - A += BLOCK_K - else: - A += BLOCK_K * stride_ak - if B_COLMAJOR: - B += BLOCK_K - else: - B += BLOCK_K * stride_bk - - # Putting bias after the matmul (instead of before) is faster, idk why - if BIAS: - bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) - acc += bias[None, :] - - # optional: save the activation inputs - if SAVE_ACT_INPUT: - # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn - act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] - tl.store(act_in_ptrs, acc) - - # optional: fused activation (while the data is in shared memory) - if ACTIVATION == "gelu": - acc = gelu(acc) - elif ACTIVATION == "gelu_approx": - acc = gelu_approx(acc) - elif ACTIVATION == "squared_relu": - acc = squared_relu(acc) - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # write back result - # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - C = C + rm[:, None] * stride_cm + rn[None, :] - mask = (rm < M)[:, None] & (rn < N)[None, :] - tl.store(C, acc) - - -def triton_linear_act( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: str = "id", - save_act_input: bool = False, -) -> torch.Tensor: - """ - Compute e = activation(x @ weight.T + bias). - This wrapper kicks the `kernel_fwd` Triton kernel - :param x: input tensor - :param weight: weight matrix - :param bias: an optional bias tensor - :param activation: Activation name. Needs to be a Triton kernel. - :param act_input: an optional tensor to save the activation inputs (for backward) - :return: result tensor - """ - # if torch.is_autocast_enabled(): - # dtype = torch.get_autocast_gpu_dtype() - # x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]] - - assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] - - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - x_reshaped = x.reshape(batch_dim, n) - - if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1: - x_reshaped = x_reshaped.contiguous() - if weight.stride(0) > 1 and weight.stride(1) > 1: - weight = weight.contiguous() - bias = bias.contiguous() if bias is not None else None - - assert ( - x.dtype == weight.dtype - ), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" - if bias is not None: - assert ( - x.dtype == bias.dtype - ), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" - assert ( - x_reshaped.shape[1] == weight.shape[1] - ), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}" - - assert ( - bias is None or bias.shape[0] == weight.shape[0] - ), "Incompatible dimensions in between weight and bias" - - M, K = x_reshaped.shape - N, K = weight.shape - - output = torch.empty((M, N), device=x.device, dtype=x.dtype) - act_input = torch.empty_like(output) if save_act_input else None - - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa - - kernel_fwd[grid]( - output, - act_input, - x_reshaped, - weight, # data ptrs - bias if bias is not None else x, # auto skip bias if not present - M, # shapes - N, - K, - M // 32, # key for triton cache (limit number of compilations) - N // 32, - K // 32, - stride_cm=output.stride(0), # strides - # stride_cn=output.stride(1), - stride_am=x_reshaped.stride(0), - stride_ak=x_reshaped.stride(1), - stride_bk=weight.stride(1), - stride_bn=weight.stride(0), - BIAS=bias is not None, # optional fused bias - SAVE_ACT_INPUT=save_act_input, # optional save activation inputs - ACTIVATION=activation, # optional fused activation - A_ROWMAJOR=x_reshaped.stride(1) == 1, - B_COLMAJOR=weight.stride(1) == 1, - GROUP_M=8, # speed optimization: group the programs - ) - - if not save_act_input: - return output.reshape(*batch_shape, output.shape[-1]) - else: - return ( - output.reshape(*batch_shape, output.shape[-1]), - act_input.reshape(*batch_shape, act_input.shape[-1]), - ) - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - # good for int8 - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - ] - + get_configs_io_bound(), - key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": estimate_matmul_time, - "top_k": 10, - }, -) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) -@triton.jit -def kernel_bwd( - C, # Pointers to matrices - ACT_INPUT, - A, - B, - # Matrix dimensions - M, - N, - K, - CACHE_KEY_M, - CACHE_KEY_N, - CACHE_KEY_K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. stride_am is how much to increase a_ptr - # by to get the element one row down (A has M rows) - stride_cm, - # stride_cn, # Assume that stride_cn == 1 - stride_am, - stride_ak, - stride_bk, - stride_bn, - # Meta-parameters - BLOCK_M: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - # split k not used, not performant with activation, kept because early_config_prune is expecting it - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - ACTIVATION: tl.constexpr, -): - - """ - Kernel for computing Out = activation(A x W + C) - - Input has shape (M, K) - - Weight has shape (K, N) - - Output has shape (M, N) - - ActInputs (optional) has shape (M, N) - 'ActInputs' optionally saves the A x W + C intermediate for backward computations - This kernel will consolidate over K - """ - - pid = tl.program_id(axis=0) - - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - # now compute the block that each program will go through - # rm (resp. rn) denotes a range of indices - # for rows (resp. col) of C - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - # trick to avoid masking on M and N axis - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.0) - b = tl.load(B, mask=rk[:, None] < k, other=0.0) - acc += tl.dot(a, b) - - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - # optional: fused activation (while the data is in shared memory) - if ACTIVATION != "id": - act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] - act_input = tl.load(act_in_ptrs).to(acc.dtype) - if ACTIVATION == "gelu": - acc *= gelu_grad(act_input) - elif ACTIVATION == "gelu_approx": - acc *= gelu_approx_grad(act_input) - elif ACTIVATION == "squared_relu": - acc *= squared_relu_grad(act_input) - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # write back result - C = C + rm[:, None] * stride_cm + rn[None, :] - mask = (rm < M)[:, None] & (rn < N)[None, :] - tl.store(C, acc, mask=mask) - - -def triton_dgrad_act( - grad_output: torch.Tensor, - weight: torch.Tensor, - activation: str = "id", - act_input: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - Compute e = activation(grad_output @ weight + bias). - This wrapper kicks the `kernel_fwd` Triton kernel - :param grad_output: input tensor - :param weight: weight matrix - :param activation: Activation name. Needs to be a Triton kernel. - :param act_input: an optional tensor to save the activation inputs (for backward) - :return: result tensor - """ - assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] - - batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1] - batch_dim = batch_shape.numel() - grad_output_reshaped = grad_output.reshape(batch_dim, n) - - if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: - grad_output_reshaped = grad_output_reshaped.contiguous() - if weight.stride(0) > 1 and weight.stride(1) > 1: - weight = weight.contiguous() - - assert ( - grad_output.dtype == weight.dtype - ), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}" - assert ( - grad_output_reshaped.shape[1] == weight.shape[0] - ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" - if activation != "id": - assert act_input is not None, f"act_input is required for activation {activation}" - - # M, N, K in bwd are different from M, N, K in fwd - M, K = grad_output_reshaped.shape - K, N = weight.shape - - grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype) - - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa - - kernel_bwd[grid]( - grad_input, - act_input, - grad_output_reshaped, - weight, # data ptrs - M, # shapes - N, - K, - M // 32, # key for triton cache (limit number of compilations) - N // 32, - K // 32, - stride_cm=grad_input.stride(0), # strides - # stride_cn=grad_input.stride(1), - stride_am=grad_output_reshaped.stride(0), - stride_ak=grad_output_reshaped.stride(1), - stride_bk=weight.stride(0), - stride_bn=weight.stride(1), - ACTIVATION=activation, # optional fused activation - GROUP_M=8, # speed optimization: group the programs - ) - - return grad_input.reshape(*batch_shape, grad_input.shape[-1]) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/mlp.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/mlp.py deleted file mode 100644 index 059f4f8a5e174c1f4824e43d313fca18eaa799b8..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/mlp.py +++ /dev/null @@ -1,149 +0,0 @@ -# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared -# to naive implementation. -import fused_dense_lib as fused_dense_cuda -import torch -import torch.nn as nn -import torch.nn.functional as F - -from flash_attn.utils.torch import custom_fwd, custom_bwd -from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd -from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act - - -class FusedDenseSqreluDenseFunc(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0): - """checkpoint_lvl: - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute act_input and gelu_out in the bwd - """ - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_gpu_dtype() - x, weight1, bias1, weight2, bias2 = [ - a.to(dtype=dtype) for a in [x, weight1, bias1, weight2, bias2] - ] - is_bf16 = x.dtype == torch.bfloat16 - assert checkpoint_lvl in [0, 1, 2] - x = x.contiguous() - weight1 = weight1.contiguous() - bias1 = bias1.contiguous() - weight2 = weight2.contiguous() - bias2 = bias2.contiguous() - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - if is_bf16: - act_input = fused_dense_cuda.linear_bias_forward( - x.reshape(batch_dim, n), weight1, bias1 - ) - output1 = sqrelu_fwd(act_input) - else: - save_act_input = checkpoint_lvl != 2 - result = triton_linear_act( - x.reshape(batch_dim, n), - weight1, - bias1, - activation="squared_relu", - save_act_input=save_act_input, - ) - if save_act_input: - output1, act_input = result - else: - output1 = result - output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2) - ctx.checkpoint_lvl = checkpoint_lvl - if checkpoint_lvl == 0: - ctx.save_for_backward(x, weight1, bias1, weight2, act_input, output1) - elif checkpoint_lvl == 1: - ctx.save_for_backward(x, weight1, bias1, weight2, act_input) - elif checkpoint_lvl == 2: - ctx.save_for_backward(x, weight1, bias1, weight2) - return output2.reshape(*batch_shape, output2.shape[-1]) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - grad_output = grad_output.contiguous() - checkpoint_lvl = ctx.checkpoint_lvl - x, weight1, bias1, weight2, *rest = ctx.saved_tensors - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - is_bf16 = x.dtype == torch.bfloat16 - if checkpoint_lvl == 0: - act_input, output1 = rest - elif checkpoint_lvl == 1: - (act_input,) = rest - output1 = sqrelu_fwd(act_input) - elif checkpoint_lvl == 2: - if is_bf16: - act_input = fused_dense_cuda.linear_bias_forward( - x.reshape(batch_dim, n), weight1, bias1 - ) - output1 = sqrelu_fwd(act_input) - else: - output1, act_input = triton_linear_act( - x.reshape(batch_dim, n), - weight1, - bias1, - activation="squared_relu", - save_act_input=True, - ) - - if is_bf16: - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) - grad_output1 = grad_output @ weight2 - grad_act_input = sqrelu_bwd(grad_output1, act_input) - grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( - x.reshape(batch_dim, n), weight1, grad_act_input - ) - else: - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) - grad_act_input = triton_dgrad_act( - grad_output, weight2, activation="squared_relu", act_input=act_input - ) - grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( - x.reshape(batch_dim, n), weight1, grad_act_input - ) - return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None - - -fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply - - -class FusedDenseSqreluDense(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - bias1=True, - bias2=True, - checkpoint_lvl=0, - device=None, - dtype=None, - ): - """ - checkpoint_lvl (increasing lvl means slower but more memory saving): - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute gelu_in and gelu_out in the bwd - """ - assert checkpoint_lvl in [0, 1, 2] - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features * 4 - assert bias1 == True, "DenseSqreluDense module without bias is currently not supported" - assert bias2 == True, "DenseSqreluDense module without bias is currently not supported" - self.checkpoint_lvl = checkpoint_lvl - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) - - def forward(self, x): - assert x.is_cuda - return fused_dense_sqrelu_dense_function( - x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias, self.checkpoint_lvl - ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/rotary.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/rotary.py deleted file mode 100644 index ff4017fda3e4a6e18cf3a51b34f0fb073d8f678a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/flash_attn/ops/triton/rotary.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) 2025, Tri Dao. -# As of 2025-04-23, we require triton >= 3.0 - -from typing import Optional, Union - -import torch - -import triton -import triton.language as tl - - -@triton.jit -def rotary_kernel( - OUT, # Pointers to matrices - X, - COS, - SIN, - CU_SEQLENS, - SEQLEN_OFFSETS, # this could be int or a pointer - # Matrix dimensions - seqlen, - nheads, - seqlen_ro, - # strides - stride_out_batch, - stride_out_seqlen, - stride_out_nheads, - stride_out_headdim, - stride_x_batch, - stride_x_seqlen, - stride_x_nheads, - stride_x_headdim, - # Meta-parameters - # We want ROTARY_DIM to be constexpr, otherwise the triton compiler doesn't know that - # the mask is constant every 8 elements, and it will generate LDG.16 instead of LDG.128 - ROTARY_DIM: tl.constexpr, - IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, - IS_VARLEN: tl.constexpr, - INTERLEAVED: tl.constexpr, - CONJUGATE: tl.constexpr, - BLOCK_H: tl.constexpr, - BLOCK_M: tl.constexpr, -): - BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM) - ROTARY_DIM_HALF = ROTARY_DIM // 2 - pid_head = tl.program_id(axis=0) - pid_m = tl.program_id(axis=1) - pid_batch = tl.program_id(axis=2) - - if not IS_VARLEN: - X = X + pid_batch * stride_x_batch - OUT = OUT + pid_batch * stride_out_batch - else: - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - X = X + start_idx * stride_x_seqlen - OUT = OUT + start_idx * stride_out_seqlen - - if pid_m * BLOCK_M >= seqlen: - return - - rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - if not IS_SEQLEN_OFFSETS_TENSOR: - rm_cs = rm + SEQLEN_OFFSETS - else: - rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) - - rk_half = tl.arange(0, BLOCK_K // 2) - COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) - SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) - mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) - cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) - sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) - if CONJUGATE: - sin = -sin - - if not INTERLEAVED: - # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT - X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk_half[None, None, :] * stride_x_headdim) - OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk_half[None, None, :] * stride_out_headdim) - mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk_half[None, None, :] < ROTARY_DIM_HALF) - x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) - x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0,).to(tl.float32) - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - tl.store(OUT, o0, mask=mask) - tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) - else: - rk = tl.arange(0, BLOCK_K) - X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk[None, None, :] * stride_x_headdim) - OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk[None, None, :] * stride_out_headdim) - mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk[None, None, :] < ROTARY_DIM) - x = tl.load(X, mask=mask, other=0.0).to(tl.float32) - x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) - tl.store(OUT, o, mask=mask) - - -def apply_rotary( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - interleaved=False, - inplace=False, - conjugate=False, -) -> torch.Tensor: - """ - Arguments: - x: (batch, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim). - cos: (seqlen_ro, rotary_dim / 2) - sin: (seqlen_ro, rotary_dim / 2) - seqlen_offsets: integer or integer tensor of size (batch,) - cu_seqlens: (batch + 1,) or None - max_seqlen: int - Returns: - y: (batch, seqlen, nheads, headdim) - """ - is_varlen = cu_seqlens is not None - if not is_varlen: - batch, seqlen, nheads, headdim = x.shape - else: - assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" - total_seqlen, nheads, headdim = x.shape - batch_p_1 = cu_seqlens.shape[0] - batch = batch_p_1 - 1 - seqlen = max_seqlen - seqlen_ro, rotary_dim = cos.shape - assert sin.shape == cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim, "rotary_dim must be <= headdim" - assert headdim <= 256, "Only support headdim <= 256" - assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" - - cos, sin = cos.contiguous(), sin.contiguous() - if isinstance(seqlen_offsets, torch.Tensor): - assert seqlen_offsets.shape == (batch,) - assert seqlen_offsets.dtype in [torch.int32, torch.int64] - seqlen_offsets = seqlen_offsets.contiguous() - else: - assert seqlen_offsets + seqlen <= seqlen_ro - - output = torch.empty_like(x) if not inplace else x - if rotary_dim < headdim and not inplace: - output[..., rotary_dim:].copy_(x[..., rotary_dim:]) - - grid = lambda META: (triton.cdiv(nheads, META["BLOCK_H"]), triton.cdiv(seqlen, META["BLOCK_M"]), batch) # noqa - BLOCK_M = 8 if rotary_dim <= 128 else 4 - - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(x.device.index): - torch.library.wrap_triton(rotary_kernel)[grid]( - output, # data ptrs - x, - cos, - sin, - cu_seqlens, - seqlen_offsets, - seqlen, # shapes - nheads, - seqlen_ro, - output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 - output.stride(-3), # seqlen_stride or total_seqlen_stride - output.stride(-2), # nheads_stride - output.stride(-1), # headdim_stride - x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 - x.stride(-3), # seqlen stride or total_seqlen_stride - x.stride(-2), # nheads stride - x.stride(-1), # headdim stride - rotary_dim, - isinstance(seqlen_offsets, torch.Tensor), - is_varlen, - interleaved, - conjugate, - BLOCK_M=BLOCK_M, - BLOCK_H=2, - ) - return output diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/__init__.py deleted file mode 100644 index ecc2f9d896b6c93f90b0a1499856dc0612177422..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/__init__.py +++ /dev/null @@ -1,393 +0,0 @@ -from typing import Optional, List -import torch -from ._ops import ops as flash_attn_ops -from .flash_attn_interface import ( - flash_attn_func, - flash_attn_kvpacked_func, - flash_attn_qkvpacked_func, - flash_attn_varlen_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_with_kvcache, -) - - -def fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - p_dropout: float = 0.0, - softmax_scale: Optional[float] = None, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - return_softmax: bool = False, - gen: Optional[torch.Generator] = None, -) -> List[torch.Tensor]: - """ - Forward pass for multi-head attention. - - Args: - q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] - k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - out: Optional output tensor, same shape as q - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - p_dropout: Dropout probability - softmax_scale: Scale factor for softmax - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - return_softmax: Whether to return softmax weights - gen: Optional random number generator - - Returns: - List of tensors: [output, softmax_lse, (softmax if return_softmax)] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.fwd( - q, - k, - v, - out, - alibi_slopes, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - gen, - ) - - -def varlen_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - out: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - max_seqlen_q: int = 0, - max_seqlen_k: int = 0, - p_dropout: float = 0.0, - softmax_scale: Optional[float] = None, - zero_tensors: bool = False, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - return_softmax: bool = False, - gen: Optional[torch.Generator] = None, -) -> List[torch.Tensor]: - """ - Forward pass for multi-head attention with variable sequence lengths. - - Args: - q: Query tensor of shape [total_q, num_heads, head_size] - k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] - v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] - cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1] - cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1] - out: Optional output tensor of shape [total_q, num_heads, head_size] - seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size] - leftpad_k: Optional left padding for keys of shape [batch_size] - block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq] - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - max_seqlen_q: Maximum sequence length for queries - max_seqlen_k: Maximum sequence length for keys - p_dropout: Dropout probability - softmax_scale: Scale factor for softmax - zero_tensors: Whether to zero tensors before computation - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - return_softmax: Whether to return softmax weights - gen: Optional random number generator - - Returns: - List of tensors: [output, softmax_lse, (softmax if return_softmax)] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.varlen_fwd( - q, - k, - v, - out, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - leftpad_k, - block_table, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - p_dropout, - softmax_scale, - zero_tensors, - is_causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - gen, - ) - - -def bwd( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor] = None, - dk: Optional[torch.Tensor] = None, - dv: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - p_dropout: float = 0.0, - softmax_scale: Optional[float] = None, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - deterministic: bool = False, - gen: Optional[torch.Generator] = None, - rng_state: Optional[torch.Tensor] = None, -) -> List[torch.Tensor]: - """ - Backward pass for multi-head attention. - - Args: - dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size] - q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] - k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size] - softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q] - dq: Optional gradient tensor for queries, same shape as q - dk: Optional gradient tensor for keys, same shape as k - dv: Optional gradient tensor for values, same shape as v - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - p_dropout: Dropout probability - softmax_scale: Scale factor for softmax - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - deterministic: Whether to use deterministic algorithms - gen: Optional random number generator - rng_state: Optional RNG state from forward pass - - Returns: - List of tensors: [dq, dk, dv] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - alibi_slopes, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - softcap, - deterministic, - gen, - rng_state, - ) - - -def varlen_bwd( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - dq: Optional[torch.Tensor] = None, - dk: Optional[torch.Tensor] = None, - dv: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - max_seqlen_q: int = 0, - max_seqlen_k: int = 0, - p_dropout: float = 0.0, - softmax_scale: Optional[float] = None, - zero_tensors: bool = False, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - deterministic: bool = False, - gen: Optional[torch.Generator] = None, - rng_state: Optional[torch.Tensor] = None, -) -> List[torch.Tensor]: - """ - Backward pass for multi-head attention with variable sequence lengths. - - Args: - dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size] - q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] - k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size] - softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q] - cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1] - cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1] - dq: Optional gradient tensor for queries, same shape as q - dk: Optional gradient tensor for keys, same shape as k - dv: Optional gradient tensor for values, same shape as v - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - max_seqlen_q: Maximum sequence length for queries - max_seqlen_k: Maximum sequence length for keys - p_dropout: Dropout probability - softmax_scale: Scale factor for softmax - zero_tensors: Whether to zero tensors before computation - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - deterministic: Whether to use deterministic algorithms - gen: Optional random number generator - rng_state: Optional RNG state from forward pass - - Returns: - List of tensors: [dq, dk, dv] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.varlen_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - p_dropout, - softmax_scale, - zero_tensors, - is_causal, - window_size_left, - window_size_right, - softcap, - deterministic, - gen, - rng_state, - ) - - -def fwd_kvcache( - q: torch.Tensor, - kcache: torch.Tensor, - vcache: torch.Tensor, - k: Optional[torch.Tensor] = None, - v: Optional[torch.Tensor] = None, - seqlens_k: Optional[torch.Tensor] = None, - rotary_cos: Optional[torch.Tensor] = None, - rotary_sin: Optional[torch.Tensor] = None, - cache_batch_idx: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - is_rotary_interleaved: bool = False, - num_splits: int = 1, -) -> List[torch.Tensor]: - """ - Forward pass for multi-head attention with KV cache. - - Args: - q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] - kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] - vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] - k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size] - v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size] - seqlens_k: Optional sequence lengths for keys of shape [batch_size] - rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2] - rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2] - cache_batch_idx: Optional indices to index into the KV cache - leftpad_k: Optional left padding for keys of shape [batch_size] - block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq] - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - out: Optional output tensor, same shape as q - softmax_scale: Scale factor for softmax - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - is_rotary_interleaved: Whether rotary embeddings are interleaved - num_splits: Number of splits for computation - - Returns: - List of tensors: [output, softmax_lse] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.fwd_kvcache( - q, - kcache, - vcache, - k, - v, - seqlens_k, - rotary_cos, - rotary_sin, - cache_batch_idx, - leftpad_k, - block_table, - alibi_slopes, - out, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - softcap, - is_rotary_interleaved, - num_splits, - ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so deleted file mode 100755 index b378d0dea5074c177d5126f7e936f05616bbc685..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0764a49f4b0f342e5b4f1e42a0ac7ef9035fcd720e34fdb3371c06de7e52c275 -size 447253888 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_ops.py deleted file mode 100644 index a9819140ce922d5d25722ffeb3c2416285a9d068..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _flash_attn_876ac68_dirty -ops = torch.ops._flash_attn_876ac68_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_flash_attn_876ac68_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/bert_padding.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/bert_padding.py deleted file mode 100644 index 3c2d35159a014a9d03aabead9e52e009168696ea..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/bert_padding.py +++ /dev/null @@ -1,218 +0,0 @@ -# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py - -import torch -import torch.nn.functional as F -from einops import rearrange, repeat - - -class IndexFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices): - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - # return input[indices] - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) - ).reshape(-1, *other_shape) - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - grad_output = rearrange(grad_output, "b ... -> b (...)") - grad_input = torch.zeros( - [ctx.first_axis_dim, grad_output.shape[1]], - device=grad_output.device, - dtype=grad_output.dtype, - ) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - # grad_input[indices] = grad_output - grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -index_first_axis = IndexFirstAxis.apply - - -class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, values, indices, first_axis_dim): - ctx.save_for_backward(indices) - assert indices.ndim == 1 - assert values.ndim >= 2 - output = torch.zeros( - first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype - ) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - output[indices] = values - # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) - return output - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - grad_values = grad_output[indices] - # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) - return grad_values, None, None - - -index_put_first_axis = IndexPutFirstAxis.apply - - -class IndexFirstAxisResidual(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices): - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - output = input[indices] - # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last - # memory format to channel_first. In other words, input might not be contiguous. - # If we don't detach, Pytorch complains about output being a view and is being modified inplace - return output, input.detach() - - @staticmethod - def backward(ctx, grad_output, grad_residual): - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - assert grad_residual.shape[1:] == other_shape - grad_input = grad_residual - # grad_input[indices] += grad_output - indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) - indices = indices.expand_as(grad_output) - grad_input.scatter_add_(0, indices, grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -index_first_axis_residual = IndexFirstAxisResidual.apply - - -def unpad_input(hidden_states, attention_mask, unused_mask=None): - """ - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. - unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. - indices: (total_nnz), the indices of masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. - """ - all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask - seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) - used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return ( - index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - used_seqlens_in_batch, - ) - - -def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): - """ - Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). - The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). - - For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: - ``` - [ - [2, 3, 0, 0, 0, 0], - [3, 2, 0, 0, 0, 0], - [6, 0, 0, 0, 0, 0] - ] - ``` - , which refers to the 3D-attention mask: - ``` - [ - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0], - [0, 0, 1, 1, 0, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 1] - ], - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 1] - ], - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1] - ] - ] - ```. - - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - """ - length = attention_mask_in_length.sum(dim=-1) - seqlen = attention_mask_in_length.size(-1) - attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) - real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() - seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] - indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return ( - index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def pad_input(hidden_states, indices, batch, seqlen): - """ - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. - batch: int, batch size for the padded sequence. - seqlen: int, maximum sequence length for the padded sequence. - Return: - hidden_states: (batch, seqlen, ...) - """ - dim = hidden_states.shape[-1] - # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) - # output[indices] = hidden_states - output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, "(b s) ... -> b s ...", b=batch) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/flash_attn_interface.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/flash_attn_interface.py deleted file mode 100644 index 690d644f0a1c3d6ccfd26acbf8b22376a47cfff0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/flash_attn_interface.py +++ /dev/null @@ -1,1609 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -from typing import Optional, Sequence, Tuple, Union - -import torch -import torch.nn as nn -import os - -# # isort: off -# # We need to import the CUDA kernels after importing torch -# USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" -# if USE_TRITON_ROCM: -# from .flash_attn_triton_amd import interface_fa as flash_attn_gpu -# else: -# import flash_attn_2_cuda as flash_attn_gpu - - -from ._ops import ops as flash_attn_gpu - -# # isort: on - -def maybe_contiguous(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def _get_block_size_n(device, head_dim, is_dropout, is_causal): - # This should match the block sizes in the CUDA kernel - assert head_dim <= 256 - major, minor = torch.cuda.get_device_capability(device) - is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) - is_sm80 = major == 8 and minor == 0 - is_sm90 = major == 9 and minor == 0 - if head_dim <= 32: - return 128 - if head_dim <= 64: - return 128 if not is_dropout else 64 - elif head_dim <= 96: - return 64 - elif head_dim <= 128: - if is_sm8x: - return 64 if (not is_dropout and is_causal) else 32 - else: - return 64 if not is_dropout else 32 - elif head_dim <= 192: - return 64 - elif head_dim <= 224: - return 64 - elif head_dim <= 256: - return 64 - - -def round_multiple(x, m): - return (x + m - 1) // m * m - - -# torch.compile() support is only enabled for pytorch >= 2.4 -# The reason for this is that we are using the new custom_op and register_fake -# APIs, which support inplace modification of inputs in the function itself -if torch.__version__ >= "2.4.0": - _torch_custom_op_wrapper = torch.library.custom_op - _torch_register_fake_wrapper = torch.library.register_fake -else: - def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): - def wrap(func): - return func - if fn is None: - return wrap - return fn - def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): - def wrap(func): - return func - if fn is None: - return wrap - return fn - _torch_custom_op_wrapper = noop_custom_op_wrapper - _torch_register_fake_wrapper = noop_register_fake_wrapper - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda") -def _flash_attn_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - return_softmax: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( - q, - k, - v, - None, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - None, - ) - return out, softmax_lse, S_dmask, rng_state - - -@_torch_register_fake_wrapper("flash_attn::_flash_attn_forward") -def _flash_attn_forward_fake( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - return_softmax: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - batch_size, seqlen_q, num_heads, head_size = q.shape - seqlen_k = k.shape[1] - out = torch.empty_like(q) - softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout) - p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - if return_softmax: - p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) - rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) - - return out, softmax_lse, p, rng_state - - -if torch.__version__ >= "2.4.0": - _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward -else: - _wrapped_flash_attn_forward = _flash_attn_forward - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda") -def _flash_attn_varlen_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - return_softmax: bool = False, - block_table: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( - q, - k, - v, - None, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - leftpad_k, - block_table, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - zero_tensors, - causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - None, - ) - # if out.isnan().any() or softmax_lse.isnan().any(): - # breakpoint() - return out, softmax_lse, S_dmask, rng_state - - -@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward") -def _flash_attn_varlen_forward_fake( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - return_softmax: bool = False, - block_table: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - paged_kv = block_table is not None - batch_size = cu_seqlens_q.numel() - 1 - total_q, num_heads, _ = q.shape - - out = torch.empty_like(q) - softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) - p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - seqlen_q_rounded = round_multiple(max_seqlen_q, 128) - seqlen_k_rounded = round_multiple(max_seqlen_k, 128) - if return_softmax: - p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) - rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) - return out, softmax_lse, p, rng_state - - -if torch.__version__ >= "2.4.0": - _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward -else: - _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") -def _flash_attn_backward( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, -) -> torch.Tensor: - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - ( - dq, - dk, - dv, - softmax_d, - ) = flash_attn_gpu.bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - deterministic, - None, - rng_state, - ) - return softmax_d - - -@_torch_register_fake_wrapper("flash_attn::_flash_attn_backward") -def _flash_attn_backward_fake( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, -) -> torch.Tensor: - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - if dq is None: - dq = torch.empty_like(q) - if dk is None: - dk = torch.empty_like(k) - if dv is None: - dv = torch.empty_like(v) - batch_size, seqlen_q, num_heads, _ = q.shape - softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) - - return softmax_d - - -if torch.__version__ >= "2.4.0": - _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward -else: - _wrapped_flash_attn_backward = _flash_attn_backward - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") -def _flash_attn_varlen_backward( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> torch.Tensor: - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - ( - dq, - dk, - dv, - softmax_d, - ) = flash_attn_gpu.varlen_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - zero_tensors, - causal, - window_size_left, - window_size_right, - softcap, - deterministic, - None, - rng_state, - ) - # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): - # breakpoint() - return softmax_d - - -@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward") -def _flash_attn_varlen_backward_fake( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> torch.Tensor: - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - batch_size = cu_seqlens_q.numel() - 1 - total_q, num_heads, _ = q.shape - - if dq is None: - dq = torch.empty_like(q) - if dk is None: - dk = torch.empty_like(k) - if dv is None: - dv = torch.empty_like(v) - softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) - - return softmax_d - - -if torch.__version__ >= "2.4.0": - _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward -else: - _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward - - -class FlashAttnQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach() - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) - dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dqkv[:, :, 0], - dqkv[:, :, 1], - dqkv[:, :, 2], - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None - - -class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach() - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - block_table=None, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) - ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors - qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) - dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_varlen_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - cu_seqlens, - cu_seqlens, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None, None - - -class FlashAttnKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - kv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, kv] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach() - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - dq = torch.empty_like(q) - kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) - dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dkv[:, :, 0], - dkv[:, :, 1], - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None - - -class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, kv] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - k, v = kv[:, 0].detach(), kv[:, 1].detach() - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - block_table=None, - ) - if is_grad: - ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state - ) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - dq = torch.empty_like(q) - kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) - dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_varlen_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dkv[:, 0], - dkv[:, 1], - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None - - -class FlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None - - -class FlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - block_table, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - block_table=block_table, - ) - if is_grad: - ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state - ) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_varlen_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # <=0.0 means deactivate - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If Q, K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of Q, K, V. - For multi-query and grouped-query attention (MQA/GQA), please see - flash_attn_kvpacked_func and flash_attn_func. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. - - Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to - the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnQKVPackedFunc.apply( - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of K, V. - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - kv: (batch_size, seqlen, 2, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnKVPackedFunc.apply( - q, - kv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If Q, K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of Q, K, V. - For multi-query and grouped-query attention (MQA/GQA), please see - flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. - - Arguments: - qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into qkv. - max_seqlen: int. Maximum sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenQKVPackedFunc.apply( - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of K, V. - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenKVPackedFunc.apply( - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - block_table=None, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - block_table, - torch.is_grad_enabled(), - ) - - -def flash_attn_with_kvcache( - q, - k_cache, - v_cache, - k=None, - v=None, - rotary_cos=None, - rotary_sin=None, - cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, - cache_batch_idx: Optional[torch.Tensor] = None, - cache_leftpad: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - rotary_interleaved=True, - alibi_slopes=None, - num_splits=0, - return_softmax_lse=False, -): - """ - If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from - k and v. This is useful for incremental decoding: you can pass in the cached keys/values from - the previous step, and update them with the new keys/values from the current step, and do - attention with the updated cache, all in 1 kernel. - - If you pass in k / v, you must make sure that the cache is large enough to hold the new values. - For example, the KV cache could be pre-allocated with the max sequence length, and you can use - cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. - - Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be - rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos - and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at - indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). - - See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. - - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Note: Does not support backward pass. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) - page_block_size must be a multiple of 256. - v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) - k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate - k with k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. - rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding - to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. - rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. - cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the - KV cache. - cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. - If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. - If the indices are not distinct, and k and v are provided, the values updated in the cache - might come from any of the duplicate indices. - cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. - block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. - If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, - rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 - (i.e. GPT-NeoX style). - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - num_splits: int. If > 1, split the key/value into this many chunks along the sequence. - If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic - to automatically determine the number of splits. - Don't change this unless you know what you are doing. - return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. - - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - """ - assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" - assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - if cache_seqlens is not None and isinstance(cache_seqlens, int): - cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device - ) - cache_seqlens = maybe_contiguous(cache_seqlens) - cache_batch_idx = maybe_contiguous(cache_batch_idx) - block_table = maybe_contiguous(block_table) - out, softmax_lse = flash_attn_gpu.fwd_kvcache( - q, - k_cache, - v_cache, - k, - v, - cache_seqlens, - rotary_cos, - rotary_sin, - cache_batch_idx, - cache_leftpad, - block_table, - alibi_slopes, - None, - softmax_scale, - causal, - window_size[0], - window_size[1], - softcap, - rotary_interleaved, - num_splits, - ) - return (out, softmax_lse) if return_softmax_lse else out diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/layers/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/layers/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/layers/patch_embed.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/layers/patch_embed.py deleted file mode 100644 index 05562f8e8bcdb58e947c6f402a49eacd2d031871..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/layers/patch_embed.py +++ /dev/null @@ -1,67 +0,0 @@ -# We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py -# But we use nn.Linear instead of Conv2d and it's about 8x faster. - -from functools import partial - -import torch.nn as nn -from einops import rearrange -from torch import _assert -from torch.nn.modules.utils import _pair - -try: - from flash_attn.ops.fused_dense import FusedDense -except ImportError: - FusedDense = None - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding""" - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - norm_layer=None, - flatten=True, - bias=True, - fused_bias_fc=False, - ): - super().__init__() - img_size = _pair(img_size) - patch_size = _pair(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - if fused_bias_fc and FusedDense is None: - raise ImportError("fused_dense is not installed") - - linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense - self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x): - _, _, H, W = x.shape - _assert( - H == self.img_size[0], - f"Input image height ({H}) doesn't match model ({self.img_size[0]}).", - ) - _assert( - W == self.img_size[1], - f"Input image width ({W}) doesn't match model ({self.img_size[1]}).", - ) - x = self.proj( - rearrange( - x, - "b c (h p1) (w p2) -> b h w (c p1 p2)", - p1=self.patch_size[0], - p2=self.patch_size[1], - ) - ) - if self.flatten: - x = rearrange(x, "b h w c -> b (h w) c") - x = self.norm(x) - return x diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/layers/rotary.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/layers/rotary.py deleted file mode 100644 index d1bfc21fc7de1dd287e8f382847b194a48075981..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/layers/rotary.py +++ /dev/null @@ -1,483 +0,0 @@ -# Copyright (c) 2025, Tri Dao - -import math -from functools import partial -from typing import Optional, Tuple, Union - -import torch -from torch import Tensor - -from einops import rearrange, repeat -# from flash_attn.ops.triton.rotary import apply_rotary -from ..ops.triton.rotary import apply_rotary - - -def rotate_half(x, interleaved=False): - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) - - -def apply_rotary_emb_torch(x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - return torch.cat( - [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], - dim=-1, - ) - - -class ApplyRotaryEmb(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - cos, - sin, - interleaved=False, - inplace=False, - seqlen_offsets: Union[int, Tensor] = 0, - cu_seqlens: Optional[Tensor] = None, - max_seqlen: Optional[int] = None, - ): - out = apply_rotary( - x, - cos, - sin, - seqlen_offsets=seqlen_offsets, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - interleaved=interleaved, - inplace=inplace, - ) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - ctx.inplace = inplace - ctx.max_seqlen = max_seqlen - return out if not inplace else x - - @staticmethod - def backward(ctx, do): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors - else: - cos, sin, cu_seqlens = ctx.saved_tensors - dx = apply_rotary( - do, - cos, - sin, - seqlen_offsets=seqlen_offsets, - cu_seqlens=cu_seqlens, - max_seqlen=ctx.max_seqlen, - interleaved=ctx.interleaved, - inplace=ctx.inplace, - conjugate=True, - ) - return dx, None, None, None, None, None, None, None - - -def apply_rotary_emb( - x, - cos, - sin, - interleaved=False, - inplace=False, - seqlen_offsets: Union[int, Tensor] = 0, - cu_seqlens: Optional[Tensor] = None, - max_seqlen: Optional[int] = None, -): - """ - Arguments: - x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim) - cos, sin: (seqlen_rotary, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - inplace: if True, apply rotary embedding in-place. - seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. - Most commonly used in inference when we have KV cache. - cu_seqlens: (batch + 1,) or None - max_seqlen: int - Return: - out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim) - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - return ApplyRotaryEmb.apply( - x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen - ) - - -# For backward compatibility -apply_rotary_emb_func = apply_rotary_emb - - -def _apply_rotary_emb_qkv( - qkv, - cos, - sin, - cos_k=None, - sin_k=None, - interleaved=False, - inplace=False, - conjugate=False, - seqlen_offsets: Union[int, Tensor] = 0, - num_heads_q: Optional[int] = None, -): - apply_rotary_fn = partial( - apply_rotary, - interleaved=interleaved, - inplace=inplace, - conjugate=conjugate, - seqlen_offsets=seqlen_offsets - ) - if cos_k is None and sin_k is None and qkv.is_contiguous(): - # Call 1 kernel instead of 2 kernels - # We need qkv to be contiguous so that when we reshape to combine (3, nheads) - # dimensions, we get the same tensor - if qkv.dim() == 5: - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") - qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) - qk = apply_rotary_fn(qk, cos, sin) - else: - assert qkv.dim() == 4 - assert num_heads_q is not None - num_heads_k = (qkv.shape[2] - num_heads_q) // 2 - assert qkv.shape[2] == num_heads_q + 2 * num_heads_k - qk = qkv[:, :, :num_heads_q + num_heads_k] - qk = apply_rotary_fn(qk, cos, sin) - if not inplace: - if qkv.dim() == 5: - qkv = torch.cat([rearrange(qk, "b s (t h) d -> b s t h d", t=2), qkv[:, :, 2:]], dim=2) - else: - qkv = torch.cat([qk, qkv[:, :, num_heads_q + num_heads_k :]], dim=2) - else: - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - if qkv.dim() == 5: - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - q, k = qkv[:, :, 0], qkv[:, :, 1] - else: - assert qkv.dim() == 4 - assert num_heads_q is not None - num_heads_k = (qkv.shape[2] - num_heads_q) // 2 - assert qkv.shape[2] == num_heads_q + 2 * num_heads_k - q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k] - q = apply_rotary_fn(q, cos, sin) - k = apply_rotary_fn(k, cos_k, sin_k) - if not inplace: - if qkv.dim() == 5: - qkv = torch.stack([q, k, qkv[:, :, 2]], dim=2) - else: - qkv = torch.cat([q, k, qkv[:, :, num_heads_q + num_heads_k:]], dim=2) - return qkv - - -class ApplyRotaryEmbQKV_(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cos, - sin, - cos_k=None, - sin_k=None, - interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - num_heads_q: Optional[int] = None, - ): - # apply_rotary_emb_qkv_inplace( - qkv = _apply_rotary_emb_qkv( - qkv, cos, sin, cos_k, sin_k, interleaved=interleaved, inplace=True, - seqlen_offsets=seqlen_offsets, num_heads_q=num_heads_q, - ) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - ctx.num_heads_q = num_heads_q - return qkv - - @staticmethod - def backward(ctx, dqkv): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors - else: - cos, sin, cos_k, sin_k = ctx.saved_tensors - dqkv = _apply_rotary_emb_qkv( - dqkv, cos, sin, cos_k, sin_k, interleaved=ctx.interleaved, inplace=True, - seqlen_offsets=seqlen_offsets, num_heads_q=ctx.num_heads_q, conjugate=True, - ) - return dqkv, None, None, None, None, None, None, None - - -def apply_rotary_emb_qkv_( - qkv, - cos, - sin, - cos_k=None, - sin_k=None, - interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - num_heads_q: Optional[int] = None, -): - """ - Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim). - If qkv has shape (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA), - then num_heads_q must be provided. - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. - Most commonly used in inference when we have KV cache. - Return: - qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of Q and K. - """ - return ApplyRotaryEmbQKV_.apply( - qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, num_heads_q - ) - - -class ApplyRotaryEmbKV_(torch.autograd.Function): - - @staticmethod - def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): - batch, seqlen, two, nheads, headdim = kv.shape - assert two == 2 - k = kv[:, :, 0] - apply_rotary( - k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True - ) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - return kv - - @staticmethod - def backward(ctx, dkv): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, seqlen_offsets = ctx.saved_tensors - else: - cos, sin = ctx.saved_tensors - apply_rotary( - dkv[:, :, 0], - cos, - sin, - seqlen_offsets=seqlen_offsets, - interleaved=ctx.interleaved, - inplace=True, - conjugate=True, - ) - return dkv, None, None, None, None - - -apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply - - -def apply_rotary_emb_kv_( - kv, - cos, - sin, - interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, -): - """ - Arguments: - kv: (batch_size, seqlen, 2, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. - Most commonly used in inference when we have KV cache. - Return: - kv: (batch_size, seqlen, 2, nheads, headdim) - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of K. - """ - return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets) - - -class RotaryEmbedding(torch.nn.Module): - """ - The rotary position embeddings from RoFormer_ (Su et. al). - A crucial insight from the method is that the query and keys are - transformed by rotation matrices which depend on the relative positions. - - Other implementations are available in the Rotary Transformer repo_ and in - GPT-NeoX_, GPT-NeoX was an inspiration - - .. _RoFormer: https://arxiv.org/abs/2104.09864 - .. _repo: https://github.com/ZhuiyiTechnology/roformer - .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox - - If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). - A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 - Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py - """ - - def __init__( - self, - dim: int, - base=10000.0, - interleaved=False, - scale_base=None, - device=None, - ): - """ - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - """ - super().__init__() - self.dim = dim - self.base = float(base) - # Generate and save the inverse frequency buffer (non trainable) - inv_freq = self._compute_inv_freq(device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.interleaved = interleaved - self.scale_base = scale_base - scale = ( - (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) - if scale_base is not None - else None - ) - self.register_buffer("scale", scale, persistent=False) - - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - - def _compute_inv_freq(self, device=None): - return 1.0 / ( - self.base - ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) - ) - - def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): - # Reset the tables if the sequence length has changed, - # if we're on a new device (possibly due to tracing for instance), - # or if we're switching from inference mode to training - if ( - seqlen > self._seq_len_cached - or self._cos_cached is None - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._seq_len_cached = seqlen - # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 - # And the output of arange can be quite large, so bf16 would lose a lot of precision. - t = torch.arange(seqlen, device=device, dtype=torch.float32) - # We want fp32 here as well since inv_freq will be multiplied with t, and the output - # will be large. Having it in bf16 will lose a lot of precision and cause the - # cos & sin output to change significantly. - # We want to recompute self.inv_freq if it was not loaded in fp32 - if self.inv_freq.dtype != torch.float32: - inv_freq = self._compute_inv_freq(device=device) - else: - inv_freq = self.inv_freq - # Don't do einsum, it converts fp32 to bf16 under AMP - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, inv_freq) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") - # We want the multiplication by scale to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - - def forward( - self, - qkv: torch.Tensor, - kv: Optional[torch.Tensor] = None, - seqlen_offset: Union[int, torch.Tensor] = 0, - max_seqlen: Optional[int] = None, - num_heads_q: Optional[int] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - qkv: (batch, seqlen, 3, nheads, headdim) or (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) - if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim). - If qkv has shape (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA), - then num_heads_q must be provided. - kv: (batch, seqlen, 2, nheads, headdim) - seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. - Most commonly used in inference when we have KV cache. - If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one - should pass in max_seqlen, which will update the cos / sin cache up to that length. - Apply rotary embedding *inplace* to qkv and / or kv. - """ - seqlen = qkv.shape[1] - if max_seqlen is not None: - self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) - elif isinstance(seqlen_offset, int): - self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) - if kv is None: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached, - self._sin_cached, - self._cos_k_cached if self.scale is not None else None, - self._sin_k_cached if self.scale is not None else None, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - num_heads_q=num_heads_q, - ) - else: - q = qkv - q = apply_rotary_emb_func( - q, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - inplace=True, - seqlen_offsets=seqlen_offset, - ) - kv = apply_rotary_emb_kv_( - kv, - self._cos_cached if self.scale is None else self._cos_k_cached, - self._sin_cached if self.scale is None else self._sin_k_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - return q, kv diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/activations.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/activations.py deleted file mode 100644 index 7c09649fc41e12d5a360c5672825d8380bc7ec80..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/activations.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -# 1/sqrt(2*pi)-> 0.3989423 -# 1/sqrt(2) -> 0.70710678 -# sqrt(2/pi) -> 0.79788456 - -# this function is tanh approximation of gelu -# actual gelu is: -# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) -@torch.jit.script -def bias_gelu(y, bias): - x = bias + y - return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) - - -# gradient of tanh approximation of gelu -# gradient of actual gelu is: -# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) -@torch.jit.script -def bias_gelu_back(g, y, bias): - """Assume that y has shape (B, D) and bias has shape (D)""" - x = bias + y - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) - grad_y = ff * g - return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) - - -class GeLUFunction(torch.autograd.Function): - @staticmethod - # bias is an optional argument - def forward(ctx, input, bias): - ctx.save_for_backward(input, bias) - return bias_gelu(input, bias) - - @staticmethod - def backward(ctx, grad_output): - input, bias = ctx.saved_tensors - tmp = bias_gelu_back(grad_output, input, bias) - return tmp, tmp - - -bias_gelu_impl = GeLUFunction.apply - -# this function is tanh approximation of gelu -# actual gelu is: -# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) -@torch.jit.script -def gelu_fwd(x): - return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) - - -# gradient of tanh approximation of gelu -# gradient of actual gelu is: -# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) -@torch.jit.script -def gelu_bwd(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) - return (ff * g).to(dtype=x.dtype) - - -class FastGeLUFunction(torch.autograd.Function): - @staticmethod - # bias is an optional argument - def forward(ctx, input): - ctx.save_for_backward(input) - return gelu_fwd(input) - - @staticmethod - def backward(ctx, grad_output): - (input,) = ctx.saved_tensors - tmp = gelu_bwd(grad_output, input) - return tmp - - -fast_gelu_impl = FastGeLUFunction.apply - - -@torch.jit.script -def relu_bwd(g, x): - return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) - - -@torch.jit.script -def sqrelu_fwd(x): - r = F.relu(x) - return (r * r).to(dtype=x.dtype) - - -@torch.jit.script -def sqrelu_bwd(g, x): - return (2.0 * g * F.relu(x)).to(dtype=x.dtype) - - -swiglu_fwd_codestring = """ -template T swiglu_fwd(T x, T y) { - return float(x) * float(y) / (1.0f + ::exp(-float(x))); -} -""" -swiglu_bwd_codestring = """ -template void swiglu_bwd(T x, T y, T g, T& dx, T& dy) { - float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); - dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); - dy = float(x) * x_sigmoid * float(g); -} -""" -swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) -swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) - - -class SwiGLUFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, y): - ctx.save_for_backward(x, y) - return swiglu_fwd(x, y) - - @staticmethod - def backward(ctx, dout): - x, y = ctx.saved_tensors - return swiglu_bwd(x, y, dout) - -swiglu = SwiGLUFunction.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/fused_dense.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/fused_dense.py deleted file mode 100644 index 6b4033d134e4093fe278f7b3f8c7d3128ce9f36d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/fused_dense.py +++ /dev/null @@ -1,688 +0,0 @@ -# Copyright (c) 2023, Tri Dao. -# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py -# We make it work with pytorch amp and with bfloat16. -# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py -from functools import partial -from typing import Optional - -# import fused_dense_cuda # from apex -import fused_dense_lib as fused_dense_cuda -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup - -from flash_attn.utils.torch import custom_fwd, custom_bwd -from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd -from flash_attn.utils.distributed import ( - all_gather_raw, - all_reduce, - all_reduce_raw, - reduce_scatter, - reduce_scatter_raw, -) - - -class FusedDenseFunc(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True - ): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather_raw of x before doing the matmul. - """ - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - - if torch.is_autocast_enabled(): - weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) - bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - weight = weight.contiguous() - if process_group is not None and sequence_parallel: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - output = F.linear(total_x, weight, bias) - if ctx.compute_weight_gradient: - ctx.save_for_backward(x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - if ctx.compute_weight_gradient: - x, weight = ctx.saved_tensors - if process_group is not None and sequence_parallel: - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - else: - (weight,) = ctx.saved_tensors - total_x = None - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, weight.t()) - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw - grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.needs_input_grad[1]: - assert ctx.compute_weight_gradient - if process_group is not None and sequence_parallel: - handle_x.wait() - grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] - ) - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None - if process_group is not None and ctx.needs_input_grad[0]: - handle_grad_input.wait() - return grad_input, grad_weight, grad_bias, None, None, None - - -def fused_dense_func( - x: Tensor, - weight: Tensor, - bias: Optional[Tensor] = None, - return_residual: bool = False, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, -): - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FusedDenseFunc.apply( - x, weight, bias, return_residual, process_group, sequence_parallel - ) - else: - assert process_group is None - out = F.linear(x, weight, bias) - return out if not return_residual else (out, x) - - -class FusedDense(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - return_residual: bool = False, - device=None, - dtype=None, - ) -> None: - super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - self.return_residual = return_residual - - def forward(self, x, process_group=None): - """ - If process_group is not None, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - """ - return fused_dense_func( - x, - self.weight, - self.bias, - return_residual=self.return_residual, - process_group=process_group, - ) - - -class ColumnParallelLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - multiple_of=1, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - if out_features % multiple_of: - raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") - multiple = out_features // multiple_of - # We want to split @multiple across world_size, but it could be an uneven split - div = multiple // world_size - mod = multiple % world_size - # The first @mod ranks get @div + 1 copies, the rest get @div copies - local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) - super().__init__( - in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype - ) - self.process_group = process_group - self.sequence_parallel = sequence_parallel - - def forward(self, x): - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - return fused_dense_func( - x, - self.weight, - self.bias, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - ) - - -class RowParallelLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - multiple_of=1, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - rank = torch.distributed.get_rank(process_group) - if in_features % multiple_of: - raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") - multiple = in_features // multiple_of - # We want to split @multiple across world_size, but it could be an uneven split - div = multiple // world_size - mod = multiple % world_size - # The first @mod ranks get @div + 1 copies, the rest get @div copies - local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) - # Only rank 0 will have bias - super().__init__( - local_multiple * multiple_of, - out_features, - bias=bias and rank == 0, - device=device, - dtype=dtype, - ) - self.process_group = process_group - self.sequence_parallel = sequence_parallel - - def forward(self, x): - """ - We're doing Tensor Parallel with sequence parallelism: we do the matmul and then - a reduce_scatter of the result. - """ - out = fused_dense_func(x, self.weight, self.bias) - reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce - return reduce_fn(out, self.process_group) - - -class FusedMLPFunc(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, - x, - weight1, - bias1, - weight2, - bias2, - activation="gelu_approx", - save_pre_act=True, - return_residual=False, - checkpoint_lvl=0, - heuristic=0, - process_group=None, - sequence_parallel=True, - ): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather of x before doing the matmul. - If sequence_parallel=False, then the input is already gathered. - - checkpoint_lvl: - 0: no recomputation in the bwd - 1: recompute gelu_out / relu_out in the bwd - 2: recompute pre_act and gelu_out / relu_out in the bwd - """ - assert -1 <= heuristic <= 4 - assert activation in ["gelu_approx", "relu", "sqrelu"] - if activation == "sqrelu": - assert heuristic == -1 - if not save_pre_act: - checkpoint_lvl = 2 - assert checkpoint_lvl in [0, 1, 2] - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - ctx.checkpoint_lvl = checkpoint_lvl - ctx.activation = activation - ctx.heuristic = heuristic - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_gpu_dtype() - weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]] - bias1 = bias1.to(dtype=dtype) if bias1 is not None else None - bias2 = bias2.to(dtype=dtype) if bias2 is not None else None - weight1 = weight1.contiguous() - bias1 = bias1.contiguous() if bias1 is not None else None - weight2 = weight2.contiguous() - bias2 = bias2.contiguous() if bias2 is not None else None - if process_group is not None and sequence_parallel: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - if heuristic == -1: - pre_act = F.linear(total_x, weight1, bias1) - activation_fn = ( - partial(F.gelu, approximate="tanh") - if activation == "gelu_approx" - else (sqrelu_fwd if activation == "sqrelu" else F.relu) - ) - with torch.jit.fuser("fuser2"): - output1 = activation_fn(pre_act) - # This is before adding bias1 - # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1) - # with torch.jit.fuser('fuser2'): - # output1 = bias_gelu(pre_act, bias1) - else: - is_gelu = activation == "gelu_approx" - output1, *rest = fused_dense_cuda.linear_act_forward( - total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic - ) - if save_pre_act: - pre_act = rest[0] - output2 = F.linear(output1, weight2, bias2) - if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): - # For RELU the pre_act is very small (just a bit-mask) so we just save it - ctx.save_for_backward(x, weight1, weight2, pre_act, output1) - elif checkpoint_lvl == 1: - ctx.save_for_backward(x, weight1, weight2, pre_act) - elif checkpoint_lvl == 2: - ctx.save_for_backward(x, weight1, weight2, bias1) - output2 = output2.reshape(*batch_shape, output2.shape[-1]) - return output2 if not return_residual else (output2, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - checkpoint_lvl = ctx.checkpoint_lvl - activation = ctx.activation - activation_fn = ( - partial(F.gelu, approximate="tanh") - if activation == "gelu_approx" - else (sqrelu_fwd if activation == "sqrelu" else F.relu) - ) - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - x, weight1, weight2, *rest = ctx.saved_tensors - if process_group is None or not sequence_parallel: - total_x = x - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - if checkpoint_lvl in [0, 1]: - if process_group is not None and sequence_parallel: - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): - pre_act, output1 = rest - elif checkpoint_lvl == 1: - (pre_act,) = rest - with torch.jit.fuser("fuser2"): - output1 = activation_fn(pre_act) - elif checkpoint_lvl == 2: - (bias1,) = rest - if process_group is not None and sequence_parallel: - total_x, _ = all_gather_raw(x, process_group) - if ctx.heuristic == -1: - pre_act = F.linear(total_x, weight1, bias1) - with torch.jit.fuser("fuser2"): - output1 = activation_fn(pre_act) - else: - output1, pre_act = fused_dense_cuda.linear_act_forward( - total_x.reshape(batch_dim, total_x.shape[-1]), - weight1, - bias1, - activation == "gelu_approx", - True, - ctx.heuristic, - ) - - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - output1 = output1.reshape(batch_dim, output1.shape[-1]) - pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1]) - if ctx.needs_input_grad[3]: - grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad( - output1, grad_output, ctx.needs_input_grad[4] - ) - else: - grad_weight2 = None - grad_bias2 = grad_output if ctx.needs_input_grad[4] else None - if ctx.heuristic == -1: - # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act) - grad_output1 = F.linear(grad_output, weight2.t()) - activation_grad_fn = ( - gelu_bwd - if activation == "gelu_approx" - else (sqrelu_bwd if activation == "sqrelu" else relu_bwd) - ) - with torch.jit.fuser("fuser2"): - grad_pre_act = activation_grad_fn(grad_output1, pre_act) - else: - # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't - # just compute gelu/relu grad - grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad( - weight2, grad_output, pre_act, activation == "gelu_approx", ctx.heuristic - ) - if not ctx.needs_input_grad[2]: - grad_bias1 = None - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_pre_act, weight1.t()) - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_pre_act, weight1 - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw - grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.heuristic == -1: - if ctx.needs_input_grad[1]: - if process_group is not None and sequence_parallel and checkpoint_lvl != 2: - handle_x.wait() - grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), - grad_pre_act, - ctx.needs_input_grad[2], - ) - else: - grad_weight1 = None - grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None - else: - if ctx.needs_input_grad[1]: - if process_group is not None and sequence_parallel and checkpoint_lvl != 2: - handle_x.wait() - grad_weight1 = F.linear( - grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t() - ) - else: - grad_weight1 = None - if process_group is not None and ctx.needs_input_grad[0]: - handle_grad_input.wait() - return ( - grad_input, - grad_weight1, - grad_bias1, - grad_weight2, - grad_bias2, - None, - None, - None, - None, - None, - None, - None, - ) - - -def fused_mlp_func( - x: Tensor, - weight1: Tensor, - weight2: Tensor, - bias1: Optional[Tensor] = None, - bias2: Optional[Tensor] = None, - activation: str = "gelu_approx", - save_pre_act: bool = True, - return_residual: bool = False, - checkpoint_lvl: int = 0, - heuristic: int = 0, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, -): - assert activation in ["gelu_approx", "relu", "sqrelu"] - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu) - dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == "relu" else 8) == 0) - if ( - x.is_cuda - and weight1.is_cuda - and weight2.is_cuda - and (bias1 is None or bias1.is_cuda) - and (bias2 is None or bias2.is_cuda) - and dtype_eligible - and dim_eligible - ): - return FusedMLPFunc.apply( - x, - weight1, - bias1, - weight2, - bias2, - activation, - save_pre_act, - return_residual, - checkpoint_lvl, - heuristic, - process_group, - sequence_parallel, - ) - else: - assert process_group is None - pre_act = F.linear(x, weight1, bias1) - activation_fn = ( - partial(F.gelu, approximate="tanh") - if activation == "gelu_approx" - else partial(F.relu, inplace=True) - ) - output1 = activation_fn(pre_act) - output2 = F.linear(output1, weight2, bias2) - return output2 if not return_residual else (output2, x) - - -class FusedMLP(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - bias1=True, - bias2=True, - activation="gelu_approx", - return_residual=False, - checkpoint_lvl=0, - heuristic="auto", - device=None, - dtype=None, - ): - """ - If process_group is not None, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul, gelu, then matmul. - Finally we do a reduce_scatter of the output. - - checkpoint_lvl (increasing lvl means slower but more memory saving): - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute pre_act and gelu_out in the bwd - heuristic: - -1: don't fuse gemm + gelu (separate kernel) - 0..4: use this heuristic for the algo section in the fused gemm + gelu - 'auto': heuristic will be picked automatically: - For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. - For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. - For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation - is slower than the unfused version. - return_residual: whether to return the input x along with the output. This is for - performance reason: for post-norm architecture, returning the input allows us - to fuse the backward of nn.Linear with the residual connection. - """ - assert checkpoint_lvl in [0, 1, 2] - assert activation in ["gelu_approx", "relu", "sqrelu"] - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features * 4 - self.activation = activation - self.return_residual = return_residual - self.checkpoint_lvl = checkpoint_lvl - self.heuristic = heuristic if activation != "sqrelu" else -1 - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) - - def forward(self, x, process_group=None): - dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() - if self.heuristic == "auto": - if self.activation == "gelu_approx": - if torch.cuda.get_device_capability("cuda") == (9, 0): - heuristic = -1 - else: - cuda_ver = tuple(map(int, torch.version.cuda.split("."))) - heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) - else: - heuristic = 0 - else: - heuristic = self.heuristic - out = fused_mlp_func( - x, - self.fc1.weight, - self.fc2.weight, - self.fc1.bias, - self.fc2.bias, - activation=self.activation, - save_pre_act=self.training, - return_residual=self.return_residual, - checkpoint_lvl=self.checkpoint_lvl, - heuristic=heuristic, - process_group=process_group, - ) - if self.return_residual: - out, x = out - if process_group is not None: - out = reduce_scatter(out, process_group) - return out if not self.return_residual else (out, x) - - -class ParallelFusedMLP(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - activation="gelu_approx", - process_group: ProcessGroup = None, - bias1=True, - bias2=True, - sequence_parallel=True, - checkpoint_lvl=0, - heuristic="auto", - device=None, - dtype=None, - ): - """ - process_group is required. We're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul, gelu, then matmul. - Finally we do a reduce_scatter of the output. - - checkpoint_lvl (increasing lvl means slower but more memory saving): - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute pre_act and gelu_out in the bwd - heuristic: - -1: don't fuse gemm + gelu (separate kernel) - 0..4: use this heuristic for the algo section in the fused gemm + gelu - 'auto': heuristic will be picked automatically: - For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. - For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. - """ - assert checkpoint_lvl in [0, 1, 2] - assert activation in ["gelu_approx", "relu", "sqrelu"] - assert process_group is not None - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features * 4 - self.activation = activation - self.process_group = process_group - self.sequence_parallel = sequence_parallel - self.checkpoint_lvl = checkpoint_lvl - self.heuristic = heuristic if activation != "sqrelu" else -1 - self.fc1 = ColumnParallelLinear( - in_features, hidden_features, process_group, bias=bias1, **factory_kwargs - ) - self.fc2 = RowParallelLinear( - hidden_features, out_features, process_group, bias=bias2, **factory_kwargs - ) - - def forward(self, x): - dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() - if self.heuristic == "auto": - if self.activation == "gelu_approx": - cuda_ver = tuple(map(int, torch.version.cuda.split("."))) - heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) - else: - heuristic = 0 - else: - heuristic = self.heuristic - out = fused_mlp_func( - x, - self.fc1.weight, - self.fc2.weight, - self.fc1.bias, - self.fc2.bias, - activation=self.activation, - save_pre_act=self.training, - checkpoint_lvl=self.checkpoint_lvl, - heuristic=heuristic, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - ) - reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce - return reduce_fn(out, self.process_group) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/layer_norm.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/layer_norm.py deleted file mode 100644 index 4b6cd798fd02844ef9cd3897f8ab95e490e638bf..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/layer_norm.py +++ /dev/null @@ -1,800 +0,0 @@ -# Copyright (c) 2022, Tri Dao. -# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py - -import dropout_layer_norm -import torch -from torch.nn import init - - -def maybe_align(x, alignment_in_bytes=16): - """Assume that x already has last dim divisible by alignment_in_bytes""" - # TD [2023-07-04] I'm not 100% sure that clone will align the memory - # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 - return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() - - -def _dropout_add_layer_norm_forward( - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma.numel() - x0mat = x0.view((-1, hidden_size)) - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - rowscale = rowscale.view(-1) if rowscale is not None else None - zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, - residualmat, - gamma, - beta, - rowscale, - colscale, - None, - None, - dropout_p, - epsilon, - 1.0, - 0, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma - - -def _dropout_add_layer_norm_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - dropout_p, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - x0 must not be None if we have colscale. - """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(xmat.shape) - dxmat = dx.view(xmat.shape) if dx is not None else None - x0mat = x0.view((-1, hidden_size)) if x0 is not None else None - rowscale = rowscale.view(-1) if rowscale is not None else None - if colscale is not None: - assert x0 is not None, "x0 is required to compute the gradient of colscale" - dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, - dxmat, - xmat, - x0mat, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - None, - None, - dropout_p, - 1.0, - 0, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - if colscale is None: - return dx0mat, dresidualmat, dgamma, dbeta - else: - dcolscale = rest[0] - return dx0mat, dresidualmat, dgamma, dbeta, dcolscale - - -def _dropout_add_layer_norm_subset_forward( - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma.numel() - x0mat = x0.view((-1, hidden_size)) - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - x0_subset = x0_subset.view(-1) if x0_subset is not None else None - out_subset = out_subset.view(-1) if out_subset is not None else None - zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, - residualmat, - gamma, - beta, - None, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma - - -def _dropout_add_layer_norm_subset_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - colscale, - x0_subset, - out_subset, - dropout_p, - rowscale_const, - x0_numrows, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - x0 must not be None if we have colscale. - """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(-1, hidden_size) - dxmat = dx.view(xmat.shape) if dx is not None else None - x0mat = x0.view((-1, hidden_size)) if x0 is not None else None - x0_subset = x0_subset.view(-1) if x0_subset is not None else None - out_subset = out_subset.view(-1) if out_subset is not None else None - if colscale is not None: - assert x0 is not None, "x0 is required to compute the gradient of colscale" - dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, - dxmat, - xmat, - x0mat, - dmask, - mu, - rsigma, - gamma, - None, - colscale, - x0_subset, - out_subset, - dropout_p, - rowscale_const, - x0_numrows, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - if colscale is None: - return dx0mat, dresidualmat, dgamma, dbeta - else: - dcolscale = rest[0] - return dx0mat, dresidualmat, dgamma, dbeta, dcolscale - - -def _dropout_add_layer_norm_parallel_residual_forward( - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma0.numel() - x0mat = x0.view((-1, hidden_size)) - x1mat = x1.view((-1, hidden_size)) if x1 is not None else None - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - ( - z0mat, - z1mat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( - x0mat, - x1mat, - residualmat, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask0 and dmask1 are None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma - - -def _dropout_add_layer_norm_parallel_residual_backward( - dz0, - dz1, - dx, - x, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - """ - hidden_size = gamma0.numel() - xmat = x.view((-1, hidden_size)) - dz0mat = dz0.view(xmat.shape) - dz1mat = dz1.view(xmat.shape) if dz1 is not None else None - dxmat = dx.view(xmat.shape) if dx is not None else None - ( - dx0mat, - dx1mat, - dresidualmat, - dgamma0, - dbeta0, - dgamma1, - dbeta1, - *rest, - ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( - dz0mat, - dz1mat, - dxmat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 - - -class DropoutAddLayerNormFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma = maybe_align(gamma.contiguous(), 16) - beta = maybe_align(beta.contiguous(), 16) if beta is not None else None - rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None - colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32, - is_rms_norm, - ) - # Only need to save x0 if we need to compute gradient wrt colscale - x0_saved = x0 if colscale is not None else None - ctx.save_for_backward( - xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta is not None - if not return_dmask: - return ( - zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) - ) - else: - dmask = ( - dmask.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask) - return ( - (zmat.view(x0.shape), dmask) - if not prenorm - else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) - ) - - @staticmethod - def backward(ctx, dz, *args): - # assert dz.is_contiguous() - dz = maybe_align(dz.contiguous(), 16) # this happens! - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors - # x0 is None if colscale is None - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - dropout_p, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(x.shape) - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - dcolscale = rest[0] if colscale is not None else None - return ( - dx0, - dresidual, - dgamma, - dbeta if ctx.has_beta else None, - None, - dcolscale, - None, - None, - None, - None, - None, - None, - ) - - -class DropoutAddLayerNormSubsetFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma = maybe_align(gamma.contiguous(), 16) - beta = maybe_align(beta.contiguous(), 16) if beta is not None else None - colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - is_rms_norm, - ) - # Only need to save x0 if we need to compute gradient wrt colscale - x0_saved = x0 if colscale is not None else None - x_shape = (-1, *x0.shape[1:]) - ctx.save_for_backward( - xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.rowscale_const = rowscale_const - ctx.x0_numrows = x0.shape[:-1].numel() - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta is not None - z_shape = (-1, *x0.shape[1:]) - if not return_dmask: - return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) - else: - z = zmat.view(z_shape) - dmask = ( - dmask.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask) - return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) - - @staticmethod - def backward(ctx, dz, *args): - # assert dz.is_contiguous() - dz = maybe_align(dz.contiguous(), 16) # this happens! - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors - # x0 is None if colscale is None - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - colscale, - x0_subset, - out_subset, - dropout_p, - ctx.rowscale_const, - ctx.x0_numrows, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(-1, *x.shape[1:]) - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - dcolscale = rest[0] if colscale is not None else None - return ( - dx0, - dresidual, - dgamma, - dbeta if ctx.has_beta else None, - dcolscale, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma0 = maybe_align(gamma0.contiguous(), 16) - beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None - gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None - beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None - ( - z0mat, - z1mat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - ) = _dropout_add_layer_norm_parallel_residual_forward( - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32, - is_rms_norm, - ) - ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.has_x1 = x1 is not None - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta0 is not None - z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) - if not return_dmask: - return z if not prenorm else (*z, xmat.view(x0.shape)) - else: - dmask0 = ( - dmask0.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - dmask1 = ( - dmask1.view(x0.shape) - if dropout_p > 0.0 and x1 is not None - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask0) - ctx.mark_non_differentiable(dmask1) - return ( - (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) - ) - - @staticmethod - def backward(ctx, dz0, dz1, *args): - dz0 = maybe_align(dz0.contiguous(), 16) # this happens! - dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors - dropout_p = ctx.dropout_p - has_x1 = ctx.has_x1 - has_residual = ctx.has_residual - ( - dx0mat, - dx1mat, - dresidualmat, - dgamma0, - dbeta0, - dgamma1, - dbeta1, - ) = _dropout_add_layer_norm_parallel_residual_backward( - dz0, - dz1, - dx, - x, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(x.shape) - dx1 = dx1mat.view(x.shape) if dx1mat is not None else None - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - return ( - dx0, - dx1, - dresidual, - dgamma0, - dbeta0 if ctx.has_beta else None, - dgamma1, - dbeta1 if ctx.has_beta else None, - None, - None, - None, - None, - None, - None, - ) - - -def layer_norm(x, weight, bias, epsilon): - return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) - - -def dropout_add_layer_norm( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - rowscale=None, - layerscale=None, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormFn.apply( - x0, - residual, - weight, - bias, - rowscale, - layerscale, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -def dropout_add_layer_norm_subset( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - layerscale=None, - x0_subset=None, - out_subset=None, - rowscale_const=1.0, - out_numrows=0, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormSubsetFn.apply( - x0, - residual, - weight, - bias, - layerscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -def dropout_add_layer_norm_parallel_residual( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormParallelResidualFn.apply( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -class DropoutAddLayerNorm(torch.nn.Module): - def __init__( - self, - hidden_size, - prenorm=False, - p=0.0, - eps=1e-5, - residual_in_fp32=False, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.prenorm = prenorm - self.p = p - self.eps = eps - self.residual_in_fp32 = residual_in_fp32 - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - init.zeros_(self.bias) - - def forward(self, x0, residual=None): - return dropout_add_layer_norm( - x0, - residual, - self.weight, - self.bias, - self.p if self.training else 0.0, - self.eps, - prenorm=self.prenorm, - residual_in_fp32=self.residual_in_fp32, - ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/rms_norm.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/rms_norm.py deleted file mode 100644 index 068348d61290e3839dd082b540d898578ba1e8e2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/rms_norm.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) 2022, Tri Dao. -# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py - -import torch -from torch.nn import init - -from flash_attn.ops.layer_norm import ( - DropoutAddLayerNormFn, - DropoutAddLayerNormParallelResidualFn, - DropoutAddLayerNormSubsetFn, -) - - -def rms_norm(x, weight, epsilon): - return DropoutAddLayerNormFn.apply( - x, None, weight, None, None, None, 0.0, epsilon, False, False, True - ) - - -def dropout_add_rms_norm( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - rowscale=None, - layerscale=None, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormFn.apply( - x0, - residual, - weight, - bias, - rowscale, - layerscale, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - True, - return_dropout_mask, - ) - - -def dropout_add_rms_norm_subset( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - layerscale=None, - x0_subset=None, - out_subset=None, - rowscale_const=1.0, - out_numrows=0, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormSubsetFn.apply( - x0, - residual, - weight, - bias, - layerscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - prenorm, - True, - return_dropout_mask, - ) - - -def dropout_add_rms_norm_parallel_residual( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormParallelResidualFn.apply( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - True, - return_dropout_mask, - ) - - -class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - - def forward(self, x): - return rms_norm(x, self.weight, self.eps) - - -class DropoutAddRMSNorm(torch.nn.Module): - def __init__( - self, - hidden_size, - prenorm=False, - p=0.0, - eps=1e-5, - residual_in_fp32=False, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.prenorm = prenorm - self.p = p - self.eps = eps - self.residual_in_fp32 = residual_in_fp32 - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - - def forward(self, x0, residual=None): - return dropout_add_rms_norm( - x0, - residual, - self.weight, - None, - self.p if self.training else 0.0, - self.eps, - prenorm=self.prenorm, - residual_in_fp32=self.residual_in_fp32, - ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/__init__.py deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/cross_entropy.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/cross_entropy.py deleted file mode 100644 index 1b5a415b73f236f3e05fb14b9141959559e18526..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/cross_entropy.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -from typing import Tuple, Optional, Union - -import torch -import torch.nn.functional as F - -import triton -import triton.language as tl - -# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for -# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent -# version of PyTorch. The following 2 lines are for backward compatibility with -# older PyTorch. -if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base - - -@triton.heuristics( - { - "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, - } -) -@triton.jit -def cross_entropy_fwd_kernel( - loss_ptr, # data ptrs - lse_ptr, - z_loss_ptr, - logits_ptr, - labels_ptr, - smoothing, - logit_scale, - lse_square_scale, - ignore_index, - total_classes, - class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes - n_cols, # shapes - logits_row_stride, # strides - BLOCK_SIZE: tl.constexpr, - HAS_SMOOTHING: tl.constexpr, - # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE - SPLIT: tl.constexpr, - PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0) -): - row_idx = tl.program_id(0) - logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) - sum_logits = 0.0 # For smoothing - if not PRECOMPUTED_LSE: - # Statistics for online softmax - m_i = -float("inf") - l_i = 0.0 - for col_offset in range(0, n_cols, BLOCK_SIZE): - cols = col_offset + tl.arange(0, BLOCK_SIZE) - logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( - tl.float32 - ) * logit_scale - if HAS_SMOOTHING: - sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0)) - m_i_new = tl.maximum(m_i, tl.max(logits)) - l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new)) - m_i = m_i_new - lse = tl.log(l_i) + m_i - tl.store(lse_ptr + row_idx, lse) - else: - lse = tl.load(lse_ptr + row_idx) - label_idx = tl.load(labels_ptr + row_idx) - if label_idx == ignore_index: - loss = 0.0 - z_loss = 0.0 - else: - label_idx -= class_start_idx - if label_idx >= 0 and label_idx < n_cols: - logits_label = tl.load(logits_ptr + label_idx) * logit_scale - if HAS_SMOOTHING: - loss = ( - (lse if not SPLIT else 0.0) - - smoothing * sum_logits / total_classes - - (1 - smoothing) * logits_label - ) - else: - loss = (lse if not SPLIT else 0.0) - logits_label - else: - # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss - if HAS_SMOOTHING: - loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) - else: - loss = 0.0 - if not SPLIT: - z_loss = lse_square_scale * lse * lse - loss += z_loss - else: - z_loss = 0.0 - tl.store(loss_ptr + row_idx, loss) - if not SPLIT: - tl.store(z_loss_ptr + row_idx, z_loss) - - -@triton.heuristics( - { - "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, - } -) -@triton.jit -def cross_entropy_bwd_kernel( - dlogits_ptr, # data ptrs - dloss_ptr, - logits_ptr, - lse_ptr, - labels_ptr, - smoothing, - logit_scale, - lse_square_scale, - ignore_index, - total_classes, - class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes - n_cols, # shapes - logits_row_stride, # strides - dlogits_row_stride, - dloss_row_stride, - BLOCK_SIZE: tl.constexpr, - HAS_SMOOTHING: tl.constexpr, -): - row_idx = tl.program_id(0) - col_block_idx = tl.program_id(1) - logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) - dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) - col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - label_idx = tl.load(labels_ptr + row_idx) - if label_idx != ignore_index: - dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) - else: - dloss = 0.0 - logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( - tl.float32 - ) * logit_scale - lse = tl.load(lse_ptr + row_idx) - probs = tl.exp(logits - lse) - probs += 2.0 * lse_square_scale * lse * probs - label_idx -= class_start_idx - if HAS_SMOOTHING: - smooth_positive = 1.0 - smoothing - smooth_negative = smoothing / total_classes - probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative - else: - probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) - tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) - - -class CrossEntropyLoss(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - logits, - labels, - precomputed_lse=None, - smoothing=0.0, - logit_scale=1.0, - lse_square_scale=0.0, - ignore_index=-100, - inplace_backward=False, - process_group=None, - ): - # For some reason Triton generates wrong code when labels has dtype long and its address - # is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index. - if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: - labels = F.pad(labels, (0, 1))[..., :-1] - assert labels.data_ptr() % 16 == 0 - assert logit_scale > 0.0 - n_rows, n_cols = logits.shape - assert labels.shape == (n_rows,) - world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) - total_classes = world_size * n_cols - rank = 0 if process_group is None else torch.distributed.get_rank(process_group) - class_start_idx = rank * n_cols - use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0 - - if logits.stride(-1) != 1: - logits = logits.contiguous() - MAX_BLOCK_SIZE = 16 * 1024 - BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) - num_warps = ( - 4 - if BLOCK_SIZE < 2048 - else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) - ) - losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - if use_precomputed_lse: - assert precomputed_lse.shape == (n_rows,) - lse = precomputed_lse.contiguous() - else: - lse = torch.empty(n_rows, dtype=torch.float, device=logits.device) - z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(logits.device.index): - cross_entropy_fwd_kernel[(n_rows,)]( - losses, # data ptrs - lse, - z_losses, - logits, - labels, - smoothing, - logit_scale, - lse_square_scale, - ignore_index, - total_classes, - class_start_idx, - n_cols, # shapes - logits.stride(0), # strides - BLOCK_SIZE=BLOCK_SIZE, # constants - SPLIT=world_size > 1, - PRECOMPUTED_LSE=use_precomputed_lse, - num_warps=num_warps, - ) - - if world_size > 1: - # If there's no smoothing, if labels are in the vocab of this partition, losses contains - # - predicted logit, and 0 otherwise. - # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains - # -0.9 * predicted logit - 0.1 * sum logit / total_classes. - # For labels not in the vocab of this partition, losses contains - # -0.1 * sum logit / total_classes. - if world_size > 1: - lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) - torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) - handle_losses = torch.distributed.all_reduce( - losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True - ) - lse = torch.logsumexp(lse_allgather, dim=0) - handle_losses.wait() - # After the allreduce, if there's no smoothing, the total losses are - predicted_logit, - # we just have to add the (global) lse. - # If there's smoothing=0.1, the total losses are - # -0.9 * predicted_logit - 0.1 * sum logit / total_classes. - # Again, we just have to add the (global) lse. - losses += lse - if lse_square_scale != 0.0: - z_losses = lse_square_scale * lse.square() - z_losses.masked_fill_(labels == ignore_index, 0.0) - losses += z_losses - else: - z_losses = torch.zeros_like(losses) - losses.masked_fill_(labels == ignore_index, 0.0) - - ctx.save_for_backward(logits, lse, labels) - ctx.mark_non_differentiable(z_losses) - ctx.smoothing = smoothing - ctx.logit_scale = logit_scale - ctx.lse_square_scale = lse_square_scale - ctx.ignore_index = ignore_index - ctx.total_classes = total_classes - ctx.class_start_idx = class_start_idx - ctx.inplace_backward = inplace_backward - return losses, z_losses - - @staticmethod - def backward(ctx, grad_losses, grad_z_losses): - del grad_z_losses # z_losses are only for logging. - - logits, lse, labels = ctx.saved_tensors - dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) - n_rows, n_cols = logits.shape - BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) - num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) - grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(logits.device.index): - cross_entropy_bwd_kernel[grid]( - dlogits, # data ptrs - grad_losses, - logits, - lse, - labels, - ctx.smoothing, - ctx.logit_scale, - ctx.lse_square_scale, - ctx.ignore_index, - ctx.total_classes, - ctx.class_start_idx, - n_cols, # shapes - logits.stride(0), # strides - dlogits.stride(0), - grad_losses.stride(0), - BLOCK_SIZE=BLOCK_SIZE, # constants - num_warps=num_warps, - ) - return dlogits, None, None, None, None, None, None, None, None, None - - -def cross_entropy_loss( - logits: torch.Tensor, - labels: torch.Tensor, - precomputed_lse: Optional[torch.Tensor] = None, - label_smoothing: float = 0.0, - logit_scale: float = 1.0, - lse_square_scale: float = 0.0, - ignore_index=-100, - inplace_backward: bool = False, - process_group=None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Arguments: - logits: (batch, vocab_size) - labels: (batch,) - label_smoothing: float - logit_scale: float. Multiply logits by this scale before calculating the loss. - lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. - This is also referred to as "z-loss". - ignore_index: int. If labels == ignore_index, the loss is set to 0.0. - inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. - This saves memory. - process_group: if not None, we're doing Tensor Parallel: each process is responsible for - one part of the vocab. The loss will be aggregated across processes. - Returns: - losses: (batch,), float - z_losses: (batch,), float - """ - return CrossEntropyLoss.apply( - logits, - labels, - precomputed_lse, - label_smoothing, - logit_scale, - lse_square_scale, - ignore_index, - inplace_backward, - process_group, - ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/k_activations.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/k_activations.py deleted file mode 100644 index efb83c358eb4a85d069ee340a3c83f418f9a805b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/k_activations.py +++ /dev/null @@ -1,162 +0,0 @@ -# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import math -from enum import Enum -from typing import Optional - -import triton -import triton.language as tl - -_sqrt2pi = math.sqrt(2.0 / math.pi) -_sqrt1_2 = math.sqrt(1.0 / 2) -_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) - - -class Activation(str, Enum): - SquaredReLU = "squared_relu" - GeLU = "gelu" - GeLUApprox = "gelu_approx" - LeakyReLU = "leaky_relu" - ReLU = "relu" - - -def get_triton_activation_kernel(activation: Optional[Activation]): - return ( - { - Activation.ReLU: relu, - Activation.LeakyReLU: leaky_relu, - Activation.GeLU: gelu, - Activation.GeLUApprox: gelu_approx, - Activation.SquaredReLU: squared_relu, - }[activation] - if activation - else None - ) - - -def get_triton_activation_bwd_kernel(activation: Optional[Activation]): - return ( - { - Activation.ReLU: relu_grad, - Activation.LeakyReLU: leaky_relu_grad, - Activation.GeLU: gelu_grad, - Activation.GeLUApprox: gelu_approx_grad, - Activation.SquaredReLU: squared_relu_grad, - }[activation] - if activation - else None - ) - - -@triton.jit -def tanh(x): - # Tanh is just a scaled sigmoid - return 2 * tl.sigmoid(2 * x) - 1 - - -@triton.jit -def cosh(x): - exp_x = tl.exp(x) - return (exp_x + 1.0 / exp_x) * 0.5 - - -# a Triton implementation of the most used activations -# See for instance http://arxiv.org/abs/1606.08415 for an overview - -# ReLU -@triton.jit -def relu(x): - """ - ReLU_ activation function - - .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html - """ - zero = 0.0 - return tl.where(x >= 0, x, zero.to(x.dtype)) - - -@triton.jit -def relu_grad(x): - # ReLU is different from other activations - # in that it does not require the input to retrospectively compute its gradient - # here the input is the downstream gradient, and we return the upstream gradient directly - zero = 0.0 - one = 1.0 - return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype)) - - -@triton.jit -def squared_relu(x): - """ - Squared ReLU activation, as proposed in the Primer_ paper. - - .. _Primer: https://arxiv.org/abs/2109.08668 - """ - x_ = relu(x) - return (x_ * x_).to(x.dtype) - - -@triton.jit -def squared_relu_grad(x): - return tl.where(x >= 0, 2.0 * x, 0.0) - - -# Leaky ReLU -@triton.jit -def leaky_relu(x): - """ - LeakyReLU_ activation - - .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html - """ - scale = 0.01 + 0.0 - scale = scale.to(x.dtype) - return tl.where(x >= 0, x, scale * x) - - -@triton.jit -def leaky_relu_grad(x): - min_grad = 0.01 - max_grad = 1 - - min_grad = min_grad.to(x.dtype) - max_grad = max_grad.to(x.dtype) - - return tl.where(x >= 0, max_grad, min_grad) - - -@triton.jit -def gelu(x): - """Gaussian Error Linear Unit (GELU)""" - return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) - - -@triton.jit -def gelu_grad(x): - cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) - pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization - return cdf + x * pdf - - -@triton.jit -def gelu_approx(x): - """ - GeLU_ activation - Gaussian error linear unit, with tanh approximation - - .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf - """ - return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x))) - - -@triton.jit -def gelu_approx_grad(x): - # CREDITS: Fast implementation proposed in - # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 - tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/layer_norm.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/layer_norm.py deleted file mode 100644 index 192cee474b160d1876fafd14c5e3d695e8ff237f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/layer_norm.py +++ /dev/null @@ -1,1252 +0,0 @@ -# Copyright (c) 2024, Tri Dao. -# Implement dropout + residual + layer_norm / rms_norm. - -# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html -# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. -# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. -# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. - -import math -from typing import Optional, List - -import torch -import torch.nn.functional as F -from torch import Tensor - -import triton -import triton.language as tl - -from flash_attn.utils.torch import custom_fwd, custom_bwd -from flash_attn.utils.library import triton_op - - -def maybe_contiguous_lastdim(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def maybe_contiguous(x): - return x.contiguous() if x is not None else None - - -def triton_autotune_configs(): - # Return configs with a valid warp count for the current device - configs = [] - # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 - max_threads_per_block = 1024 - # Default to warp size 32 if not defined by device - warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) - # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit - return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] - if warp_count * warp_size <= max_threads_per_block] - # return [triton.Config({}, num_warps=8)] - - -def layer_norm_ref( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - zero_centered_weight=False, - dropout_mask=None, - dropout_mask1=None, - upcast=False, -): - dtype = x.dtype - if upcast: - x = x.float() - weight = weight.float() - bias = bias.float() if bias is not None else None - residual = residual.float() if residual is not None else residual - x1 = x1.float() if x1 is not None else None - weight1 = weight1.float() if weight1 is not None else None - bias1 = bias1.float() if bias1 is not None else None - if zero_centered_weight: - weight = weight + 1.0 - if weight1 is not None: - weight1 = weight1 + 1.0 - if x1 is not None: - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - if rowscale is not None: - x = x * rowscale[..., None] - if dropout_p > 0.0: - if dropout_mask is not None: - x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) - else: - x = F.dropout(x, p=dropout_p) - if x1 is not None: - if dropout_mask1 is not None: - x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) - else: - x1 = F.dropout(x1, p=dropout_p) - if x1 is not None: - x = x + x1 - if residual is not None: - x = (x + residual).to(x.dtype) - out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( - dtype - ) - if weight1 is None: - return out if not prenorm else (out, x) - else: - out1 = F.layer_norm( - x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps - ).to(dtype) - return (out, out1) if not prenorm else (out, out1, x) - - -def rms_norm_ref( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - zero_centered_weight=False, - dropout_mask=None, - dropout_mask1=None, - upcast=False, -): - dtype = x.dtype - if upcast: - x = x.float() - weight = weight.float() - bias = bias.float() if bias is not None else None - residual = residual.float() if residual is not None else residual - x1 = x1.float() if x1 is not None else None - weight1 = weight1.float() if weight1 is not None else None - bias1 = bias1.float() if bias1 is not None else None - if zero_centered_weight: - weight = weight + 1.0 - if weight1 is not None: - weight1 = weight1 + 1.0 - if x1 is not None: - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - if rowscale is not None: - x = x * rowscale[..., None] - if dropout_p > 0.0: - if dropout_mask is not None: - x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) - else: - x = F.dropout(x, p=dropout_p) - if x1 is not None: - if dropout_mask1 is not None: - x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) - else: - x1 = F.dropout(x1, p=dropout_p) - if x1 is not None: - x = x + x1 - if residual is not None: - x = (x + residual).to(x.dtype) - rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) - out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) - if weight1 is None: - return out if not prenorm else (out, x) - else: - out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( - dtype - ) - return (out, out1) if not prenorm else (out, out1, x) - - -@triton.autotune( - configs=triton_autotune_configs(), - key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS", "HAS_X1", "HAS_W1", "HAS_B1"], -) -# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) -# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) -# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) -# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) -@triton.jit -def _layer_norm_fwd_1pass_kernel( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - RESIDUAL, # pointer to the residual - X1, - W1, - B1, - Y1, - RESIDUAL_OUT, # pointer to the residual - ROWSCALE, - SEEDS, # Dropout seeds for each row - DROPOUT_MASK, - DROPOUT_MASK1, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_res_row, - stride_res_out_row, - stride_x1_row, - stride_y1_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - dropout_p, # Dropout probability - zero_centered_weight, # If true, add 1.0 to the weight - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_RESIDUAL: tl.constexpr, - STORE_RESIDUAL_OUT: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_DROPOUT: tl.constexpr, - STORE_DROPOUT_MASK: tl.constexpr, - HAS_ROWSCALE: tl.constexpr, - HAS_X1: tl.constexpr, - HAS_W1: tl.constexpr, - HAS_B1: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - X += row * stride_x_row - Y += row * stride_y_row - if HAS_RESIDUAL: - RESIDUAL += row * stride_res_row - if STORE_RESIDUAL_OUT: - RESIDUAL_OUT += row * stride_res_out_row - if HAS_X1: - X1 += row * stride_x1_row - if HAS_W1: - Y1 += row * stride_y1_row - # Compute mean and variance - cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + row).to(tl.float32) - x *= rowscale - if HAS_DROPOUT: - # Compute dropout mask - # 7 rounds is good enough, and reduces register pressure - keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) - if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) - if HAS_X1: - x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) - x1 *= rowscale - if HAS_DROPOUT: - # Compute dropout mask - # 7 rounds is good enough, and reduces register pressure - keep_mask = ( - tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - ) - x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) - if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) - x += x1 - if HAS_RESIDUAL: - residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) - x += residual - if STORE_RESIDUAL_OUT: - tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) - if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w += 1.0 - if HAS_BIAS: - b = tl.load(B + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - y = x_hat * w + b if HAS_BIAS else x_hat * w - # Write output - tl.store(Y + cols, y, mask=mask) - if HAS_W1: - w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w1 += 1.0 - if HAS_B1: - b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) - y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 - tl.store(Y1 + cols, y1, mask=mask) - - -def _layer_norm_fwd( - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - residual: Optional[Tensor] = None, - x1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - residual_dtype: Optional[torch.dtype] = None, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - return_dropout_mask: bool = False, - out: Optional[Tensor] = None, - residual_out: Optional[Tensor] = None -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library - # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None - # so that _layer_norm_fwd_impl doesn't have to return them. - if out is None: - out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) - if residual is not None: - residual_dtype = residual.dtype - if residual_out is None and ( - residual is not None - or (residual_dtype is not None and residual_dtype != x.dtype) - or dropout_p > 0.0 - or rowscale is not None - or x1 is not None - ): - residual_out = torch.empty_like( - x, dtype=residual_dtype if residual_dtype is not None else x.dtype - ) - else: - residual_out = None - y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( - x, - weight, - bias, - eps, - out, - residual=residual, - x1=x1, - weight1=weight1, - bias1=bias1, - dropout_p=dropout_p, - rowscale=rowscale, - zero_centered_weight=zero_centered_weight, - is_rms_norm=is_rms_norm, - return_dropout_mask=return_dropout_mask, - residual_out=residual_out, - ) - # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 - if residual_out is None: - residual_out = x - return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 - - -# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema -# since we're returning a tuple of tensors -@triton_op("flash_attn::layer_norm_fwd_impl", mutates_args={"out", "residual_out"}, - schema="(Tensor x, Tensor weight, Tensor bias, float eps, Tensor(a!) out, Tensor? residual, Tensor? x1, Tensor? weight1, Tensor? bias1, float dropout_p, Tensor? rowscale, bool zero_centered_weight, bool is_rms_norm, bool return_dropout_mask, Tensor(a!)? residual_out) -> (Tensor y1, Tensor mean, Tensor rstd, Tensor seeds, Tensor dropout_mask, Tensor dropout_mask1)") -def _layer_norm_fwd_impl( - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - out: Tensor, - residual: Optional[Tensor] = None, - x1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - return_dropout_mask: bool = False, - residual_out: Optional[Tensor] = None -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - M, N = x.shape - assert x.stride(-1) == 1 - if residual is not None: - assert residual.stride(-1) == 1 - assert residual.shape == (M, N) - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - if x1 is not None: - assert x1.shape == x.shape - assert rowscale is None - assert x1.stride(-1) == 1 - if weight1 is not None: - assert weight1.shape == (N,) - assert weight1.stride(-1) == 1 - if bias1 is not None: - assert bias1.shape == (N,) - assert bias1.stride(-1) == 1 - if rowscale is not None: - assert rowscale.is_contiguous() - assert rowscale.shape == (M,) - assert out.shape == x.shape - assert out.stride(-1) == 1 - if residual_out is not None: - assert residual_out.shape == x.shape - assert residual_out.stride(-1) == 1 - if weight1 is not None: - y1 = torch.empty_like(out) - assert y1.stride(-1) == 1 - else: - y1 = None - mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None - rstd = torch.empty((M,), dtype=torch.float32, device=x.device) - if dropout_p > 0.0: - seeds = torch.randint( - 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 - ) - else: - seeds = None - if return_dropout_mask and dropout_p > 0.0: - dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) - if x1 is not None: - dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) - else: - dropout_mask1 = None - else: - dropout_mask, dropout_mask1 = None, None - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - with torch.cuda.device(x.device.index): - torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( - x, - out, - weight, - bias, - residual, - x1, - weight1, - bias1, - y1, - residual_out, - rowscale, - seeds, - dropout_mask, - dropout_mask1, - mean, - rstd, - x.stride(0), - out.stride(0), - residual.stride(0) if residual is not None else 0, - residual_out.stride(0) if residual_out is not None else 0, - x1.stride(0) if x1 is not None else 0, - y1.stride(0) if y1 is not None else 0, - M, - N, - eps, - dropout_p, - # Passing bool make torch inductor very unhappy since it then tries to compare to int_max - int(zero_centered_weight), - is_rms_norm, - BLOCK_N, - residual is not None, - residual_out is not None, - bias is not None, - dropout_p > 0.0, - dropout_mask is not None, - rowscale is not None, - HAS_X1=x1 is not None, - HAS_W1=weight1 is not None, - HAS_B1=bias1 is not None, - ) - return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 - - -@triton.autotune( - configs=triton_autotune_configs(), - key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], -) -# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) -# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) -# @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) -# @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) -# @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) -# @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) -# @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) -@triton.jit -def _layer_norm_bwd_kernel( - X, # pointer to the input - W, # pointer to the weights - B, # pointer to the biases - Y, # pointer to the output to be recomputed - DY, # pointer to the output gradient - DX, # pointer to the input gradient - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - DRESIDUAL, - W1, - DY1, - DX1, - DW1, - DB1, - DRESIDUAL_IN, - ROWSCALE, - SEEDS, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_dy_row, - stride_dx_row, - stride_dres_row, - stride_dy1_row, - stride_dx1_row, - stride_dres_in_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - dropout_p, - zero_centered_weight, - rows_per_program, - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_DRESIDUAL: tl.constexpr, - STORE_DRESIDUAL: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_DROPOUT: tl.constexpr, - HAS_ROWSCALE: tl.constexpr, - HAS_DY1: tl.constexpr, - HAS_DX1: tl.constexpr, - HAS_B1: tl.constexpr, - RECOMPUTE_OUTPUT: tl.constexpr, -): - # Map the program id to the elements of X, DX, and DY it should compute. - row_block_id = tl.program_id(0) - row_start = row_block_id * rows_per_program - # Do not early exit if row_start >= M, because we need to write DW and DB - cols = tl.arange(0, BLOCK_N) - mask = cols < N - X += row_start * stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += row_start * stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += row_start * stride_dres_in_row - DY += row_start * stride_dy_row - DX += row_start * stride_dx_row - if HAS_DY1: - DY1 += row_start * stride_dy1_row - if HAS_DX1: - DX1 += row_start * stride_dx1_row - if RECOMPUTE_OUTPUT: - Y += row_start * stride_y_row - w = tl.load(W + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w += 1.0 - if RECOMPUTE_OUTPUT and HAS_BIAS: - b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) - if HAS_DY1: - w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w1 += 1.0 - dw = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_BIAS: - db = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_DY1: - dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_B1: - db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) - row_end = min((row_block_id + 1) * rows_per_program, M) - for row in range(row_start, row_end): - # Load data to SRAM - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) - dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) - if HAS_DY1: - dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) - if not IS_RMS_NORM: - mean = tl.load(Mean + row) - rstd = tl.load(Rstd + row) - # Compute dx - xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - xhat = tl.where(mask, xhat, 0.0) - if RECOMPUTE_OUTPUT: - y = xhat * w + b if HAS_BIAS else xhat * w - tl.store(Y + cols, y, mask=mask) - wdy = w * dy - dw += dy * xhat - if HAS_BIAS: - db += dy - if HAS_DY1: - wdy += w1 * dy1 - dw1 += dy1 * xhat - if HAS_B1: - db1 += dy1 - if not IS_RMS_NORM: - c1 = tl.sum(xhat * wdy, axis=0) / N - c2 = tl.sum(wdy, axis=0) / N - dx = (wdy - (xhat * c1 + c2)) * rstd - else: - c1 = tl.sum(xhat * wdy, axis=0) / N - dx = (wdy - xhat * c1) * rstd - if HAS_DRESIDUAL: - dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) - dx += dres - # Write dx - if STORE_DRESIDUAL: - tl.store(DRESIDUAL_IN + cols, dx, mask=mask) - if HAS_DX1: - if HAS_DROPOUT: - keep_mask = ( - tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - ) - dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) - else: - dx1 = dx - tl.store(DX1 + cols, dx1, mask=mask) - if HAS_DROPOUT: - keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + row).to(tl.float32) - dx *= rowscale - tl.store(DX + cols, dx, mask=mask) - - X += stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += stride_dres_in_row - if RECOMPUTE_OUTPUT: - Y += stride_y_row - DY += stride_dy_row - DX += stride_dx_row - if HAS_DY1: - DY1 += stride_dy1_row - if HAS_DX1: - DX1 += stride_dx1_row - tl.store(DW + row_block_id * N + cols, dw, mask=mask) - if HAS_BIAS: - tl.store(DB + row_block_id * N + cols, db, mask=mask) - if HAS_DY1: - tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) - if HAS_B1: - tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) - - -def _layer_norm_bwd( - dy: Tensor, - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - mean: Tensor, - rstd: Tensor, - dresidual: Optional[Tensor] = None, - dy1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - seeds: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - has_residual: bool = False, - has_x1: bool = False, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - x_dtype: Optional[torch.dtype] = None, - recompute_output: bool = False, -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - # Need to wrap to handle the case where dresidual_in or dx1 are aliases of x, - # which makes torch.library unhappy - dx, dw, db, dresidual_in, dx1, dw1, db1, y = _layer_norm_bwd_impl( - dy, - x, - weight, - bias, - eps, - mean, - rstd, - dresidual, - dy1, - weight1, - bias1, - seeds, - dropout_p, - rowscale, - has_residual, - has_x1, - zero_centered_weight, - is_rms_norm, - x_dtype=x_dtype, - recompute_output=recompute_output, - ) - # Don't need to compute dresidual_in separately in this case - if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: - dresidual_in = dx - if has_x1 and dropout_p == 0.0: - dx1 = dx - return dx, dw, db, dresidual_in, dx1, dw1, db1, y - - - -@triton_op("flash_attn::layer_norm_bwd_impl", mutates_args={}, - schema="(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)", - allow_decomposition=False, # Don't let torch.compile trace inside - ) -def _layer_norm_bwd_impl( - dy: Tensor, - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - mean: Tensor, - rstd: Tensor, - dresidual: Optional[Tensor] = None, - dy1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - seeds: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - has_residual: bool = False, - has_x1: bool = False, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - x_dtype: Optional[torch.dtype] = None, - recompute_output: bool = False, -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - M, N = x.shape - assert x.stride(-1) == 1 - dy = maybe_contiguous_lastdim(dy) - assert dy.stride(-1) == 1 - assert dy.shape == (M, N) - if dresidual is not None: - dresidual = maybe_contiguous_lastdim(dresidual) - assert dresidual.stride(-1) == 1 - assert dresidual.shape == (M, N) - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - if dy1 is not None: - dy1 = maybe_contiguous_lastdim(dy1) - assert weight1 is not None - assert dy1.shape == dy.shape - assert dy1.stride(-1) == 1 - if weight1 is not None: - assert weight1.shape == (N,) - assert weight1.stride(-1) == 1 - if bias1 is not None: - assert bias1.shape == (N,) - assert bias1.stride(-1) == 1 - if seeds is not None: - assert seeds.is_contiguous() - assert seeds.shape == (M if not has_x1 else M * 2,) - if rowscale is not None: - assert rowscale.is_contiguous() - assert rowscale.shape == (M,) - # allocate output - dx = ( - torch.empty_like(x) - if x_dtype is None - else torch.empty(M, N, dtype=x_dtype, device=x.device) - ) - dresidual_in = ( - torch.empty_like(x) - if has_residual - and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) - else None - ) - dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None - y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None - if recompute_output: - assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" - - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the - # latency of the gmem reads/writes, but will increase the time of summing up dw / db. - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 - _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) - _db = ( - torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) - if bias is not None - else None - ) - _dw1 = torch.empty_like(_dw) if weight1 is not None else None - _db1 = torch.empty_like(_db) if bias1 is not None else None - rows_per_program = math.ceil(M / sm_count) - grid = (sm_count,) - with torch.cuda.device(x.device.index): - torch.library.wrap_triton(_layer_norm_bwd_kernel)[grid]( - x, - weight, - bias, - y, - dy, - dx, - _dw, - _db, - dresidual, - weight1, - dy1, - dx1, - _dw1, - _db1, - dresidual_in, - rowscale, - seeds, - mean, - rstd, - x.stride(0), - 0 if not recompute_output else y.stride(0), - dy.stride(0), - dx.stride(0), - dresidual.stride(0) if dresidual is not None else 0, - dy1.stride(0) if dy1 is not None else 0, - dx1.stride(0) if dx1 is not None else 0, - dresidual_in.stride(0) if dresidual_in is not None else 0, - M, - N, - eps, - dropout_p, - # Passing bool make torch inductor very unhappy since it then tries to compare to int_max - int(zero_centered_weight), - rows_per_program, - is_rms_norm, - BLOCK_N, - dresidual is not None, - dresidual_in is not None, - bias is not None, - dropout_p > 0.0, - HAS_ROWSCALE=rowscale is not None, - HAS_DY1=dy1 is not None, - HAS_DX1=dx1 is not None, - HAS_B1=bias1 is not None, - RECOMPUTE_OUTPUT=y is not None, - ) - dw = _dw.sum(0).to(weight.dtype) - db = _db.sum(0).to(bias.dtype) if bias is not None else None - dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None - db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None - # dresidual_in and dx1 could be None, the wrapper will handle assigning them from dx - return dx, dw, db, dresidual_in, dx1, dw1, db1, y - - -class LayerNormFn(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - zero_centered_weight=False, - is_rms_norm=False, - return_dropout_mask=False, - out_dtype=None, - out=None, - residual_out=None - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) - if residual is not None: - assert residual.shape == x_shape_og - residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) - if x1 is not None: - assert x1.shape == x_shape_og - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) - weight = weight.contiguous() - bias = maybe_contiguous(bias) - weight1 = maybe_contiguous(weight1) - bias1 = maybe_contiguous(bias1) - if rowscale is not None: - rowscale = rowscale.reshape(-1).contiguous() - residual_dtype = ( - residual.dtype - if residual is not None - else (torch.float32 if residual_in_fp32 else None) - ) - if out is not None: - out = out.reshape(-1, out.shape[-1]) - if residual_out is not None: - residual_out = residual_out.reshape(-1, residual_out.shape[-1]) - y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( - x, - weight, - bias, - eps, - residual, - x1, - weight1, - bias1, - dropout_p=dropout_p, - rowscale=rowscale, - out_dtype=out_dtype, - residual_dtype=residual_dtype, - zero_centered_weight=zero_centered_weight, - is_rms_norm=is_rms_norm, - return_dropout_mask=return_dropout_mask, - out=out, - residual_out=residual_out, - ) - ctx.save_for_backward( - residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd - ) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.dropout_p = dropout_p - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.has_x1 = x1 is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - ctx.zero_centered_weight = zero_centered_weight - y = y.reshape(x_shape_og) - y1 = y1.reshape(x_shape_og) if y1 is not None else None - residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None - dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None - dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None - if not return_dropout_mask: - if weight1 is None: - return y if not prenorm else (y, residual_out) - else: - return (y, y1) if not prenorm else (y, y1, residual_out) - else: - if weight1 is None: - return ( - (y, dropout_mask, dropout_mask1) - if not prenorm - else (y, residual_out, dropout_mask, dropout_mask1) - ) - else: - return ( - (y, y1, dropout_mask, dropout_mask1) - if not prenorm - else (y, y1, residual_out, dropout_mask, dropout_mask1) - ) - - @staticmethod - def backward(ctx, dy, *args): - x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors - dy = dy.reshape(-1, dy.shape[-1]) - if weight1 is not None: - dy1, args = args[0], args[1:] - dy1 = dy1.reshape(-1, dy1.shape[-1]) - assert dy1.shape == x.shape - else: - dy1 = None - if ctx.prenorm: - dresidual = args[0] - dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dw, db, dresidual_in, dx1, dw1, db1, _ = _layer_norm_bwd( - dy, - x, - weight, - bias, - ctx.eps, - mean, - rstd, - dresidual, - dy1, - weight1, - bias1, - seeds, - ctx.dropout_p, - rowscale, - ctx.has_residual, - ctx.has_x1, - ctx.zero_centered_weight, - ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - recompute_output=False, - ) - return ( - dx.reshape(ctx.x_shape_og), - dw, - db, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, - dw1, - db1, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -def layer_norm_fn( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - zero_centered_weight=False, - is_rms_norm=False, - return_dropout_mask=False, - out_dtype=None, - out=None, - residual_out=None -): - return LayerNormFn.apply( - x, - weight, - bias, - residual, - x1, - weight1, - bias1, - eps, - dropout_p, - rowscale, - prenorm, - residual_in_fp32, - zero_centered_weight, - is_rms_norm, - return_dropout_mask, - out_dtype, - out, - residual_out - ) - - -def rms_norm_fn( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - zero_centered_weight=False, - return_dropout_mask=False, - out_dtype=None, - out=None, - residual_out=None -): - return LayerNormFn.apply( - x, - weight, - bias, - residual, - x1, - weight1, - bias1, - eps, - dropout_p, - rowscale, - prenorm, - residual_in_fp32, - zero_centered_weight, - True, - return_dropout_mask, - out_dtype, - out, - residual_out - ) - - -class RMSNorm(torch.nn.Module): - - def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, - device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - if dropout_p > 0.0: - self.drop = torch.nn.Dropout(dropout_p) - else: - self.drop = None - self.zero_centered_weight = zero_centered_weight - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - if not self.zero_centered_weight: - torch.nn.init.ones_(self.weight) - else: - torch.nn.init.zeros_(self.weight) - - def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): - return rms_norm_fn( - x, - self.weight, - self.bias, - residual=residual, - eps=self.eps, - dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, - prenorm=prenorm, - residual_in_fp32=residual_in_fp32, - zero_centered_weight=self.zero_centered_weight, - ) - - -class LayerNormLinearFn(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx, - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) - if residual is not None: - assert residual.shape == x_shape_og - residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) - norm_weight = norm_weight.contiguous() - norm_bias = maybe_contiguous(norm_bias) - residual_dtype = ( - residual.dtype - if residual is not None - else (torch.float32 if residual_in_fp32 else None) - ) - y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( - x, - norm_weight, - norm_bias, - eps, - residual, - out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"), - residual_dtype=residual_dtype, - is_rms_norm=is_rms_norm, - ) - y = y.reshape(x_shape_og) - dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype - linear_weight = linear_weight.to(dtype) - linear_bias = linear_bias.to(dtype) if linear_bias is not None else None - out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) - # We don't store y, will be recomputed in the backward pass to save memory - ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - ctx.linear_bias_is_none = linear_bias is None - return out if not prenorm else (out, residual_out.reshape(x_shape_og)) - - @staticmethod - @custom_bwd - def backward(ctx, dout, *args): - x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors - dout = dout.reshape(-1, dout.shape[-1]) - dy = F.linear(dout, linear_weight.t()) - dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) - dy = maybe_contiguous_lastdim(dy) - assert dy.shape == x.shape - if ctx.prenorm: - dresidual = args[0] - dresidual = maybe_contiguous_lastdim(dresidual.reshape(-1, dresidual.shape[-1])) - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( - dy, - x, - norm_weight, - norm_bias, - ctx.eps, - mean, - rstd, - dresidual=dresidual, - has_residual=ctx.has_residual, - is_rms_norm=ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - recompute_output=True, - ) - dlinear_weight = torch.einsum("bo,bi->oi", dout, y) - return ( - dx.reshape(ctx.x_shape_og), - dnorm_weight, - dnorm_bias, - dlinear_weight, - dlinear_bias, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - None, - None, - None, - None, - ) - - -def layer_norm_linear_fn( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, -): - return LayerNormLinearFn.apply( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual, - eps, - prenorm, - residual_in_fp32, - is_rms_norm, - ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/linear.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/linear.py deleted file mode 100644 index a8966dbc345ab0e593df0124451ee7be3dae131a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/linear.py +++ /dev/null @@ -1,594 +0,0 @@ -# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py -# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py -from typing import Optional - -import torch -import triton -import triton.language as tl -from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - -from flash_attn.ops.triton.k_activations import ( - gelu, - gelu_approx, - gelu_approx_grad, - gelu_grad, - squared_relu, - squared_relu_grad, -) - -# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications - - -def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - - -def get_configs_io_bound(): - configs = [] - for num_stages in [2, 3, 4, 5, 6]: - for block_m in [16, 32]: - for block_k in [32, 64]: - for block_n in [32, 64, 128, 256]: - num_warps = 2 if block_n <= 64 else 4 - configs.append( - triton.Config( - { - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SPLIT_K": 1, - }, - num_stages=num_stages, - num_warps=num_warps, - ) - ) - # split_k not used - # for split_k in [2, 4, 8, 16]: - # configs.append(triton.Config( - # {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - # num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) - return configs - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - # good for int8 - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - ] - + get_configs_io_bound(), - key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": estimate_matmul_time, - "top_k": 10, - }, -) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) -@triton.jit -def kernel_fwd( - C, # Pointers to matrices - ACT_INPUT, - A, - B, - bias, - # Matrix dimensions - M, - N, - K, - CACHE_KEY_M, - CACHE_KEY_N, - CACHE_KEY_K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. stride_am is how much to increase a_ptr - # by to get the element one row down (A has M rows) - stride_cm, - # stride_cn, # Assume that stride_cn == 1 - stride_am, - stride_ak, - stride_bn, - stride_bk, - # Meta-parameters - BLOCK_M: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - # split k not used, not performant with activation, kept because early_config_prune is expecting it - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - A_ROWMAJOR: tl.constexpr, - B_COLMAJOR: tl.constexpr, - BIAS: tl.constexpr, - SAVE_ACT_INPUT: tl.constexpr, - ACTIVATION: tl.constexpr, -): - - """ - Kernel for computing Out = activation(A x W + C) - - Input has shape (M, K) - - Weight has shape (K, N) - - Bias has shape (N,) - - Output has shape (M, N) - - ActInputs (optional) has shape (M, N) - 'ActInputs' optionally saves the A x W + C intermediate for backward computations - This kernel will consolidate over K - """ - - pid = tl.program_id(axis=0) - - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - # now compute the block that each program will go through - # rm (resp. rn) denotes a range of indices - # for rows (resp. col) of C - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - # trick to avoid masking on M and N axis - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - - if A_ROWMAJOR: - A = A + (ram[:, None] * stride_am + rk[None, :]) - else: - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - if B_COLMAJOR: - B = B + (rk[:, None] + rbn[None, :] * stride_bn) - else: - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.0) - b = tl.load(B, mask=rk[:, None] < k, other=0.0) - acc += tl.dot(a, b) - - if A_ROWMAJOR: - A += BLOCK_K - else: - A += BLOCK_K * stride_ak - if B_COLMAJOR: - B += BLOCK_K - else: - B += BLOCK_K * stride_bk - - # Putting bias after the matmul (instead of before) is faster, idk why - if BIAS: - bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) - acc += bias[None, :] - - # optional: save the activation inputs - if SAVE_ACT_INPUT: - # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn - act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] - tl.store(act_in_ptrs, acc) - - # optional: fused activation (while the data is in shared memory) - if ACTIVATION == "gelu": - acc = gelu(acc) - elif ACTIVATION == "gelu_approx": - acc = gelu_approx(acc) - elif ACTIVATION == "squared_relu": - acc = squared_relu(acc) - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # write back result - # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - C = C + rm[:, None] * stride_cm + rn[None, :] - mask = (rm < M)[:, None] & (rn < N)[None, :] - tl.store(C, acc) - - -def triton_linear_act( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: str = "id", - save_act_input: bool = False, -) -> torch.Tensor: - """ - Compute e = activation(x @ weight.T + bias). - This wrapper kicks the `kernel_fwd` Triton kernel - :param x: input tensor - :param weight: weight matrix - :param bias: an optional bias tensor - :param activation: Activation name. Needs to be a Triton kernel. - :param act_input: an optional tensor to save the activation inputs (for backward) - :return: result tensor - """ - # if torch.is_autocast_enabled(): - # dtype = torch.get_autocast_gpu_dtype() - # x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]] - - assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] - - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - x_reshaped = x.reshape(batch_dim, n) - - if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1: - x_reshaped = x_reshaped.contiguous() - if weight.stride(0) > 1 and weight.stride(1) > 1: - weight = weight.contiguous() - bias = bias.contiguous() if bias is not None else None - - assert ( - x.dtype == weight.dtype - ), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" - if bias is not None: - assert ( - x.dtype == bias.dtype - ), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" - assert ( - x_reshaped.shape[1] == weight.shape[1] - ), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}" - - assert ( - bias is None or bias.shape[0] == weight.shape[0] - ), "Incompatible dimensions in between weight and bias" - - M, K = x_reshaped.shape - N, K = weight.shape - - output = torch.empty((M, N), device=x.device, dtype=x.dtype) - act_input = torch.empty_like(output) if save_act_input else None - - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa - - kernel_fwd[grid]( - output, - act_input, - x_reshaped, - weight, # data ptrs - bias if bias is not None else x, # auto skip bias if not present - M, # shapes - N, - K, - M // 32, # key for triton cache (limit number of compilations) - N // 32, - K // 32, - stride_cm=output.stride(0), # strides - # stride_cn=output.stride(1), - stride_am=x_reshaped.stride(0), - stride_ak=x_reshaped.stride(1), - stride_bk=weight.stride(1), - stride_bn=weight.stride(0), - BIAS=bias is not None, # optional fused bias - SAVE_ACT_INPUT=save_act_input, # optional save activation inputs - ACTIVATION=activation, # optional fused activation - A_ROWMAJOR=x_reshaped.stride(1) == 1, - B_COLMAJOR=weight.stride(1) == 1, - GROUP_M=8, # speed optimization: group the programs - ) - - if not save_act_input: - return output.reshape(*batch_shape, output.shape[-1]) - else: - return ( - output.reshape(*batch_shape, output.shape[-1]), - act_input.reshape(*batch_shape, act_input.shape[-1]), - ) - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - # good for int8 - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - ] - + get_configs_io_bound(), - key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": estimate_matmul_time, - "top_k": 10, - }, -) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) -@triton.jit -def kernel_bwd( - C, # Pointers to matrices - ACT_INPUT, - A, - B, - # Matrix dimensions - M, - N, - K, - CACHE_KEY_M, - CACHE_KEY_N, - CACHE_KEY_K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. stride_am is how much to increase a_ptr - # by to get the element one row down (A has M rows) - stride_cm, - # stride_cn, # Assume that stride_cn == 1 - stride_am, - stride_ak, - stride_bk, - stride_bn, - # Meta-parameters - BLOCK_M: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - # split k not used, not performant with activation, kept because early_config_prune is expecting it - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - ACTIVATION: tl.constexpr, -): - - """ - Kernel for computing Out = activation(A x W + C) - - Input has shape (M, K) - - Weight has shape (K, N) - - Output has shape (M, N) - - ActInputs (optional) has shape (M, N) - 'ActInputs' optionally saves the A x W + C intermediate for backward computations - This kernel will consolidate over K - """ - - pid = tl.program_id(axis=0) - - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - # now compute the block that each program will go through - # rm (resp. rn) denotes a range of indices - # for rows (resp. col) of C - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - # trick to avoid masking on M and N axis - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.0) - b = tl.load(B, mask=rk[:, None] < k, other=0.0) - acc += tl.dot(a, b) - - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - # optional: fused activation (while the data is in shared memory) - if ACTIVATION != "id": - act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] - act_input = tl.load(act_in_ptrs).to(acc.dtype) - if ACTIVATION == "gelu": - acc *= gelu_grad(act_input) - elif ACTIVATION == "gelu_approx": - acc *= gelu_approx_grad(act_input) - elif ACTIVATION == "squared_relu": - acc *= squared_relu_grad(act_input) - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # write back result - C = C + rm[:, None] * stride_cm + rn[None, :] - mask = (rm < M)[:, None] & (rn < N)[None, :] - tl.store(C, acc, mask=mask) - - -def triton_dgrad_act( - grad_output: torch.Tensor, - weight: torch.Tensor, - activation: str = "id", - act_input: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - Compute e = activation(grad_output @ weight + bias). - This wrapper kicks the `kernel_fwd` Triton kernel - :param grad_output: input tensor - :param weight: weight matrix - :param activation: Activation name. Needs to be a Triton kernel. - :param act_input: an optional tensor to save the activation inputs (for backward) - :return: result tensor - """ - assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] - - batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1] - batch_dim = batch_shape.numel() - grad_output_reshaped = grad_output.reshape(batch_dim, n) - - if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: - grad_output_reshaped = grad_output_reshaped.contiguous() - if weight.stride(0) > 1 and weight.stride(1) > 1: - weight = weight.contiguous() - - assert ( - grad_output.dtype == weight.dtype - ), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}" - assert ( - grad_output_reshaped.shape[1] == weight.shape[0] - ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" - if activation != "id": - assert act_input is not None, f"act_input is required for activation {activation}" - - # M, N, K in bwd are different from M, N, K in fwd - M, K = grad_output_reshaped.shape - K, N = weight.shape - - grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype) - - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa - - kernel_bwd[grid]( - grad_input, - act_input, - grad_output_reshaped, - weight, # data ptrs - M, # shapes - N, - K, - M // 32, # key for triton cache (limit number of compilations) - N // 32, - K // 32, - stride_cm=grad_input.stride(0), # strides - # stride_cn=grad_input.stride(1), - stride_am=grad_output_reshaped.stride(0), - stride_ak=grad_output_reshaped.stride(1), - stride_bk=weight.stride(0), - stride_bn=weight.stride(1), - ACTIVATION=activation, # optional fused activation - GROUP_M=8, # speed optimization: group the programs - ) - - return grad_input.reshape(*batch_shape, grad_input.shape[-1]) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/mlp.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/mlp.py deleted file mode 100644 index 059f4f8a5e174c1f4824e43d313fca18eaa799b8..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/mlp.py +++ /dev/null @@ -1,149 +0,0 @@ -# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared -# to naive implementation. -import fused_dense_lib as fused_dense_cuda -import torch -import torch.nn as nn -import torch.nn.functional as F - -from flash_attn.utils.torch import custom_fwd, custom_bwd -from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd -from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act - - -class FusedDenseSqreluDenseFunc(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0): - """checkpoint_lvl: - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute act_input and gelu_out in the bwd - """ - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_gpu_dtype() - x, weight1, bias1, weight2, bias2 = [ - a.to(dtype=dtype) for a in [x, weight1, bias1, weight2, bias2] - ] - is_bf16 = x.dtype == torch.bfloat16 - assert checkpoint_lvl in [0, 1, 2] - x = x.contiguous() - weight1 = weight1.contiguous() - bias1 = bias1.contiguous() - weight2 = weight2.contiguous() - bias2 = bias2.contiguous() - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - if is_bf16: - act_input = fused_dense_cuda.linear_bias_forward( - x.reshape(batch_dim, n), weight1, bias1 - ) - output1 = sqrelu_fwd(act_input) - else: - save_act_input = checkpoint_lvl != 2 - result = triton_linear_act( - x.reshape(batch_dim, n), - weight1, - bias1, - activation="squared_relu", - save_act_input=save_act_input, - ) - if save_act_input: - output1, act_input = result - else: - output1 = result - output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2) - ctx.checkpoint_lvl = checkpoint_lvl - if checkpoint_lvl == 0: - ctx.save_for_backward(x, weight1, bias1, weight2, act_input, output1) - elif checkpoint_lvl == 1: - ctx.save_for_backward(x, weight1, bias1, weight2, act_input) - elif checkpoint_lvl == 2: - ctx.save_for_backward(x, weight1, bias1, weight2) - return output2.reshape(*batch_shape, output2.shape[-1]) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - grad_output = grad_output.contiguous() - checkpoint_lvl = ctx.checkpoint_lvl - x, weight1, bias1, weight2, *rest = ctx.saved_tensors - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - is_bf16 = x.dtype == torch.bfloat16 - if checkpoint_lvl == 0: - act_input, output1 = rest - elif checkpoint_lvl == 1: - (act_input,) = rest - output1 = sqrelu_fwd(act_input) - elif checkpoint_lvl == 2: - if is_bf16: - act_input = fused_dense_cuda.linear_bias_forward( - x.reshape(batch_dim, n), weight1, bias1 - ) - output1 = sqrelu_fwd(act_input) - else: - output1, act_input = triton_linear_act( - x.reshape(batch_dim, n), - weight1, - bias1, - activation="squared_relu", - save_act_input=True, - ) - - if is_bf16: - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) - grad_output1 = grad_output @ weight2 - grad_act_input = sqrelu_bwd(grad_output1, act_input) - grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( - x.reshape(batch_dim, n), weight1, grad_act_input - ) - else: - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) - grad_act_input = triton_dgrad_act( - grad_output, weight2, activation="squared_relu", act_input=act_input - ) - grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( - x.reshape(batch_dim, n), weight1, grad_act_input - ) - return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None - - -fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply - - -class FusedDenseSqreluDense(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - bias1=True, - bias2=True, - checkpoint_lvl=0, - device=None, - dtype=None, - ): - """ - checkpoint_lvl (increasing lvl means slower but more memory saving): - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute gelu_in and gelu_out in the bwd - """ - assert checkpoint_lvl in [0, 1, 2] - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features * 4 - assert bias1 == True, "DenseSqreluDense module without bias is currently not supported" - assert bias2 == True, "DenseSqreluDense module without bias is currently not supported" - self.checkpoint_lvl = checkpoint_lvl - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) - - def forward(self, x): - assert x.is_cuda - return fused_dense_sqrelu_dense_function( - x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias, self.checkpoint_lvl - ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/rotary.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/rotary.py deleted file mode 100644 index ff4017fda3e4a6e18cf3a51b34f0fb073d8f678a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/flash_attn/ops/triton/rotary.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) 2025, Tri Dao. -# As of 2025-04-23, we require triton >= 3.0 - -from typing import Optional, Union - -import torch - -import triton -import triton.language as tl - - -@triton.jit -def rotary_kernel( - OUT, # Pointers to matrices - X, - COS, - SIN, - CU_SEQLENS, - SEQLEN_OFFSETS, # this could be int or a pointer - # Matrix dimensions - seqlen, - nheads, - seqlen_ro, - # strides - stride_out_batch, - stride_out_seqlen, - stride_out_nheads, - stride_out_headdim, - stride_x_batch, - stride_x_seqlen, - stride_x_nheads, - stride_x_headdim, - # Meta-parameters - # We want ROTARY_DIM to be constexpr, otherwise the triton compiler doesn't know that - # the mask is constant every 8 elements, and it will generate LDG.16 instead of LDG.128 - ROTARY_DIM: tl.constexpr, - IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, - IS_VARLEN: tl.constexpr, - INTERLEAVED: tl.constexpr, - CONJUGATE: tl.constexpr, - BLOCK_H: tl.constexpr, - BLOCK_M: tl.constexpr, -): - BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM) - ROTARY_DIM_HALF = ROTARY_DIM // 2 - pid_head = tl.program_id(axis=0) - pid_m = tl.program_id(axis=1) - pid_batch = tl.program_id(axis=2) - - if not IS_VARLEN: - X = X + pid_batch * stride_x_batch - OUT = OUT + pid_batch * stride_out_batch - else: - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - X = X + start_idx * stride_x_seqlen - OUT = OUT + start_idx * stride_out_seqlen - - if pid_m * BLOCK_M >= seqlen: - return - - rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - if not IS_SEQLEN_OFFSETS_TENSOR: - rm_cs = rm + SEQLEN_OFFSETS - else: - rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) - - rk_half = tl.arange(0, BLOCK_K // 2) - COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) - SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) - mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) - cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) - sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) - if CONJUGATE: - sin = -sin - - if not INTERLEAVED: - # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT - X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk_half[None, None, :] * stride_x_headdim) - OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk_half[None, None, :] * stride_out_headdim) - mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk_half[None, None, :] < ROTARY_DIM_HALF) - x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) - x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0,).to(tl.float32) - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - tl.store(OUT, o0, mask=mask) - tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) - else: - rk = tl.arange(0, BLOCK_K) - X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk[None, None, :] * stride_x_headdim) - OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk[None, None, :] * stride_out_headdim) - mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk[None, None, :] < ROTARY_DIM) - x = tl.load(X, mask=mask, other=0.0).to(tl.float32) - x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) - tl.store(OUT, o, mask=mask) - - -def apply_rotary( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - interleaved=False, - inplace=False, - conjugate=False, -) -> torch.Tensor: - """ - Arguments: - x: (batch, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim). - cos: (seqlen_ro, rotary_dim / 2) - sin: (seqlen_ro, rotary_dim / 2) - seqlen_offsets: integer or integer tensor of size (batch,) - cu_seqlens: (batch + 1,) or None - max_seqlen: int - Returns: - y: (batch, seqlen, nheads, headdim) - """ - is_varlen = cu_seqlens is not None - if not is_varlen: - batch, seqlen, nheads, headdim = x.shape - else: - assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" - total_seqlen, nheads, headdim = x.shape - batch_p_1 = cu_seqlens.shape[0] - batch = batch_p_1 - 1 - seqlen = max_seqlen - seqlen_ro, rotary_dim = cos.shape - assert sin.shape == cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim, "rotary_dim must be <= headdim" - assert headdim <= 256, "Only support headdim <= 256" - assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" - - cos, sin = cos.contiguous(), sin.contiguous() - if isinstance(seqlen_offsets, torch.Tensor): - assert seqlen_offsets.shape == (batch,) - assert seqlen_offsets.dtype in [torch.int32, torch.int64] - seqlen_offsets = seqlen_offsets.contiguous() - else: - assert seqlen_offsets + seqlen <= seqlen_ro - - output = torch.empty_like(x) if not inplace else x - if rotary_dim < headdim and not inplace: - output[..., rotary_dim:].copy_(x[..., rotary_dim:]) - - grid = lambda META: (triton.cdiv(nheads, META["BLOCK_H"]), triton.cdiv(seqlen, META["BLOCK_M"]), batch) # noqa - BLOCK_M = 8 if rotary_dim <= 128 else 4 - - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(x.device.index): - torch.library.wrap_triton(rotary_kernel)[grid]( - output, # data ptrs - x, - cos, - sin, - cu_seqlens, - seqlen_offsets, - seqlen, # shapes - nheads, - seqlen_ro, - output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 - output.stride(-3), # seqlen_stride or total_seqlen_stride - output.stride(-2), # nheads_stride - output.stride(-1), # headdim_stride - x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 - x.stride(-3), # seqlen stride or total_seqlen_stride - x.stride(-2), # nheads stride - x.stride(-1), # headdim stride - rotary_dim, - isinstance(seqlen_offsets, torch.Tensor), - is_varlen, - interleaved, - conjugate, - BLOCK_M=BLOCK_M, - BLOCK_H=2, - ) - return output diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/__init__.py deleted file mode 100644 index ecc2f9d896b6c93f90b0a1499856dc0612177422..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/__init__.py +++ /dev/null @@ -1,393 +0,0 @@ -from typing import Optional, List -import torch -from ._ops import ops as flash_attn_ops -from .flash_attn_interface import ( - flash_attn_func, - flash_attn_kvpacked_func, - flash_attn_qkvpacked_func, - flash_attn_varlen_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_with_kvcache, -) - - -def fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - p_dropout: float = 0.0, - softmax_scale: Optional[float] = None, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - return_softmax: bool = False, - gen: Optional[torch.Generator] = None, -) -> List[torch.Tensor]: - """ - Forward pass for multi-head attention. - - Args: - q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] - k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - out: Optional output tensor, same shape as q - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - p_dropout: Dropout probability - softmax_scale: Scale factor for softmax - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - return_softmax: Whether to return softmax weights - gen: Optional random number generator - - Returns: - List of tensors: [output, softmax_lse, (softmax if return_softmax)] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.fwd( - q, - k, - v, - out, - alibi_slopes, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - gen, - ) - - -def varlen_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - out: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - max_seqlen_q: int = 0, - max_seqlen_k: int = 0, - p_dropout: float = 0.0, - softmax_scale: Optional[float] = None, - zero_tensors: bool = False, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - return_softmax: bool = False, - gen: Optional[torch.Generator] = None, -) -> List[torch.Tensor]: - """ - Forward pass for multi-head attention with variable sequence lengths. - - Args: - q: Query tensor of shape [total_q, num_heads, head_size] - k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] - v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] - cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1] - cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1] - out: Optional output tensor of shape [total_q, num_heads, head_size] - seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size] - leftpad_k: Optional left padding for keys of shape [batch_size] - block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq] - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - max_seqlen_q: Maximum sequence length for queries - max_seqlen_k: Maximum sequence length for keys - p_dropout: Dropout probability - softmax_scale: Scale factor for softmax - zero_tensors: Whether to zero tensors before computation - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - return_softmax: Whether to return softmax weights - gen: Optional random number generator - - Returns: - List of tensors: [output, softmax_lse, (softmax if return_softmax)] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.varlen_fwd( - q, - k, - v, - out, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - leftpad_k, - block_table, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - p_dropout, - softmax_scale, - zero_tensors, - is_causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - gen, - ) - - -def bwd( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor] = None, - dk: Optional[torch.Tensor] = None, - dv: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - p_dropout: float = 0.0, - softmax_scale: Optional[float] = None, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - deterministic: bool = False, - gen: Optional[torch.Generator] = None, - rng_state: Optional[torch.Tensor] = None, -) -> List[torch.Tensor]: - """ - Backward pass for multi-head attention. - - Args: - dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size] - q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] - k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size] - softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q] - dq: Optional gradient tensor for queries, same shape as q - dk: Optional gradient tensor for keys, same shape as k - dv: Optional gradient tensor for values, same shape as v - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - p_dropout: Dropout probability - softmax_scale: Scale factor for softmax - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - deterministic: Whether to use deterministic algorithms - gen: Optional random number generator - rng_state: Optional RNG state from forward pass - - Returns: - List of tensors: [dq, dk, dv] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - alibi_slopes, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - softcap, - deterministic, - gen, - rng_state, - ) - - -def varlen_bwd( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - dq: Optional[torch.Tensor] = None, - dk: Optional[torch.Tensor] = None, - dv: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - max_seqlen_q: int = 0, - max_seqlen_k: int = 0, - p_dropout: float = 0.0, - softmax_scale: Optional[float] = None, - zero_tensors: bool = False, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - deterministic: bool = False, - gen: Optional[torch.Generator] = None, - rng_state: Optional[torch.Tensor] = None, -) -> List[torch.Tensor]: - """ - Backward pass for multi-head attention with variable sequence lengths. - - Args: - dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size] - q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] - k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] - out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size] - softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q] - cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1] - cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1] - dq: Optional gradient tensor for queries, same shape as q - dk: Optional gradient tensor for keys, same shape as k - dv: Optional gradient tensor for values, same shape as v - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - max_seqlen_q: Maximum sequence length for queries - max_seqlen_k: Maximum sequence length for keys - p_dropout: Dropout probability - softmax_scale: Scale factor for softmax - zero_tensors: Whether to zero tensors before computation - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - deterministic: Whether to use deterministic algorithms - gen: Optional random number generator - rng_state: Optional RNG state from forward pass - - Returns: - List of tensors: [dq, dk, dv] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.varlen_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - p_dropout, - softmax_scale, - zero_tensors, - is_causal, - window_size_left, - window_size_right, - softcap, - deterministic, - gen, - rng_state, - ) - - -def fwd_kvcache( - q: torch.Tensor, - kcache: torch.Tensor, - vcache: torch.Tensor, - k: Optional[torch.Tensor] = None, - v: Optional[torch.Tensor] = None, - seqlens_k: Optional[torch.Tensor] = None, - rotary_cos: Optional[torch.Tensor] = None, - rotary_sin: Optional[torch.Tensor] = None, - cache_batch_idx: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - is_causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - is_rotary_interleaved: bool = False, - num_splits: int = 1, -) -> List[torch.Tensor]: - """ - Forward pass for multi-head attention with KV cache. - - Args: - q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] - kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] - vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] - k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size] - v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size] - seqlens_k: Optional sequence lengths for keys of shape [batch_size] - rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2] - rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2] - cache_batch_idx: Optional indices to index into the KV cache - leftpad_k: Optional left padding for keys of shape [batch_size] - block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq] - alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] - out: Optional output tensor, same shape as q - softmax_scale: Scale factor for softmax - is_causal: Whether to use causal attention - window_size_left: Window size for left context (-1 for unlimited) - window_size_right: Window size for right context (-1 for unlimited) - softcap: Soft cap for attention weights - is_rotary_interleaved: Whether rotary embeddings are interleaved - num_splits: Number of splits for computation - - Returns: - List of tensors: [output, softmax_lse] - """ - if softmax_scale is None: - attention_head_dim = q.shape[-1] - softmax_scale = 1.0 / (attention_head_dim**0.5) - - return flash_attn_ops.fwd_kvcache( - q, - kcache, - vcache, - k, - v, - seqlens_k, - rotary_cos, - rotary_sin, - cache_batch_idx, - leftpad_k, - block_table, - alibi_slopes, - out, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - softcap, - is_rotary_interleaved, - num_splits, - ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so deleted file mode 100755 index d44eff5ebd602dfdb3e4d4cd60888bcdc6002bdd..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b2dde56a1e2cca8b30a68fc9da5f238b1d44d23f8e9a77ed70d3b8147166f739 -size 448643480 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_ops.py deleted file mode 100644 index a9819140ce922d5d25722ffeb3c2416285a9d068..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _flash_attn_876ac68_dirty -ops = torch.ops._flash_attn_876ac68_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_flash_attn_876ac68_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/bert_padding.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/bert_padding.py deleted file mode 100644 index 3c2d35159a014a9d03aabead9e52e009168696ea..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/bert_padding.py +++ /dev/null @@ -1,218 +0,0 @@ -# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py - -import torch -import torch.nn.functional as F -from einops import rearrange, repeat - - -class IndexFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices): - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - # return input[indices] - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) - ).reshape(-1, *other_shape) - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - grad_output = rearrange(grad_output, "b ... -> b (...)") - grad_input = torch.zeros( - [ctx.first_axis_dim, grad_output.shape[1]], - device=grad_output.device, - dtype=grad_output.dtype, - ) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - # grad_input[indices] = grad_output - grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -index_first_axis = IndexFirstAxis.apply - - -class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, values, indices, first_axis_dim): - ctx.save_for_backward(indices) - assert indices.ndim == 1 - assert values.ndim >= 2 - output = torch.zeros( - first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype - ) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - output[indices] = values - # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) - return output - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - grad_values = grad_output[indices] - # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) - return grad_values, None, None - - -index_put_first_axis = IndexPutFirstAxis.apply - - -class IndexFirstAxisResidual(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices): - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - output = input[indices] - # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last - # memory format to channel_first. In other words, input might not be contiguous. - # If we don't detach, Pytorch complains about output being a view and is being modified inplace - return output, input.detach() - - @staticmethod - def backward(ctx, grad_output, grad_residual): - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - assert grad_residual.shape[1:] == other_shape - grad_input = grad_residual - # grad_input[indices] += grad_output - indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) - indices = indices.expand_as(grad_output) - grad_input.scatter_add_(0, indices, grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -index_first_axis_residual = IndexFirstAxisResidual.apply - - -def unpad_input(hidden_states, attention_mask, unused_mask=None): - """ - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. - unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. - indices: (total_nnz), the indices of masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. - """ - all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask - seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) - used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return ( - index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - used_seqlens_in_batch, - ) - - -def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): - """ - Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). - The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). - - For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: - ``` - [ - [2, 3, 0, 0, 0, 0], - [3, 2, 0, 0, 0, 0], - [6, 0, 0, 0, 0, 0] - ] - ``` - , which refers to the 3D-attention mask: - ``` - [ - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0], - [0, 0, 1, 1, 0, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 1] - ], - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 1] - ], - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1] - ] - ] - ```. - - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - """ - length = attention_mask_in_length.sum(dim=-1) - seqlen = attention_mask_in_length.size(-1) - attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) - real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() - seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] - indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return ( - index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def pad_input(hidden_states, indices, batch, seqlen): - """ - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. - batch: int, batch size for the padded sequence. - seqlen: int, maximum sequence length for the padded sequence. - Return: - hidden_states: (batch, seqlen, ...) - """ - dim = hidden_states.shape[-1] - # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) - # output[indices] = hidden_states - output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, "(b s) ... -> b s ...", b=batch) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/flash_attn_interface.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/flash_attn_interface.py deleted file mode 100644 index 690d644f0a1c3d6ccfd26acbf8b22376a47cfff0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/flash_attn_interface.py +++ /dev/null @@ -1,1609 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -from typing import Optional, Sequence, Tuple, Union - -import torch -import torch.nn as nn -import os - -# # isort: off -# # We need to import the CUDA kernels after importing torch -# USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" -# if USE_TRITON_ROCM: -# from .flash_attn_triton_amd import interface_fa as flash_attn_gpu -# else: -# import flash_attn_2_cuda as flash_attn_gpu - - -from ._ops import ops as flash_attn_gpu - -# # isort: on - -def maybe_contiguous(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def _get_block_size_n(device, head_dim, is_dropout, is_causal): - # This should match the block sizes in the CUDA kernel - assert head_dim <= 256 - major, minor = torch.cuda.get_device_capability(device) - is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) - is_sm80 = major == 8 and minor == 0 - is_sm90 = major == 9 and minor == 0 - if head_dim <= 32: - return 128 - if head_dim <= 64: - return 128 if not is_dropout else 64 - elif head_dim <= 96: - return 64 - elif head_dim <= 128: - if is_sm8x: - return 64 if (not is_dropout and is_causal) else 32 - else: - return 64 if not is_dropout else 32 - elif head_dim <= 192: - return 64 - elif head_dim <= 224: - return 64 - elif head_dim <= 256: - return 64 - - -def round_multiple(x, m): - return (x + m - 1) // m * m - - -# torch.compile() support is only enabled for pytorch >= 2.4 -# The reason for this is that we are using the new custom_op and register_fake -# APIs, which support inplace modification of inputs in the function itself -if torch.__version__ >= "2.4.0": - _torch_custom_op_wrapper = torch.library.custom_op - _torch_register_fake_wrapper = torch.library.register_fake -else: - def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): - def wrap(func): - return func - if fn is None: - return wrap - return fn - def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): - def wrap(func): - return func - if fn is None: - return wrap - return fn - _torch_custom_op_wrapper = noop_custom_op_wrapper - _torch_register_fake_wrapper = noop_register_fake_wrapper - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda") -def _flash_attn_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - return_softmax: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( - q, - k, - v, - None, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - None, - ) - return out, softmax_lse, S_dmask, rng_state - - -@_torch_register_fake_wrapper("flash_attn::_flash_attn_forward") -def _flash_attn_forward_fake( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - return_softmax: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - batch_size, seqlen_q, num_heads, head_size = q.shape - seqlen_k = k.shape[1] - out = torch.empty_like(q) - softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout) - p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - if return_softmax: - p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) - rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) - - return out, softmax_lse, p, rng_state - - -if torch.__version__ >= "2.4.0": - _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward -else: - _wrapped_flash_attn_forward = _flash_attn_forward - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda") -def _flash_attn_varlen_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - return_softmax: bool = False, - block_table: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( - q, - k, - v, - None, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - leftpad_k, - block_table, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - zero_tensors, - causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - None, - ) - # if out.isnan().any() or softmax_lse.isnan().any(): - # breakpoint() - return out, softmax_lse, S_dmask, rng_state - - -@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward") -def _flash_attn_varlen_forward_fake( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - return_softmax: bool = False, - block_table: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - paged_kv = block_table is not None - batch_size = cu_seqlens_q.numel() - 1 - total_q, num_heads, _ = q.shape - - out = torch.empty_like(q) - softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) - p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - seqlen_q_rounded = round_multiple(max_seqlen_q, 128) - seqlen_k_rounded = round_multiple(max_seqlen_k, 128) - if return_softmax: - p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) - rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) - return out, softmax_lse, p, rng_state - - -if torch.__version__ >= "2.4.0": - _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward -else: - _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") -def _flash_attn_backward( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, -) -> torch.Tensor: - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - ( - dq, - dk, - dv, - softmax_d, - ) = flash_attn_gpu.bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - deterministic, - None, - rng_state, - ) - return softmax_d - - -@_torch_register_fake_wrapper("flash_attn::_flash_attn_backward") -def _flash_attn_backward_fake( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, -) -> torch.Tensor: - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - if dq is None: - dq = torch.empty_like(q) - if dk is None: - dk = torch.empty_like(k) - if dv is None: - dv = torch.empty_like(v) - batch_size, seqlen_q, num_heads, _ = q.shape - softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) - - return softmax_d - - -if torch.__version__ >= "2.4.0": - _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward -else: - _wrapped_flash_attn_backward = _flash_attn_backward - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") -def _flash_attn_varlen_backward( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> torch.Tensor: - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - ( - dq, - dk, - dv, - softmax_d, - ) = flash_attn_gpu.varlen_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - zero_tensors, - causal, - window_size_left, - window_size_right, - softcap, - deterministic, - None, - rng_state, - ) - # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): - # breakpoint() - return softmax_d - - -@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward") -def _flash_attn_varlen_backward_fake( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> torch.Tensor: - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - batch_size = cu_seqlens_q.numel() - 1 - total_q, num_heads, _ = q.shape - - if dq is None: - dq = torch.empty_like(q) - if dk is None: - dk = torch.empty_like(k) - if dv is None: - dv = torch.empty_like(v) - softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) - - return softmax_d - - -if torch.__version__ >= "2.4.0": - _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward -else: - _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward - - -class FlashAttnQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach() - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) - dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dqkv[:, :, 0], - dqkv[:, :, 1], - dqkv[:, :, 2], - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None - - -class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach() - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - block_table=None, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) - ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors - qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) - dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_varlen_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - cu_seqlens, - cu_seqlens, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None, None - - -class FlashAttnKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - kv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, kv] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach() - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - dq = torch.empty_like(q) - kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) - dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dkv[:, :, 0], - dkv[:, :, 1], - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None - - -class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, kv] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - k, v = kv[:, 0].detach(), kv[:, 1].detach() - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - block_table=None, - ) - if is_grad: - ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state - ) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - dq = torch.empty_like(q) - kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) - dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_varlen_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dkv[:, 0], - dkv[:, 1], - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None - - -class FlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None - - -class FlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - block_table, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - block_table=block_table, - ) - if is_grad: - ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state - ) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_varlen_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # <=0.0 means deactivate - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If Q, K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of Q, K, V. - For multi-query and grouped-query attention (MQA/GQA), please see - flash_attn_kvpacked_func and flash_attn_func. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. - - Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to - the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnQKVPackedFunc.apply( - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of K, V. - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - kv: (batch_size, seqlen, 2, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnKVPackedFunc.apply( - q, - kv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If Q, K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of Q, K, V. - For multi-query and grouped-query attention (MQA/GQA), please see - flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. - - Arguments: - qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into qkv. - max_seqlen: int. Maximum sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenQKVPackedFunc.apply( - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of K, V. - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenKVPackedFunc.apply( - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - block_table=None, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - block_table, - torch.is_grad_enabled(), - ) - - -def flash_attn_with_kvcache( - q, - k_cache, - v_cache, - k=None, - v=None, - rotary_cos=None, - rotary_sin=None, - cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, - cache_batch_idx: Optional[torch.Tensor] = None, - cache_leftpad: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - rotary_interleaved=True, - alibi_slopes=None, - num_splits=0, - return_softmax_lse=False, -): - """ - If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from - k and v. This is useful for incremental decoding: you can pass in the cached keys/values from - the previous step, and update them with the new keys/values from the current step, and do - attention with the updated cache, all in 1 kernel. - - If you pass in k / v, you must make sure that the cache is large enough to hold the new values. - For example, the KV cache could be pre-allocated with the max sequence length, and you can use - cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. - - Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be - rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos - and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at - indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). - - See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. - - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Note: Does not support backward pass. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) - page_block_size must be a multiple of 256. - v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) - k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate - k with k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. - rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding - to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. - rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. - cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the - KV cache. - cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. - If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. - If the indices are not distinct, and k and v are provided, the values updated in the cache - might come from any of the duplicate indices. - cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. - block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. - If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, - rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 - (i.e. GPT-NeoX style). - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - num_splits: int. If > 1, split the key/value into this many chunks along the sequence. - If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic - to automatically determine the number of splits. - Don't change this unless you know what you are doing. - return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. - - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - """ - assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" - assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - if cache_seqlens is not None and isinstance(cache_seqlens, int): - cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device - ) - cache_seqlens = maybe_contiguous(cache_seqlens) - cache_batch_idx = maybe_contiguous(cache_batch_idx) - block_table = maybe_contiguous(block_table) - out, softmax_lse = flash_attn_gpu.fwd_kvcache( - q, - k_cache, - v_cache, - k, - v, - cache_seqlens, - rotary_cos, - rotary_sin, - cache_batch_idx, - cache_leftpad, - block_table, - alibi_slopes, - None, - softmax_scale, - causal, - window_size[0], - window_size[1], - softcap, - rotary_interleaved, - num_splits, - ) - return (out, softmax_lse) if return_softmax_lse else out diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/layers/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/layers/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/layers/patch_embed.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/layers/patch_embed.py deleted file mode 100644 index 05562f8e8bcdb58e947c6f402a49eacd2d031871..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/layers/patch_embed.py +++ /dev/null @@ -1,67 +0,0 @@ -# We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py -# But we use nn.Linear instead of Conv2d and it's about 8x faster. - -from functools import partial - -import torch.nn as nn -from einops import rearrange -from torch import _assert -from torch.nn.modules.utils import _pair - -try: - from flash_attn.ops.fused_dense import FusedDense -except ImportError: - FusedDense = None - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding""" - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - norm_layer=None, - flatten=True, - bias=True, - fused_bias_fc=False, - ): - super().__init__() - img_size = _pair(img_size) - patch_size = _pair(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - if fused_bias_fc and FusedDense is None: - raise ImportError("fused_dense is not installed") - - linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense - self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x): - _, _, H, W = x.shape - _assert( - H == self.img_size[0], - f"Input image height ({H}) doesn't match model ({self.img_size[0]}).", - ) - _assert( - W == self.img_size[1], - f"Input image width ({W}) doesn't match model ({self.img_size[1]}).", - ) - x = self.proj( - rearrange( - x, - "b c (h p1) (w p2) -> b h w (c p1 p2)", - p1=self.patch_size[0], - p2=self.patch_size[1], - ) - ) - if self.flatten: - x = rearrange(x, "b h w c -> b (h w) c") - x = self.norm(x) - return x diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/layers/rotary.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/layers/rotary.py deleted file mode 100644 index d1bfc21fc7de1dd287e8f382847b194a48075981..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/layers/rotary.py +++ /dev/null @@ -1,483 +0,0 @@ -# Copyright (c) 2025, Tri Dao - -import math -from functools import partial -from typing import Optional, Tuple, Union - -import torch -from torch import Tensor - -from einops import rearrange, repeat -# from flash_attn.ops.triton.rotary import apply_rotary -from ..ops.triton.rotary import apply_rotary - - -def rotate_half(x, interleaved=False): - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) - - -def apply_rotary_emb_torch(x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - return torch.cat( - [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], - dim=-1, - ) - - -class ApplyRotaryEmb(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - cos, - sin, - interleaved=False, - inplace=False, - seqlen_offsets: Union[int, Tensor] = 0, - cu_seqlens: Optional[Tensor] = None, - max_seqlen: Optional[int] = None, - ): - out = apply_rotary( - x, - cos, - sin, - seqlen_offsets=seqlen_offsets, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - interleaved=interleaved, - inplace=inplace, - ) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - ctx.inplace = inplace - ctx.max_seqlen = max_seqlen - return out if not inplace else x - - @staticmethod - def backward(ctx, do): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors - else: - cos, sin, cu_seqlens = ctx.saved_tensors - dx = apply_rotary( - do, - cos, - sin, - seqlen_offsets=seqlen_offsets, - cu_seqlens=cu_seqlens, - max_seqlen=ctx.max_seqlen, - interleaved=ctx.interleaved, - inplace=ctx.inplace, - conjugate=True, - ) - return dx, None, None, None, None, None, None, None - - -def apply_rotary_emb( - x, - cos, - sin, - interleaved=False, - inplace=False, - seqlen_offsets: Union[int, Tensor] = 0, - cu_seqlens: Optional[Tensor] = None, - max_seqlen: Optional[int] = None, -): - """ - Arguments: - x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim) - cos, sin: (seqlen_rotary, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - inplace: if True, apply rotary embedding in-place. - seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. - Most commonly used in inference when we have KV cache. - cu_seqlens: (batch + 1,) or None - max_seqlen: int - Return: - out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim) - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - return ApplyRotaryEmb.apply( - x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen - ) - - -# For backward compatibility -apply_rotary_emb_func = apply_rotary_emb - - -def _apply_rotary_emb_qkv( - qkv, - cos, - sin, - cos_k=None, - sin_k=None, - interleaved=False, - inplace=False, - conjugate=False, - seqlen_offsets: Union[int, Tensor] = 0, - num_heads_q: Optional[int] = None, -): - apply_rotary_fn = partial( - apply_rotary, - interleaved=interleaved, - inplace=inplace, - conjugate=conjugate, - seqlen_offsets=seqlen_offsets - ) - if cos_k is None and sin_k is None and qkv.is_contiguous(): - # Call 1 kernel instead of 2 kernels - # We need qkv to be contiguous so that when we reshape to combine (3, nheads) - # dimensions, we get the same tensor - if qkv.dim() == 5: - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") - qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) - qk = apply_rotary_fn(qk, cos, sin) - else: - assert qkv.dim() == 4 - assert num_heads_q is not None - num_heads_k = (qkv.shape[2] - num_heads_q) // 2 - assert qkv.shape[2] == num_heads_q + 2 * num_heads_k - qk = qkv[:, :, :num_heads_q + num_heads_k] - qk = apply_rotary_fn(qk, cos, sin) - if not inplace: - if qkv.dim() == 5: - qkv = torch.cat([rearrange(qk, "b s (t h) d -> b s t h d", t=2), qkv[:, :, 2:]], dim=2) - else: - qkv = torch.cat([qk, qkv[:, :, num_heads_q + num_heads_k :]], dim=2) - else: - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - if qkv.dim() == 5: - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - q, k = qkv[:, :, 0], qkv[:, :, 1] - else: - assert qkv.dim() == 4 - assert num_heads_q is not None - num_heads_k = (qkv.shape[2] - num_heads_q) // 2 - assert qkv.shape[2] == num_heads_q + 2 * num_heads_k - q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k] - q = apply_rotary_fn(q, cos, sin) - k = apply_rotary_fn(k, cos_k, sin_k) - if not inplace: - if qkv.dim() == 5: - qkv = torch.stack([q, k, qkv[:, :, 2]], dim=2) - else: - qkv = torch.cat([q, k, qkv[:, :, num_heads_q + num_heads_k:]], dim=2) - return qkv - - -class ApplyRotaryEmbQKV_(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cos, - sin, - cos_k=None, - sin_k=None, - interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - num_heads_q: Optional[int] = None, - ): - # apply_rotary_emb_qkv_inplace( - qkv = _apply_rotary_emb_qkv( - qkv, cos, sin, cos_k, sin_k, interleaved=interleaved, inplace=True, - seqlen_offsets=seqlen_offsets, num_heads_q=num_heads_q, - ) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - ctx.num_heads_q = num_heads_q - return qkv - - @staticmethod - def backward(ctx, dqkv): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors - else: - cos, sin, cos_k, sin_k = ctx.saved_tensors - dqkv = _apply_rotary_emb_qkv( - dqkv, cos, sin, cos_k, sin_k, interleaved=ctx.interleaved, inplace=True, - seqlen_offsets=seqlen_offsets, num_heads_q=ctx.num_heads_q, conjugate=True, - ) - return dqkv, None, None, None, None, None, None, None - - -def apply_rotary_emb_qkv_( - qkv, - cos, - sin, - cos_k=None, - sin_k=None, - interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - num_heads_q: Optional[int] = None, -): - """ - Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim). - If qkv has shape (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA), - then num_heads_q must be provided. - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. - Most commonly used in inference when we have KV cache. - Return: - qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of Q and K. - """ - return ApplyRotaryEmbQKV_.apply( - qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, num_heads_q - ) - - -class ApplyRotaryEmbKV_(torch.autograd.Function): - - @staticmethod - def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): - batch, seqlen, two, nheads, headdim = kv.shape - assert two == 2 - k = kv[:, :, 0] - apply_rotary( - k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True - ) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - return kv - - @staticmethod - def backward(ctx, dkv): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, seqlen_offsets = ctx.saved_tensors - else: - cos, sin = ctx.saved_tensors - apply_rotary( - dkv[:, :, 0], - cos, - sin, - seqlen_offsets=seqlen_offsets, - interleaved=ctx.interleaved, - inplace=True, - conjugate=True, - ) - return dkv, None, None, None, None - - -apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply - - -def apply_rotary_emb_kv_( - kv, - cos, - sin, - interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, -): - """ - Arguments: - kv: (batch_size, seqlen, 2, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. - Most commonly used in inference when we have KV cache. - Return: - kv: (batch_size, seqlen, 2, nheads, headdim) - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of K. - """ - return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets) - - -class RotaryEmbedding(torch.nn.Module): - """ - The rotary position embeddings from RoFormer_ (Su et. al). - A crucial insight from the method is that the query and keys are - transformed by rotation matrices which depend on the relative positions. - - Other implementations are available in the Rotary Transformer repo_ and in - GPT-NeoX_, GPT-NeoX was an inspiration - - .. _RoFormer: https://arxiv.org/abs/2104.09864 - .. _repo: https://github.com/ZhuiyiTechnology/roformer - .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox - - If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). - A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 - Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py - """ - - def __init__( - self, - dim: int, - base=10000.0, - interleaved=False, - scale_base=None, - device=None, - ): - """ - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - """ - super().__init__() - self.dim = dim - self.base = float(base) - # Generate and save the inverse frequency buffer (non trainable) - inv_freq = self._compute_inv_freq(device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.interleaved = interleaved - self.scale_base = scale_base - scale = ( - (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) - if scale_base is not None - else None - ) - self.register_buffer("scale", scale, persistent=False) - - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - - def _compute_inv_freq(self, device=None): - return 1.0 / ( - self.base - ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) - ) - - def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): - # Reset the tables if the sequence length has changed, - # if we're on a new device (possibly due to tracing for instance), - # or if we're switching from inference mode to training - if ( - seqlen > self._seq_len_cached - or self._cos_cached is None - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._seq_len_cached = seqlen - # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 - # And the output of arange can be quite large, so bf16 would lose a lot of precision. - t = torch.arange(seqlen, device=device, dtype=torch.float32) - # We want fp32 here as well since inv_freq will be multiplied with t, and the output - # will be large. Having it in bf16 will lose a lot of precision and cause the - # cos & sin output to change significantly. - # We want to recompute self.inv_freq if it was not loaded in fp32 - if self.inv_freq.dtype != torch.float32: - inv_freq = self._compute_inv_freq(device=device) - else: - inv_freq = self.inv_freq - # Don't do einsum, it converts fp32 to bf16 under AMP - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, inv_freq) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") - # We want the multiplication by scale to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - - def forward( - self, - qkv: torch.Tensor, - kv: Optional[torch.Tensor] = None, - seqlen_offset: Union[int, torch.Tensor] = 0, - max_seqlen: Optional[int] = None, - num_heads_q: Optional[int] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - qkv: (batch, seqlen, 3, nheads, headdim) or (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) - if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim). - If qkv has shape (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA), - then num_heads_q must be provided. - kv: (batch, seqlen, 2, nheads, headdim) - seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. - Most commonly used in inference when we have KV cache. - If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one - should pass in max_seqlen, which will update the cos / sin cache up to that length. - Apply rotary embedding *inplace* to qkv and / or kv. - """ - seqlen = qkv.shape[1] - if max_seqlen is not None: - self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) - elif isinstance(seqlen_offset, int): - self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) - if kv is None: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached, - self._sin_cached, - self._cos_k_cached if self.scale is not None else None, - self._sin_k_cached if self.scale is not None else None, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - num_heads_q=num_heads_q, - ) - else: - q = qkv - q = apply_rotary_emb_func( - q, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - inplace=True, - seqlen_offsets=seqlen_offset, - ) - kv = apply_rotary_emb_kv_( - kv, - self._cos_cached if self.scale is None else self._cos_k_cached, - self._sin_cached if self.scale is None else self._sin_k_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - return q, kv diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/activations.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/activations.py deleted file mode 100644 index 7c09649fc41e12d5a360c5672825d8380bc7ec80..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/activations.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -# 1/sqrt(2*pi)-> 0.3989423 -# 1/sqrt(2) -> 0.70710678 -# sqrt(2/pi) -> 0.79788456 - -# this function is tanh approximation of gelu -# actual gelu is: -# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) -@torch.jit.script -def bias_gelu(y, bias): - x = bias + y - return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) - - -# gradient of tanh approximation of gelu -# gradient of actual gelu is: -# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) -@torch.jit.script -def bias_gelu_back(g, y, bias): - """Assume that y has shape (B, D) and bias has shape (D)""" - x = bias + y - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) - grad_y = ff * g - return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) - - -class GeLUFunction(torch.autograd.Function): - @staticmethod - # bias is an optional argument - def forward(ctx, input, bias): - ctx.save_for_backward(input, bias) - return bias_gelu(input, bias) - - @staticmethod - def backward(ctx, grad_output): - input, bias = ctx.saved_tensors - tmp = bias_gelu_back(grad_output, input, bias) - return tmp, tmp - - -bias_gelu_impl = GeLUFunction.apply - -# this function is tanh approximation of gelu -# actual gelu is: -# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) -@torch.jit.script -def gelu_fwd(x): - return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) - - -# gradient of tanh approximation of gelu -# gradient of actual gelu is: -# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) -@torch.jit.script -def gelu_bwd(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) - return (ff * g).to(dtype=x.dtype) - - -class FastGeLUFunction(torch.autograd.Function): - @staticmethod - # bias is an optional argument - def forward(ctx, input): - ctx.save_for_backward(input) - return gelu_fwd(input) - - @staticmethod - def backward(ctx, grad_output): - (input,) = ctx.saved_tensors - tmp = gelu_bwd(grad_output, input) - return tmp - - -fast_gelu_impl = FastGeLUFunction.apply - - -@torch.jit.script -def relu_bwd(g, x): - return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) - - -@torch.jit.script -def sqrelu_fwd(x): - r = F.relu(x) - return (r * r).to(dtype=x.dtype) - - -@torch.jit.script -def sqrelu_bwd(g, x): - return (2.0 * g * F.relu(x)).to(dtype=x.dtype) - - -swiglu_fwd_codestring = """ -template T swiglu_fwd(T x, T y) { - return float(x) * float(y) / (1.0f + ::exp(-float(x))); -} -""" -swiglu_bwd_codestring = """ -template void swiglu_bwd(T x, T y, T g, T& dx, T& dy) { - float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); - dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); - dy = float(x) * x_sigmoid * float(g); -} -""" -swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) -swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) - - -class SwiGLUFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, y): - ctx.save_for_backward(x, y) - return swiglu_fwd(x, y) - - @staticmethod - def backward(ctx, dout): - x, y = ctx.saved_tensors - return swiglu_bwd(x, y, dout) - -swiglu = SwiGLUFunction.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/fused_dense.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/fused_dense.py deleted file mode 100644 index 6b4033d134e4093fe278f7b3f8c7d3128ce9f36d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/fused_dense.py +++ /dev/null @@ -1,688 +0,0 @@ -# Copyright (c) 2023, Tri Dao. -# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py -# We make it work with pytorch amp and with bfloat16. -# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py -from functools import partial -from typing import Optional - -# import fused_dense_cuda # from apex -import fused_dense_lib as fused_dense_cuda -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup - -from flash_attn.utils.torch import custom_fwd, custom_bwd -from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd -from flash_attn.utils.distributed import ( - all_gather_raw, - all_reduce, - all_reduce_raw, - reduce_scatter, - reduce_scatter_raw, -) - - -class FusedDenseFunc(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True - ): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather_raw of x before doing the matmul. - """ - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - - if torch.is_autocast_enabled(): - weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) - bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - weight = weight.contiguous() - if process_group is not None and sequence_parallel: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - output = F.linear(total_x, weight, bias) - if ctx.compute_weight_gradient: - ctx.save_for_backward(x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - if ctx.compute_weight_gradient: - x, weight = ctx.saved_tensors - if process_group is not None and sequence_parallel: - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - else: - (weight,) = ctx.saved_tensors - total_x = None - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, weight.t()) - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw - grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.needs_input_grad[1]: - assert ctx.compute_weight_gradient - if process_group is not None and sequence_parallel: - handle_x.wait() - grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] - ) - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None - if process_group is not None and ctx.needs_input_grad[0]: - handle_grad_input.wait() - return grad_input, grad_weight, grad_bias, None, None, None - - -def fused_dense_func( - x: Tensor, - weight: Tensor, - bias: Optional[Tensor] = None, - return_residual: bool = False, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, -): - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FusedDenseFunc.apply( - x, weight, bias, return_residual, process_group, sequence_parallel - ) - else: - assert process_group is None - out = F.linear(x, weight, bias) - return out if not return_residual else (out, x) - - -class FusedDense(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - return_residual: bool = False, - device=None, - dtype=None, - ) -> None: - super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - self.return_residual = return_residual - - def forward(self, x, process_group=None): - """ - If process_group is not None, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - """ - return fused_dense_func( - x, - self.weight, - self.bias, - return_residual=self.return_residual, - process_group=process_group, - ) - - -class ColumnParallelLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - multiple_of=1, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - if out_features % multiple_of: - raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") - multiple = out_features // multiple_of - # We want to split @multiple across world_size, but it could be an uneven split - div = multiple // world_size - mod = multiple % world_size - # The first @mod ranks get @div + 1 copies, the rest get @div copies - local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) - super().__init__( - in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype - ) - self.process_group = process_group - self.sequence_parallel = sequence_parallel - - def forward(self, x): - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - return fused_dense_func( - x, - self.weight, - self.bias, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - ) - - -class RowParallelLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - multiple_of=1, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - rank = torch.distributed.get_rank(process_group) - if in_features % multiple_of: - raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") - multiple = in_features // multiple_of - # We want to split @multiple across world_size, but it could be an uneven split - div = multiple // world_size - mod = multiple % world_size - # The first @mod ranks get @div + 1 copies, the rest get @div copies - local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) - # Only rank 0 will have bias - super().__init__( - local_multiple * multiple_of, - out_features, - bias=bias and rank == 0, - device=device, - dtype=dtype, - ) - self.process_group = process_group - self.sequence_parallel = sequence_parallel - - def forward(self, x): - """ - We're doing Tensor Parallel with sequence parallelism: we do the matmul and then - a reduce_scatter of the result. - """ - out = fused_dense_func(x, self.weight, self.bias) - reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce - return reduce_fn(out, self.process_group) - - -class FusedMLPFunc(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, - x, - weight1, - bias1, - weight2, - bias2, - activation="gelu_approx", - save_pre_act=True, - return_residual=False, - checkpoint_lvl=0, - heuristic=0, - process_group=None, - sequence_parallel=True, - ): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather of x before doing the matmul. - If sequence_parallel=False, then the input is already gathered. - - checkpoint_lvl: - 0: no recomputation in the bwd - 1: recompute gelu_out / relu_out in the bwd - 2: recompute pre_act and gelu_out / relu_out in the bwd - """ - assert -1 <= heuristic <= 4 - assert activation in ["gelu_approx", "relu", "sqrelu"] - if activation == "sqrelu": - assert heuristic == -1 - if not save_pre_act: - checkpoint_lvl = 2 - assert checkpoint_lvl in [0, 1, 2] - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - ctx.checkpoint_lvl = checkpoint_lvl - ctx.activation = activation - ctx.heuristic = heuristic - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_gpu_dtype() - weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]] - bias1 = bias1.to(dtype=dtype) if bias1 is not None else None - bias2 = bias2.to(dtype=dtype) if bias2 is not None else None - weight1 = weight1.contiguous() - bias1 = bias1.contiguous() if bias1 is not None else None - weight2 = weight2.contiguous() - bias2 = bias2.contiguous() if bias2 is not None else None - if process_group is not None and sequence_parallel: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - if heuristic == -1: - pre_act = F.linear(total_x, weight1, bias1) - activation_fn = ( - partial(F.gelu, approximate="tanh") - if activation == "gelu_approx" - else (sqrelu_fwd if activation == "sqrelu" else F.relu) - ) - with torch.jit.fuser("fuser2"): - output1 = activation_fn(pre_act) - # This is before adding bias1 - # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1) - # with torch.jit.fuser('fuser2'): - # output1 = bias_gelu(pre_act, bias1) - else: - is_gelu = activation == "gelu_approx" - output1, *rest = fused_dense_cuda.linear_act_forward( - total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic - ) - if save_pre_act: - pre_act = rest[0] - output2 = F.linear(output1, weight2, bias2) - if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): - # For RELU the pre_act is very small (just a bit-mask) so we just save it - ctx.save_for_backward(x, weight1, weight2, pre_act, output1) - elif checkpoint_lvl == 1: - ctx.save_for_backward(x, weight1, weight2, pre_act) - elif checkpoint_lvl == 2: - ctx.save_for_backward(x, weight1, weight2, bias1) - output2 = output2.reshape(*batch_shape, output2.shape[-1]) - return output2 if not return_residual else (output2, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - checkpoint_lvl = ctx.checkpoint_lvl - activation = ctx.activation - activation_fn = ( - partial(F.gelu, approximate="tanh") - if activation == "gelu_approx" - else (sqrelu_fwd if activation == "sqrelu" else F.relu) - ) - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - x, weight1, weight2, *rest = ctx.saved_tensors - if process_group is None or not sequence_parallel: - total_x = x - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - if checkpoint_lvl in [0, 1]: - if process_group is not None and sequence_parallel: - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): - pre_act, output1 = rest - elif checkpoint_lvl == 1: - (pre_act,) = rest - with torch.jit.fuser("fuser2"): - output1 = activation_fn(pre_act) - elif checkpoint_lvl == 2: - (bias1,) = rest - if process_group is not None and sequence_parallel: - total_x, _ = all_gather_raw(x, process_group) - if ctx.heuristic == -1: - pre_act = F.linear(total_x, weight1, bias1) - with torch.jit.fuser("fuser2"): - output1 = activation_fn(pre_act) - else: - output1, pre_act = fused_dense_cuda.linear_act_forward( - total_x.reshape(batch_dim, total_x.shape[-1]), - weight1, - bias1, - activation == "gelu_approx", - True, - ctx.heuristic, - ) - - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - output1 = output1.reshape(batch_dim, output1.shape[-1]) - pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1]) - if ctx.needs_input_grad[3]: - grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad( - output1, grad_output, ctx.needs_input_grad[4] - ) - else: - grad_weight2 = None - grad_bias2 = grad_output if ctx.needs_input_grad[4] else None - if ctx.heuristic == -1: - # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act) - grad_output1 = F.linear(grad_output, weight2.t()) - activation_grad_fn = ( - gelu_bwd - if activation == "gelu_approx" - else (sqrelu_bwd if activation == "sqrelu" else relu_bwd) - ) - with torch.jit.fuser("fuser2"): - grad_pre_act = activation_grad_fn(grad_output1, pre_act) - else: - # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't - # just compute gelu/relu grad - grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad( - weight2, grad_output, pre_act, activation == "gelu_approx", ctx.heuristic - ) - if not ctx.needs_input_grad[2]: - grad_bias1 = None - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_pre_act, weight1.t()) - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_pre_act, weight1 - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw - grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.heuristic == -1: - if ctx.needs_input_grad[1]: - if process_group is not None and sequence_parallel and checkpoint_lvl != 2: - handle_x.wait() - grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), - grad_pre_act, - ctx.needs_input_grad[2], - ) - else: - grad_weight1 = None - grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None - else: - if ctx.needs_input_grad[1]: - if process_group is not None and sequence_parallel and checkpoint_lvl != 2: - handle_x.wait() - grad_weight1 = F.linear( - grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t() - ) - else: - grad_weight1 = None - if process_group is not None and ctx.needs_input_grad[0]: - handle_grad_input.wait() - return ( - grad_input, - grad_weight1, - grad_bias1, - grad_weight2, - grad_bias2, - None, - None, - None, - None, - None, - None, - None, - ) - - -def fused_mlp_func( - x: Tensor, - weight1: Tensor, - weight2: Tensor, - bias1: Optional[Tensor] = None, - bias2: Optional[Tensor] = None, - activation: str = "gelu_approx", - save_pre_act: bool = True, - return_residual: bool = False, - checkpoint_lvl: int = 0, - heuristic: int = 0, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, -): - assert activation in ["gelu_approx", "relu", "sqrelu"] - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu) - dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == "relu" else 8) == 0) - if ( - x.is_cuda - and weight1.is_cuda - and weight2.is_cuda - and (bias1 is None or bias1.is_cuda) - and (bias2 is None or bias2.is_cuda) - and dtype_eligible - and dim_eligible - ): - return FusedMLPFunc.apply( - x, - weight1, - bias1, - weight2, - bias2, - activation, - save_pre_act, - return_residual, - checkpoint_lvl, - heuristic, - process_group, - sequence_parallel, - ) - else: - assert process_group is None - pre_act = F.linear(x, weight1, bias1) - activation_fn = ( - partial(F.gelu, approximate="tanh") - if activation == "gelu_approx" - else partial(F.relu, inplace=True) - ) - output1 = activation_fn(pre_act) - output2 = F.linear(output1, weight2, bias2) - return output2 if not return_residual else (output2, x) - - -class FusedMLP(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - bias1=True, - bias2=True, - activation="gelu_approx", - return_residual=False, - checkpoint_lvl=0, - heuristic="auto", - device=None, - dtype=None, - ): - """ - If process_group is not None, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul, gelu, then matmul. - Finally we do a reduce_scatter of the output. - - checkpoint_lvl (increasing lvl means slower but more memory saving): - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute pre_act and gelu_out in the bwd - heuristic: - -1: don't fuse gemm + gelu (separate kernel) - 0..4: use this heuristic for the algo section in the fused gemm + gelu - 'auto': heuristic will be picked automatically: - For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. - For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. - For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation - is slower than the unfused version. - return_residual: whether to return the input x along with the output. This is for - performance reason: for post-norm architecture, returning the input allows us - to fuse the backward of nn.Linear with the residual connection. - """ - assert checkpoint_lvl in [0, 1, 2] - assert activation in ["gelu_approx", "relu", "sqrelu"] - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features * 4 - self.activation = activation - self.return_residual = return_residual - self.checkpoint_lvl = checkpoint_lvl - self.heuristic = heuristic if activation != "sqrelu" else -1 - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) - - def forward(self, x, process_group=None): - dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() - if self.heuristic == "auto": - if self.activation == "gelu_approx": - if torch.cuda.get_device_capability("cuda") == (9, 0): - heuristic = -1 - else: - cuda_ver = tuple(map(int, torch.version.cuda.split("."))) - heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) - else: - heuristic = 0 - else: - heuristic = self.heuristic - out = fused_mlp_func( - x, - self.fc1.weight, - self.fc2.weight, - self.fc1.bias, - self.fc2.bias, - activation=self.activation, - save_pre_act=self.training, - return_residual=self.return_residual, - checkpoint_lvl=self.checkpoint_lvl, - heuristic=heuristic, - process_group=process_group, - ) - if self.return_residual: - out, x = out - if process_group is not None: - out = reduce_scatter(out, process_group) - return out if not self.return_residual else (out, x) - - -class ParallelFusedMLP(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - activation="gelu_approx", - process_group: ProcessGroup = None, - bias1=True, - bias2=True, - sequence_parallel=True, - checkpoint_lvl=0, - heuristic="auto", - device=None, - dtype=None, - ): - """ - process_group is required. We're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul, gelu, then matmul. - Finally we do a reduce_scatter of the output. - - checkpoint_lvl (increasing lvl means slower but more memory saving): - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute pre_act and gelu_out in the bwd - heuristic: - -1: don't fuse gemm + gelu (separate kernel) - 0..4: use this heuristic for the algo section in the fused gemm + gelu - 'auto': heuristic will be picked automatically: - For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. - For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. - """ - assert checkpoint_lvl in [0, 1, 2] - assert activation in ["gelu_approx", "relu", "sqrelu"] - assert process_group is not None - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features * 4 - self.activation = activation - self.process_group = process_group - self.sequence_parallel = sequence_parallel - self.checkpoint_lvl = checkpoint_lvl - self.heuristic = heuristic if activation != "sqrelu" else -1 - self.fc1 = ColumnParallelLinear( - in_features, hidden_features, process_group, bias=bias1, **factory_kwargs - ) - self.fc2 = RowParallelLinear( - hidden_features, out_features, process_group, bias=bias2, **factory_kwargs - ) - - def forward(self, x): - dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() - if self.heuristic == "auto": - if self.activation == "gelu_approx": - cuda_ver = tuple(map(int, torch.version.cuda.split("."))) - heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) - else: - heuristic = 0 - else: - heuristic = self.heuristic - out = fused_mlp_func( - x, - self.fc1.weight, - self.fc2.weight, - self.fc1.bias, - self.fc2.bias, - activation=self.activation, - save_pre_act=self.training, - checkpoint_lvl=self.checkpoint_lvl, - heuristic=heuristic, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - ) - reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce - return reduce_fn(out, self.process_group) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/layer_norm.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/layer_norm.py deleted file mode 100644 index 4b6cd798fd02844ef9cd3897f8ab95e490e638bf..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/layer_norm.py +++ /dev/null @@ -1,800 +0,0 @@ -# Copyright (c) 2022, Tri Dao. -# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py - -import dropout_layer_norm -import torch -from torch.nn import init - - -def maybe_align(x, alignment_in_bytes=16): - """Assume that x already has last dim divisible by alignment_in_bytes""" - # TD [2023-07-04] I'm not 100% sure that clone will align the memory - # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 - return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() - - -def _dropout_add_layer_norm_forward( - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma.numel() - x0mat = x0.view((-1, hidden_size)) - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - rowscale = rowscale.view(-1) if rowscale is not None else None - zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, - residualmat, - gamma, - beta, - rowscale, - colscale, - None, - None, - dropout_p, - epsilon, - 1.0, - 0, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma - - -def _dropout_add_layer_norm_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - dropout_p, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - x0 must not be None if we have colscale. - """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(xmat.shape) - dxmat = dx.view(xmat.shape) if dx is not None else None - x0mat = x0.view((-1, hidden_size)) if x0 is not None else None - rowscale = rowscale.view(-1) if rowscale is not None else None - if colscale is not None: - assert x0 is not None, "x0 is required to compute the gradient of colscale" - dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, - dxmat, - xmat, - x0mat, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - None, - None, - dropout_p, - 1.0, - 0, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - if colscale is None: - return dx0mat, dresidualmat, dgamma, dbeta - else: - dcolscale = rest[0] - return dx0mat, dresidualmat, dgamma, dbeta, dcolscale - - -def _dropout_add_layer_norm_subset_forward( - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma.numel() - x0mat = x0.view((-1, hidden_size)) - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - x0_subset = x0_subset.view(-1) if x0_subset is not None else None - out_subset = out_subset.view(-1) if out_subset is not None else None - zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, - residualmat, - gamma, - beta, - None, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma - - -def _dropout_add_layer_norm_subset_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - colscale, - x0_subset, - out_subset, - dropout_p, - rowscale_const, - x0_numrows, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - x0 must not be None if we have colscale. - """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(-1, hidden_size) - dxmat = dx.view(xmat.shape) if dx is not None else None - x0mat = x0.view((-1, hidden_size)) if x0 is not None else None - x0_subset = x0_subset.view(-1) if x0_subset is not None else None - out_subset = out_subset.view(-1) if out_subset is not None else None - if colscale is not None: - assert x0 is not None, "x0 is required to compute the gradient of colscale" - dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, - dxmat, - xmat, - x0mat, - dmask, - mu, - rsigma, - gamma, - None, - colscale, - x0_subset, - out_subset, - dropout_p, - rowscale_const, - x0_numrows, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - if colscale is None: - return dx0mat, dresidualmat, dgamma, dbeta - else: - dcolscale = rest[0] - return dx0mat, dresidualmat, dgamma, dbeta, dcolscale - - -def _dropout_add_layer_norm_parallel_residual_forward( - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma0.numel() - x0mat = x0.view((-1, hidden_size)) - x1mat = x1.view((-1, hidden_size)) if x1 is not None else None - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - ( - z0mat, - z1mat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( - x0mat, - x1mat, - residualmat, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask0 and dmask1 are None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma - - -def _dropout_add_layer_norm_parallel_residual_backward( - dz0, - dz1, - dx, - x, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - """ - hidden_size = gamma0.numel() - xmat = x.view((-1, hidden_size)) - dz0mat = dz0.view(xmat.shape) - dz1mat = dz1.view(xmat.shape) if dz1 is not None else None - dxmat = dx.view(xmat.shape) if dx is not None else None - ( - dx0mat, - dx1mat, - dresidualmat, - dgamma0, - dbeta0, - dgamma1, - dbeta1, - *rest, - ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( - dz0mat, - dz1mat, - dxmat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 - - -class DropoutAddLayerNormFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma = maybe_align(gamma.contiguous(), 16) - beta = maybe_align(beta.contiguous(), 16) if beta is not None else None - rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None - colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32, - is_rms_norm, - ) - # Only need to save x0 if we need to compute gradient wrt colscale - x0_saved = x0 if colscale is not None else None - ctx.save_for_backward( - xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta is not None - if not return_dmask: - return ( - zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) - ) - else: - dmask = ( - dmask.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask) - return ( - (zmat.view(x0.shape), dmask) - if not prenorm - else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) - ) - - @staticmethod - def backward(ctx, dz, *args): - # assert dz.is_contiguous() - dz = maybe_align(dz.contiguous(), 16) # this happens! - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors - # x0 is None if colscale is None - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - dropout_p, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(x.shape) - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - dcolscale = rest[0] if colscale is not None else None - return ( - dx0, - dresidual, - dgamma, - dbeta if ctx.has_beta else None, - None, - dcolscale, - None, - None, - None, - None, - None, - None, - ) - - -class DropoutAddLayerNormSubsetFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma = maybe_align(gamma.contiguous(), 16) - beta = maybe_align(beta.contiguous(), 16) if beta is not None else None - colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - is_rms_norm, - ) - # Only need to save x0 if we need to compute gradient wrt colscale - x0_saved = x0 if colscale is not None else None - x_shape = (-1, *x0.shape[1:]) - ctx.save_for_backward( - xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.rowscale_const = rowscale_const - ctx.x0_numrows = x0.shape[:-1].numel() - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta is not None - z_shape = (-1, *x0.shape[1:]) - if not return_dmask: - return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) - else: - z = zmat.view(z_shape) - dmask = ( - dmask.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask) - return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) - - @staticmethod - def backward(ctx, dz, *args): - # assert dz.is_contiguous() - dz = maybe_align(dz.contiguous(), 16) # this happens! - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors - # x0 is None if colscale is None - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - colscale, - x0_subset, - out_subset, - dropout_p, - ctx.rowscale_const, - ctx.x0_numrows, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(-1, *x.shape[1:]) - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - dcolscale = rest[0] if colscale is not None else None - return ( - dx0, - dresidual, - dgamma, - dbeta if ctx.has_beta else None, - dcolscale, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma0 = maybe_align(gamma0.contiguous(), 16) - beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None - gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None - beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None - ( - z0mat, - z1mat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - ) = _dropout_add_layer_norm_parallel_residual_forward( - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32, - is_rms_norm, - ) - ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.has_x1 = x1 is not None - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta0 is not None - z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) - if not return_dmask: - return z if not prenorm else (*z, xmat.view(x0.shape)) - else: - dmask0 = ( - dmask0.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - dmask1 = ( - dmask1.view(x0.shape) - if dropout_p > 0.0 and x1 is not None - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask0) - ctx.mark_non_differentiable(dmask1) - return ( - (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) - ) - - @staticmethod - def backward(ctx, dz0, dz1, *args): - dz0 = maybe_align(dz0.contiguous(), 16) # this happens! - dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors - dropout_p = ctx.dropout_p - has_x1 = ctx.has_x1 - has_residual = ctx.has_residual - ( - dx0mat, - dx1mat, - dresidualmat, - dgamma0, - dbeta0, - dgamma1, - dbeta1, - ) = _dropout_add_layer_norm_parallel_residual_backward( - dz0, - dz1, - dx, - x, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(x.shape) - dx1 = dx1mat.view(x.shape) if dx1mat is not None else None - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - return ( - dx0, - dx1, - dresidual, - dgamma0, - dbeta0 if ctx.has_beta else None, - dgamma1, - dbeta1 if ctx.has_beta else None, - None, - None, - None, - None, - None, - None, - ) - - -def layer_norm(x, weight, bias, epsilon): - return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) - - -def dropout_add_layer_norm( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - rowscale=None, - layerscale=None, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormFn.apply( - x0, - residual, - weight, - bias, - rowscale, - layerscale, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -def dropout_add_layer_norm_subset( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - layerscale=None, - x0_subset=None, - out_subset=None, - rowscale_const=1.0, - out_numrows=0, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormSubsetFn.apply( - x0, - residual, - weight, - bias, - layerscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -def dropout_add_layer_norm_parallel_residual( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormParallelResidualFn.apply( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -class DropoutAddLayerNorm(torch.nn.Module): - def __init__( - self, - hidden_size, - prenorm=False, - p=0.0, - eps=1e-5, - residual_in_fp32=False, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.prenorm = prenorm - self.p = p - self.eps = eps - self.residual_in_fp32 = residual_in_fp32 - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - init.zeros_(self.bias) - - def forward(self, x0, residual=None): - return dropout_add_layer_norm( - x0, - residual, - self.weight, - self.bias, - self.p if self.training else 0.0, - self.eps, - prenorm=self.prenorm, - residual_in_fp32=self.residual_in_fp32, - ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/rms_norm.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/rms_norm.py deleted file mode 100644 index 068348d61290e3839dd082b540d898578ba1e8e2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/rms_norm.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) 2022, Tri Dao. -# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py - -import torch -from torch.nn import init - -from flash_attn.ops.layer_norm import ( - DropoutAddLayerNormFn, - DropoutAddLayerNormParallelResidualFn, - DropoutAddLayerNormSubsetFn, -) - - -def rms_norm(x, weight, epsilon): - return DropoutAddLayerNormFn.apply( - x, None, weight, None, None, None, 0.0, epsilon, False, False, True - ) - - -def dropout_add_rms_norm( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - rowscale=None, - layerscale=None, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormFn.apply( - x0, - residual, - weight, - bias, - rowscale, - layerscale, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - True, - return_dropout_mask, - ) - - -def dropout_add_rms_norm_subset( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - layerscale=None, - x0_subset=None, - out_subset=None, - rowscale_const=1.0, - out_numrows=0, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormSubsetFn.apply( - x0, - residual, - weight, - bias, - layerscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - prenorm, - True, - return_dropout_mask, - ) - - -def dropout_add_rms_norm_parallel_residual( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormParallelResidualFn.apply( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - True, - return_dropout_mask, - ) - - -class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - - def forward(self, x): - return rms_norm(x, self.weight, self.eps) - - -class DropoutAddRMSNorm(torch.nn.Module): - def __init__( - self, - hidden_size, - prenorm=False, - p=0.0, - eps=1e-5, - residual_in_fp32=False, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.prenorm = prenorm - self.p = p - self.eps = eps - self.residual_in_fp32 = residual_in_fp32 - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - - def forward(self, x0, residual=None): - return dropout_add_rms_norm( - x0, - residual, - self.weight, - None, - self.p if self.training else 0.0, - self.eps, - prenorm=self.prenorm, - residual_in_fp32=self.residual_in_fp32, - ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/__init__.py deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/cross_entropy.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/cross_entropy.py deleted file mode 100644 index 1b5a415b73f236f3e05fb14b9141959559e18526..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/cross_entropy.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -from typing import Tuple, Optional, Union - -import torch -import torch.nn.functional as F - -import triton -import triton.language as tl - -# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for -# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent -# version of PyTorch. The following 2 lines are for backward compatibility with -# older PyTorch. -if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base - - -@triton.heuristics( - { - "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, - } -) -@triton.jit -def cross_entropy_fwd_kernel( - loss_ptr, # data ptrs - lse_ptr, - z_loss_ptr, - logits_ptr, - labels_ptr, - smoothing, - logit_scale, - lse_square_scale, - ignore_index, - total_classes, - class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes - n_cols, # shapes - logits_row_stride, # strides - BLOCK_SIZE: tl.constexpr, - HAS_SMOOTHING: tl.constexpr, - # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE - SPLIT: tl.constexpr, - PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0) -): - row_idx = tl.program_id(0) - logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) - sum_logits = 0.0 # For smoothing - if not PRECOMPUTED_LSE: - # Statistics for online softmax - m_i = -float("inf") - l_i = 0.0 - for col_offset in range(0, n_cols, BLOCK_SIZE): - cols = col_offset + tl.arange(0, BLOCK_SIZE) - logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( - tl.float32 - ) * logit_scale - if HAS_SMOOTHING: - sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0)) - m_i_new = tl.maximum(m_i, tl.max(logits)) - l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new)) - m_i = m_i_new - lse = tl.log(l_i) + m_i - tl.store(lse_ptr + row_idx, lse) - else: - lse = tl.load(lse_ptr + row_idx) - label_idx = tl.load(labels_ptr + row_idx) - if label_idx == ignore_index: - loss = 0.0 - z_loss = 0.0 - else: - label_idx -= class_start_idx - if label_idx >= 0 and label_idx < n_cols: - logits_label = tl.load(logits_ptr + label_idx) * logit_scale - if HAS_SMOOTHING: - loss = ( - (lse if not SPLIT else 0.0) - - smoothing * sum_logits / total_classes - - (1 - smoothing) * logits_label - ) - else: - loss = (lse if not SPLIT else 0.0) - logits_label - else: - # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss - if HAS_SMOOTHING: - loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) - else: - loss = 0.0 - if not SPLIT: - z_loss = lse_square_scale * lse * lse - loss += z_loss - else: - z_loss = 0.0 - tl.store(loss_ptr + row_idx, loss) - if not SPLIT: - tl.store(z_loss_ptr + row_idx, z_loss) - - -@triton.heuristics( - { - "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, - } -) -@triton.jit -def cross_entropy_bwd_kernel( - dlogits_ptr, # data ptrs - dloss_ptr, - logits_ptr, - lse_ptr, - labels_ptr, - smoothing, - logit_scale, - lse_square_scale, - ignore_index, - total_classes, - class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes - n_cols, # shapes - logits_row_stride, # strides - dlogits_row_stride, - dloss_row_stride, - BLOCK_SIZE: tl.constexpr, - HAS_SMOOTHING: tl.constexpr, -): - row_idx = tl.program_id(0) - col_block_idx = tl.program_id(1) - logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) - dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) - col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - label_idx = tl.load(labels_ptr + row_idx) - if label_idx != ignore_index: - dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) - else: - dloss = 0.0 - logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( - tl.float32 - ) * logit_scale - lse = tl.load(lse_ptr + row_idx) - probs = tl.exp(logits - lse) - probs += 2.0 * lse_square_scale * lse * probs - label_idx -= class_start_idx - if HAS_SMOOTHING: - smooth_positive = 1.0 - smoothing - smooth_negative = smoothing / total_classes - probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative - else: - probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) - tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) - - -class CrossEntropyLoss(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - logits, - labels, - precomputed_lse=None, - smoothing=0.0, - logit_scale=1.0, - lse_square_scale=0.0, - ignore_index=-100, - inplace_backward=False, - process_group=None, - ): - # For some reason Triton generates wrong code when labels has dtype long and its address - # is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index. - if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: - labels = F.pad(labels, (0, 1))[..., :-1] - assert labels.data_ptr() % 16 == 0 - assert logit_scale > 0.0 - n_rows, n_cols = logits.shape - assert labels.shape == (n_rows,) - world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) - total_classes = world_size * n_cols - rank = 0 if process_group is None else torch.distributed.get_rank(process_group) - class_start_idx = rank * n_cols - use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0 - - if logits.stride(-1) != 1: - logits = logits.contiguous() - MAX_BLOCK_SIZE = 16 * 1024 - BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) - num_warps = ( - 4 - if BLOCK_SIZE < 2048 - else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) - ) - losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - if use_precomputed_lse: - assert precomputed_lse.shape == (n_rows,) - lse = precomputed_lse.contiguous() - else: - lse = torch.empty(n_rows, dtype=torch.float, device=logits.device) - z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(logits.device.index): - cross_entropy_fwd_kernel[(n_rows,)]( - losses, # data ptrs - lse, - z_losses, - logits, - labels, - smoothing, - logit_scale, - lse_square_scale, - ignore_index, - total_classes, - class_start_idx, - n_cols, # shapes - logits.stride(0), # strides - BLOCK_SIZE=BLOCK_SIZE, # constants - SPLIT=world_size > 1, - PRECOMPUTED_LSE=use_precomputed_lse, - num_warps=num_warps, - ) - - if world_size > 1: - # If there's no smoothing, if labels are in the vocab of this partition, losses contains - # - predicted logit, and 0 otherwise. - # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains - # -0.9 * predicted logit - 0.1 * sum logit / total_classes. - # For labels not in the vocab of this partition, losses contains - # -0.1 * sum logit / total_classes. - if world_size > 1: - lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) - torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) - handle_losses = torch.distributed.all_reduce( - losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True - ) - lse = torch.logsumexp(lse_allgather, dim=0) - handle_losses.wait() - # After the allreduce, if there's no smoothing, the total losses are - predicted_logit, - # we just have to add the (global) lse. - # If there's smoothing=0.1, the total losses are - # -0.9 * predicted_logit - 0.1 * sum logit / total_classes. - # Again, we just have to add the (global) lse. - losses += lse - if lse_square_scale != 0.0: - z_losses = lse_square_scale * lse.square() - z_losses.masked_fill_(labels == ignore_index, 0.0) - losses += z_losses - else: - z_losses = torch.zeros_like(losses) - losses.masked_fill_(labels == ignore_index, 0.0) - - ctx.save_for_backward(logits, lse, labels) - ctx.mark_non_differentiable(z_losses) - ctx.smoothing = smoothing - ctx.logit_scale = logit_scale - ctx.lse_square_scale = lse_square_scale - ctx.ignore_index = ignore_index - ctx.total_classes = total_classes - ctx.class_start_idx = class_start_idx - ctx.inplace_backward = inplace_backward - return losses, z_losses - - @staticmethod - def backward(ctx, grad_losses, grad_z_losses): - del grad_z_losses # z_losses are only for logging. - - logits, lse, labels = ctx.saved_tensors - dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) - n_rows, n_cols = logits.shape - BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) - num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) - grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(logits.device.index): - cross_entropy_bwd_kernel[grid]( - dlogits, # data ptrs - grad_losses, - logits, - lse, - labels, - ctx.smoothing, - ctx.logit_scale, - ctx.lse_square_scale, - ctx.ignore_index, - ctx.total_classes, - ctx.class_start_idx, - n_cols, # shapes - logits.stride(0), # strides - dlogits.stride(0), - grad_losses.stride(0), - BLOCK_SIZE=BLOCK_SIZE, # constants - num_warps=num_warps, - ) - return dlogits, None, None, None, None, None, None, None, None, None - - -def cross_entropy_loss( - logits: torch.Tensor, - labels: torch.Tensor, - precomputed_lse: Optional[torch.Tensor] = None, - label_smoothing: float = 0.0, - logit_scale: float = 1.0, - lse_square_scale: float = 0.0, - ignore_index=-100, - inplace_backward: bool = False, - process_group=None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Arguments: - logits: (batch, vocab_size) - labels: (batch,) - label_smoothing: float - logit_scale: float. Multiply logits by this scale before calculating the loss. - lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. - This is also referred to as "z-loss". - ignore_index: int. If labels == ignore_index, the loss is set to 0.0. - inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. - This saves memory. - process_group: if not None, we're doing Tensor Parallel: each process is responsible for - one part of the vocab. The loss will be aggregated across processes. - Returns: - losses: (batch,), float - z_losses: (batch,), float - """ - return CrossEntropyLoss.apply( - logits, - labels, - precomputed_lse, - label_smoothing, - logit_scale, - lse_square_scale, - ignore_index, - inplace_backward, - process_group, - ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/k_activations.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/k_activations.py deleted file mode 100644 index efb83c358eb4a85d069ee340a3c83f418f9a805b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/k_activations.py +++ /dev/null @@ -1,162 +0,0 @@ -# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import math -from enum import Enum -from typing import Optional - -import triton -import triton.language as tl - -_sqrt2pi = math.sqrt(2.0 / math.pi) -_sqrt1_2 = math.sqrt(1.0 / 2) -_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) - - -class Activation(str, Enum): - SquaredReLU = "squared_relu" - GeLU = "gelu" - GeLUApprox = "gelu_approx" - LeakyReLU = "leaky_relu" - ReLU = "relu" - - -def get_triton_activation_kernel(activation: Optional[Activation]): - return ( - { - Activation.ReLU: relu, - Activation.LeakyReLU: leaky_relu, - Activation.GeLU: gelu, - Activation.GeLUApprox: gelu_approx, - Activation.SquaredReLU: squared_relu, - }[activation] - if activation - else None - ) - - -def get_triton_activation_bwd_kernel(activation: Optional[Activation]): - return ( - { - Activation.ReLU: relu_grad, - Activation.LeakyReLU: leaky_relu_grad, - Activation.GeLU: gelu_grad, - Activation.GeLUApprox: gelu_approx_grad, - Activation.SquaredReLU: squared_relu_grad, - }[activation] - if activation - else None - ) - - -@triton.jit -def tanh(x): - # Tanh is just a scaled sigmoid - return 2 * tl.sigmoid(2 * x) - 1 - - -@triton.jit -def cosh(x): - exp_x = tl.exp(x) - return (exp_x + 1.0 / exp_x) * 0.5 - - -# a Triton implementation of the most used activations -# See for instance http://arxiv.org/abs/1606.08415 for an overview - -# ReLU -@triton.jit -def relu(x): - """ - ReLU_ activation function - - .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html - """ - zero = 0.0 - return tl.where(x >= 0, x, zero.to(x.dtype)) - - -@triton.jit -def relu_grad(x): - # ReLU is different from other activations - # in that it does not require the input to retrospectively compute its gradient - # here the input is the downstream gradient, and we return the upstream gradient directly - zero = 0.0 - one = 1.0 - return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype)) - - -@triton.jit -def squared_relu(x): - """ - Squared ReLU activation, as proposed in the Primer_ paper. - - .. _Primer: https://arxiv.org/abs/2109.08668 - """ - x_ = relu(x) - return (x_ * x_).to(x.dtype) - - -@triton.jit -def squared_relu_grad(x): - return tl.where(x >= 0, 2.0 * x, 0.0) - - -# Leaky ReLU -@triton.jit -def leaky_relu(x): - """ - LeakyReLU_ activation - - .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html - """ - scale = 0.01 + 0.0 - scale = scale.to(x.dtype) - return tl.where(x >= 0, x, scale * x) - - -@triton.jit -def leaky_relu_grad(x): - min_grad = 0.01 - max_grad = 1 - - min_grad = min_grad.to(x.dtype) - max_grad = max_grad.to(x.dtype) - - return tl.where(x >= 0, max_grad, min_grad) - - -@triton.jit -def gelu(x): - """Gaussian Error Linear Unit (GELU)""" - return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) - - -@triton.jit -def gelu_grad(x): - cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) - pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization - return cdf + x * pdf - - -@triton.jit -def gelu_approx(x): - """ - GeLU_ activation - Gaussian error linear unit, with tanh approximation - - .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf - """ - return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x))) - - -@triton.jit -def gelu_approx_grad(x): - # CREDITS: Fast implementation proposed in - # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 - tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/layer_norm.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/layer_norm.py deleted file mode 100644 index 192cee474b160d1876fafd14c5e3d695e8ff237f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/layer_norm.py +++ /dev/null @@ -1,1252 +0,0 @@ -# Copyright (c) 2024, Tri Dao. -# Implement dropout + residual + layer_norm / rms_norm. - -# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html -# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. -# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. -# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. - -import math -from typing import Optional, List - -import torch -import torch.nn.functional as F -from torch import Tensor - -import triton -import triton.language as tl - -from flash_attn.utils.torch import custom_fwd, custom_bwd -from flash_attn.utils.library import triton_op - - -def maybe_contiguous_lastdim(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def maybe_contiguous(x): - return x.contiguous() if x is not None else None - - -def triton_autotune_configs(): - # Return configs with a valid warp count for the current device - configs = [] - # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 - max_threads_per_block = 1024 - # Default to warp size 32 if not defined by device - warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) - # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit - return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] - if warp_count * warp_size <= max_threads_per_block] - # return [triton.Config({}, num_warps=8)] - - -def layer_norm_ref( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - zero_centered_weight=False, - dropout_mask=None, - dropout_mask1=None, - upcast=False, -): - dtype = x.dtype - if upcast: - x = x.float() - weight = weight.float() - bias = bias.float() if bias is not None else None - residual = residual.float() if residual is not None else residual - x1 = x1.float() if x1 is not None else None - weight1 = weight1.float() if weight1 is not None else None - bias1 = bias1.float() if bias1 is not None else None - if zero_centered_weight: - weight = weight + 1.0 - if weight1 is not None: - weight1 = weight1 + 1.0 - if x1 is not None: - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - if rowscale is not None: - x = x * rowscale[..., None] - if dropout_p > 0.0: - if dropout_mask is not None: - x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) - else: - x = F.dropout(x, p=dropout_p) - if x1 is not None: - if dropout_mask1 is not None: - x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) - else: - x1 = F.dropout(x1, p=dropout_p) - if x1 is not None: - x = x + x1 - if residual is not None: - x = (x + residual).to(x.dtype) - out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( - dtype - ) - if weight1 is None: - return out if not prenorm else (out, x) - else: - out1 = F.layer_norm( - x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps - ).to(dtype) - return (out, out1) if not prenorm else (out, out1, x) - - -def rms_norm_ref( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - zero_centered_weight=False, - dropout_mask=None, - dropout_mask1=None, - upcast=False, -): - dtype = x.dtype - if upcast: - x = x.float() - weight = weight.float() - bias = bias.float() if bias is not None else None - residual = residual.float() if residual is not None else residual - x1 = x1.float() if x1 is not None else None - weight1 = weight1.float() if weight1 is not None else None - bias1 = bias1.float() if bias1 is not None else None - if zero_centered_weight: - weight = weight + 1.0 - if weight1 is not None: - weight1 = weight1 + 1.0 - if x1 is not None: - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - if rowscale is not None: - x = x * rowscale[..., None] - if dropout_p > 0.0: - if dropout_mask is not None: - x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) - else: - x = F.dropout(x, p=dropout_p) - if x1 is not None: - if dropout_mask1 is not None: - x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) - else: - x1 = F.dropout(x1, p=dropout_p) - if x1 is not None: - x = x + x1 - if residual is not None: - x = (x + residual).to(x.dtype) - rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) - out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) - if weight1 is None: - return out if not prenorm else (out, x) - else: - out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( - dtype - ) - return (out, out1) if not prenorm else (out, out1, x) - - -@triton.autotune( - configs=triton_autotune_configs(), - key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS", "HAS_X1", "HAS_W1", "HAS_B1"], -) -# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) -# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) -# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) -# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) -@triton.jit -def _layer_norm_fwd_1pass_kernel( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - RESIDUAL, # pointer to the residual - X1, - W1, - B1, - Y1, - RESIDUAL_OUT, # pointer to the residual - ROWSCALE, - SEEDS, # Dropout seeds for each row - DROPOUT_MASK, - DROPOUT_MASK1, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_res_row, - stride_res_out_row, - stride_x1_row, - stride_y1_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - dropout_p, # Dropout probability - zero_centered_weight, # If true, add 1.0 to the weight - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_RESIDUAL: tl.constexpr, - STORE_RESIDUAL_OUT: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_DROPOUT: tl.constexpr, - STORE_DROPOUT_MASK: tl.constexpr, - HAS_ROWSCALE: tl.constexpr, - HAS_X1: tl.constexpr, - HAS_W1: tl.constexpr, - HAS_B1: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - X += row * stride_x_row - Y += row * stride_y_row - if HAS_RESIDUAL: - RESIDUAL += row * stride_res_row - if STORE_RESIDUAL_OUT: - RESIDUAL_OUT += row * stride_res_out_row - if HAS_X1: - X1 += row * stride_x1_row - if HAS_W1: - Y1 += row * stride_y1_row - # Compute mean and variance - cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + row).to(tl.float32) - x *= rowscale - if HAS_DROPOUT: - # Compute dropout mask - # 7 rounds is good enough, and reduces register pressure - keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) - if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) - if HAS_X1: - x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) - x1 *= rowscale - if HAS_DROPOUT: - # Compute dropout mask - # 7 rounds is good enough, and reduces register pressure - keep_mask = ( - tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - ) - x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) - if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) - x += x1 - if HAS_RESIDUAL: - residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) - x += residual - if STORE_RESIDUAL_OUT: - tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) - if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w += 1.0 - if HAS_BIAS: - b = tl.load(B + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - y = x_hat * w + b if HAS_BIAS else x_hat * w - # Write output - tl.store(Y + cols, y, mask=mask) - if HAS_W1: - w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w1 += 1.0 - if HAS_B1: - b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) - y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 - tl.store(Y1 + cols, y1, mask=mask) - - -def _layer_norm_fwd( - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - residual: Optional[Tensor] = None, - x1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - residual_dtype: Optional[torch.dtype] = None, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - return_dropout_mask: bool = False, - out: Optional[Tensor] = None, - residual_out: Optional[Tensor] = None -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library - # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None - # so that _layer_norm_fwd_impl doesn't have to return them. - if out is None: - out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) - if residual is not None: - residual_dtype = residual.dtype - if residual_out is None and ( - residual is not None - or (residual_dtype is not None and residual_dtype != x.dtype) - or dropout_p > 0.0 - or rowscale is not None - or x1 is not None - ): - residual_out = torch.empty_like( - x, dtype=residual_dtype if residual_dtype is not None else x.dtype - ) - else: - residual_out = None - y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( - x, - weight, - bias, - eps, - out, - residual=residual, - x1=x1, - weight1=weight1, - bias1=bias1, - dropout_p=dropout_p, - rowscale=rowscale, - zero_centered_weight=zero_centered_weight, - is_rms_norm=is_rms_norm, - return_dropout_mask=return_dropout_mask, - residual_out=residual_out, - ) - # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 - if residual_out is None: - residual_out = x - return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 - - -# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema -# since we're returning a tuple of tensors -@triton_op("flash_attn::layer_norm_fwd_impl", mutates_args={"out", "residual_out"}, - schema="(Tensor x, Tensor weight, Tensor bias, float eps, Tensor(a!) out, Tensor? residual, Tensor? x1, Tensor? weight1, Tensor? bias1, float dropout_p, Tensor? rowscale, bool zero_centered_weight, bool is_rms_norm, bool return_dropout_mask, Tensor(a!)? residual_out) -> (Tensor y1, Tensor mean, Tensor rstd, Tensor seeds, Tensor dropout_mask, Tensor dropout_mask1)") -def _layer_norm_fwd_impl( - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - out: Tensor, - residual: Optional[Tensor] = None, - x1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - return_dropout_mask: bool = False, - residual_out: Optional[Tensor] = None -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - M, N = x.shape - assert x.stride(-1) == 1 - if residual is not None: - assert residual.stride(-1) == 1 - assert residual.shape == (M, N) - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - if x1 is not None: - assert x1.shape == x.shape - assert rowscale is None - assert x1.stride(-1) == 1 - if weight1 is not None: - assert weight1.shape == (N,) - assert weight1.stride(-1) == 1 - if bias1 is not None: - assert bias1.shape == (N,) - assert bias1.stride(-1) == 1 - if rowscale is not None: - assert rowscale.is_contiguous() - assert rowscale.shape == (M,) - assert out.shape == x.shape - assert out.stride(-1) == 1 - if residual_out is not None: - assert residual_out.shape == x.shape - assert residual_out.stride(-1) == 1 - if weight1 is not None: - y1 = torch.empty_like(out) - assert y1.stride(-1) == 1 - else: - y1 = None - mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None - rstd = torch.empty((M,), dtype=torch.float32, device=x.device) - if dropout_p > 0.0: - seeds = torch.randint( - 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 - ) - else: - seeds = None - if return_dropout_mask and dropout_p > 0.0: - dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) - if x1 is not None: - dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) - else: - dropout_mask1 = None - else: - dropout_mask, dropout_mask1 = None, None - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - with torch.cuda.device(x.device.index): - torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( - x, - out, - weight, - bias, - residual, - x1, - weight1, - bias1, - y1, - residual_out, - rowscale, - seeds, - dropout_mask, - dropout_mask1, - mean, - rstd, - x.stride(0), - out.stride(0), - residual.stride(0) if residual is not None else 0, - residual_out.stride(0) if residual_out is not None else 0, - x1.stride(0) if x1 is not None else 0, - y1.stride(0) if y1 is not None else 0, - M, - N, - eps, - dropout_p, - # Passing bool make torch inductor very unhappy since it then tries to compare to int_max - int(zero_centered_weight), - is_rms_norm, - BLOCK_N, - residual is not None, - residual_out is not None, - bias is not None, - dropout_p > 0.0, - dropout_mask is not None, - rowscale is not None, - HAS_X1=x1 is not None, - HAS_W1=weight1 is not None, - HAS_B1=bias1 is not None, - ) - return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 - - -@triton.autotune( - configs=triton_autotune_configs(), - key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], -) -# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) -# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) -# @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) -# @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) -# @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) -# @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) -# @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) -@triton.jit -def _layer_norm_bwd_kernel( - X, # pointer to the input - W, # pointer to the weights - B, # pointer to the biases - Y, # pointer to the output to be recomputed - DY, # pointer to the output gradient - DX, # pointer to the input gradient - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - DRESIDUAL, - W1, - DY1, - DX1, - DW1, - DB1, - DRESIDUAL_IN, - ROWSCALE, - SEEDS, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_dy_row, - stride_dx_row, - stride_dres_row, - stride_dy1_row, - stride_dx1_row, - stride_dres_in_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - dropout_p, - zero_centered_weight, - rows_per_program, - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_DRESIDUAL: tl.constexpr, - STORE_DRESIDUAL: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_DROPOUT: tl.constexpr, - HAS_ROWSCALE: tl.constexpr, - HAS_DY1: tl.constexpr, - HAS_DX1: tl.constexpr, - HAS_B1: tl.constexpr, - RECOMPUTE_OUTPUT: tl.constexpr, -): - # Map the program id to the elements of X, DX, and DY it should compute. - row_block_id = tl.program_id(0) - row_start = row_block_id * rows_per_program - # Do not early exit if row_start >= M, because we need to write DW and DB - cols = tl.arange(0, BLOCK_N) - mask = cols < N - X += row_start * stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += row_start * stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += row_start * stride_dres_in_row - DY += row_start * stride_dy_row - DX += row_start * stride_dx_row - if HAS_DY1: - DY1 += row_start * stride_dy1_row - if HAS_DX1: - DX1 += row_start * stride_dx1_row - if RECOMPUTE_OUTPUT: - Y += row_start * stride_y_row - w = tl.load(W + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w += 1.0 - if RECOMPUTE_OUTPUT and HAS_BIAS: - b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) - if HAS_DY1: - w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w1 += 1.0 - dw = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_BIAS: - db = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_DY1: - dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_B1: - db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) - row_end = min((row_block_id + 1) * rows_per_program, M) - for row in range(row_start, row_end): - # Load data to SRAM - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) - dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) - if HAS_DY1: - dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) - if not IS_RMS_NORM: - mean = tl.load(Mean + row) - rstd = tl.load(Rstd + row) - # Compute dx - xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - xhat = tl.where(mask, xhat, 0.0) - if RECOMPUTE_OUTPUT: - y = xhat * w + b if HAS_BIAS else xhat * w - tl.store(Y + cols, y, mask=mask) - wdy = w * dy - dw += dy * xhat - if HAS_BIAS: - db += dy - if HAS_DY1: - wdy += w1 * dy1 - dw1 += dy1 * xhat - if HAS_B1: - db1 += dy1 - if not IS_RMS_NORM: - c1 = tl.sum(xhat * wdy, axis=0) / N - c2 = tl.sum(wdy, axis=0) / N - dx = (wdy - (xhat * c1 + c2)) * rstd - else: - c1 = tl.sum(xhat * wdy, axis=0) / N - dx = (wdy - xhat * c1) * rstd - if HAS_DRESIDUAL: - dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) - dx += dres - # Write dx - if STORE_DRESIDUAL: - tl.store(DRESIDUAL_IN + cols, dx, mask=mask) - if HAS_DX1: - if HAS_DROPOUT: - keep_mask = ( - tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - ) - dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) - else: - dx1 = dx - tl.store(DX1 + cols, dx1, mask=mask) - if HAS_DROPOUT: - keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + row).to(tl.float32) - dx *= rowscale - tl.store(DX + cols, dx, mask=mask) - - X += stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += stride_dres_in_row - if RECOMPUTE_OUTPUT: - Y += stride_y_row - DY += stride_dy_row - DX += stride_dx_row - if HAS_DY1: - DY1 += stride_dy1_row - if HAS_DX1: - DX1 += stride_dx1_row - tl.store(DW + row_block_id * N + cols, dw, mask=mask) - if HAS_BIAS: - tl.store(DB + row_block_id * N + cols, db, mask=mask) - if HAS_DY1: - tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) - if HAS_B1: - tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) - - -def _layer_norm_bwd( - dy: Tensor, - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - mean: Tensor, - rstd: Tensor, - dresidual: Optional[Tensor] = None, - dy1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - seeds: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - has_residual: bool = False, - has_x1: bool = False, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - x_dtype: Optional[torch.dtype] = None, - recompute_output: bool = False, -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - # Need to wrap to handle the case where dresidual_in or dx1 are aliases of x, - # which makes torch.library unhappy - dx, dw, db, dresidual_in, dx1, dw1, db1, y = _layer_norm_bwd_impl( - dy, - x, - weight, - bias, - eps, - mean, - rstd, - dresidual, - dy1, - weight1, - bias1, - seeds, - dropout_p, - rowscale, - has_residual, - has_x1, - zero_centered_weight, - is_rms_norm, - x_dtype=x_dtype, - recompute_output=recompute_output, - ) - # Don't need to compute dresidual_in separately in this case - if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: - dresidual_in = dx - if has_x1 and dropout_p == 0.0: - dx1 = dx - return dx, dw, db, dresidual_in, dx1, dw1, db1, y - - - -@triton_op("flash_attn::layer_norm_bwd_impl", mutates_args={}, - schema="(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)", - allow_decomposition=False, # Don't let torch.compile trace inside - ) -def _layer_norm_bwd_impl( - dy: Tensor, - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - mean: Tensor, - rstd: Tensor, - dresidual: Optional[Tensor] = None, - dy1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - seeds: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - has_residual: bool = False, - has_x1: bool = False, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - x_dtype: Optional[torch.dtype] = None, - recompute_output: bool = False, -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - M, N = x.shape - assert x.stride(-1) == 1 - dy = maybe_contiguous_lastdim(dy) - assert dy.stride(-1) == 1 - assert dy.shape == (M, N) - if dresidual is not None: - dresidual = maybe_contiguous_lastdim(dresidual) - assert dresidual.stride(-1) == 1 - assert dresidual.shape == (M, N) - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - if dy1 is not None: - dy1 = maybe_contiguous_lastdim(dy1) - assert weight1 is not None - assert dy1.shape == dy.shape - assert dy1.stride(-1) == 1 - if weight1 is not None: - assert weight1.shape == (N,) - assert weight1.stride(-1) == 1 - if bias1 is not None: - assert bias1.shape == (N,) - assert bias1.stride(-1) == 1 - if seeds is not None: - assert seeds.is_contiguous() - assert seeds.shape == (M if not has_x1 else M * 2,) - if rowscale is not None: - assert rowscale.is_contiguous() - assert rowscale.shape == (M,) - # allocate output - dx = ( - torch.empty_like(x) - if x_dtype is None - else torch.empty(M, N, dtype=x_dtype, device=x.device) - ) - dresidual_in = ( - torch.empty_like(x) - if has_residual - and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) - else None - ) - dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None - y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None - if recompute_output: - assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" - - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the - # latency of the gmem reads/writes, but will increase the time of summing up dw / db. - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 - _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) - _db = ( - torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) - if bias is not None - else None - ) - _dw1 = torch.empty_like(_dw) if weight1 is not None else None - _db1 = torch.empty_like(_db) if bias1 is not None else None - rows_per_program = math.ceil(M / sm_count) - grid = (sm_count,) - with torch.cuda.device(x.device.index): - torch.library.wrap_triton(_layer_norm_bwd_kernel)[grid]( - x, - weight, - bias, - y, - dy, - dx, - _dw, - _db, - dresidual, - weight1, - dy1, - dx1, - _dw1, - _db1, - dresidual_in, - rowscale, - seeds, - mean, - rstd, - x.stride(0), - 0 if not recompute_output else y.stride(0), - dy.stride(0), - dx.stride(0), - dresidual.stride(0) if dresidual is not None else 0, - dy1.stride(0) if dy1 is not None else 0, - dx1.stride(0) if dx1 is not None else 0, - dresidual_in.stride(0) if dresidual_in is not None else 0, - M, - N, - eps, - dropout_p, - # Passing bool make torch inductor very unhappy since it then tries to compare to int_max - int(zero_centered_weight), - rows_per_program, - is_rms_norm, - BLOCK_N, - dresidual is not None, - dresidual_in is not None, - bias is not None, - dropout_p > 0.0, - HAS_ROWSCALE=rowscale is not None, - HAS_DY1=dy1 is not None, - HAS_DX1=dx1 is not None, - HAS_B1=bias1 is not None, - RECOMPUTE_OUTPUT=y is not None, - ) - dw = _dw.sum(0).to(weight.dtype) - db = _db.sum(0).to(bias.dtype) if bias is not None else None - dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None - db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None - # dresidual_in and dx1 could be None, the wrapper will handle assigning them from dx - return dx, dw, db, dresidual_in, dx1, dw1, db1, y - - -class LayerNormFn(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - zero_centered_weight=False, - is_rms_norm=False, - return_dropout_mask=False, - out_dtype=None, - out=None, - residual_out=None - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) - if residual is not None: - assert residual.shape == x_shape_og - residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) - if x1 is not None: - assert x1.shape == x_shape_og - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) - weight = weight.contiguous() - bias = maybe_contiguous(bias) - weight1 = maybe_contiguous(weight1) - bias1 = maybe_contiguous(bias1) - if rowscale is not None: - rowscale = rowscale.reshape(-1).contiguous() - residual_dtype = ( - residual.dtype - if residual is not None - else (torch.float32 if residual_in_fp32 else None) - ) - if out is not None: - out = out.reshape(-1, out.shape[-1]) - if residual_out is not None: - residual_out = residual_out.reshape(-1, residual_out.shape[-1]) - y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( - x, - weight, - bias, - eps, - residual, - x1, - weight1, - bias1, - dropout_p=dropout_p, - rowscale=rowscale, - out_dtype=out_dtype, - residual_dtype=residual_dtype, - zero_centered_weight=zero_centered_weight, - is_rms_norm=is_rms_norm, - return_dropout_mask=return_dropout_mask, - out=out, - residual_out=residual_out, - ) - ctx.save_for_backward( - residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd - ) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.dropout_p = dropout_p - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.has_x1 = x1 is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - ctx.zero_centered_weight = zero_centered_weight - y = y.reshape(x_shape_og) - y1 = y1.reshape(x_shape_og) if y1 is not None else None - residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None - dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None - dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None - if not return_dropout_mask: - if weight1 is None: - return y if not prenorm else (y, residual_out) - else: - return (y, y1) if not prenorm else (y, y1, residual_out) - else: - if weight1 is None: - return ( - (y, dropout_mask, dropout_mask1) - if not prenorm - else (y, residual_out, dropout_mask, dropout_mask1) - ) - else: - return ( - (y, y1, dropout_mask, dropout_mask1) - if not prenorm - else (y, y1, residual_out, dropout_mask, dropout_mask1) - ) - - @staticmethod - def backward(ctx, dy, *args): - x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors - dy = dy.reshape(-1, dy.shape[-1]) - if weight1 is not None: - dy1, args = args[0], args[1:] - dy1 = dy1.reshape(-1, dy1.shape[-1]) - assert dy1.shape == x.shape - else: - dy1 = None - if ctx.prenorm: - dresidual = args[0] - dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dw, db, dresidual_in, dx1, dw1, db1, _ = _layer_norm_bwd( - dy, - x, - weight, - bias, - ctx.eps, - mean, - rstd, - dresidual, - dy1, - weight1, - bias1, - seeds, - ctx.dropout_p, - rowscale, - ctx.has_residual, - ctx.has_x1, - ctx.zero_centered_weight, - ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - recompute_output=False, - ) - return ( - dx.reshape(ctx.x_shape_og), - dw, - db, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, - dw1, - db1, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -def layer_norm_fn( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - zero_centered_weight=False, - is_rms_norm=False, - return_dropout_mask=False, - out_dtype=None, - out=None, - residual_out=None -): - return LayerNormFn.apply( - x, - weight, - bias, - residual, - x1, - weight1, - bias1, - eps, - dropout_p, - rowscale, - prenorm, - residual_in_fp32, - zero_centered_weight, - is_rms_norm, - return_dropout_mask, - out_dtype, - out, - residual_out - ) - - -def rms_norm_fn( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - zero_centered_weight=False, - return_dropout_mask=False, - out_dtype=None, - out=None, - residual_out=None -): - return LayerNormFn.apply( - x, - weight, - bias, - residual, - x1, - weight1, - bias1, - eps, - dropout_p, - rowscale, - prenorm, - residual_in_fp32, - zero_centered_weight, - True, - return_dropout_mask, - out_dtype, - out, - residual_out - ) - - -class RMSNorm(torch.nn.Module): - - def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, - device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - if dropout_p > 0.0: - self.drop = torch.nn.Dropout(dropout_p) - else: - self.drop = None - self.zero_centered_weight = zero_centered_weight - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - if not self.zero_centered_weight: - torch.nn.init.ones_(self.weight) - else: - torch.nn.init.zeros_(self.weight) - - def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): - return rms_norm_fn( - x, - self.weight, - self.bias, - residual=residual, - eps=self.eps, - dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, - prenorm=prenorm, - residual_in_fp32=residual_in_fp32, - zero_centered_weight=self.zero_centered_weight, - ) - - -class LayerNormLinearFn(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx, - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) - if residual is not None: - assert residual.shape == x_shape_og - residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) - norm_weight = norm_weight.contiguous() - norm_bias = maybe_contiguous(norm_bias) - residual_dtype = ( - residual.dtype - if residual is not None - else (torch.float32 if residual_in_fp32 else None) - ) - y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( - x, - norm_weight, - norm_bias, - eps, - residual, - out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"), - residual_dtype=residual_dtype, - is_rms_norm=is_rms_norm, - ) - y = y.reshape(x_shape_og) - dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype - linear_weight = linear_weight.to(dtype) - linear_bias = linear_bias.to(dtype) if linear_bias is not None else None - out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) - # We don't store y, will be recomputed in the backward pass to save memory - ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - ctx.linear_bias_is_none = linear_bias is None - return out if not prenorm else (out, residual_out.reshape(x_shape_og)) - - @staticmethod - @custom_bwd - def backward(ctx, dout, *args): - x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors - dout = dout.reshape(-1, dout.shape[-1]) - dy = F.linear(dout, linear_weight.t()) - dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) - dy = maybe_contiguous_lastdim(dy) - assert dy.shape == x.shape - if ctx.prenorm: - dresidual = args[0] - dresidual = maybe_contiguous_lastdim(dresidual.reshape(-1, dresidual.shape[-1])) - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( - dy, - x, - norm_weight, - norm_bias, - ctx.eps, - mean, - rstd, - dresidual=dresidual, - has_residual=ctx.has_residual, - is_rms_norm=ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - recompute_output=True, - ) - dlinear_weight = torch.einsum("bo,bi->oi", dout, y) - return ( - dx.reshape(ctx.x_shape_og), - dnorm_weight, - dnorm_bias, - dlinear_weight, - dlinear_bias, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - None, - None, - None, - None, - ) - - -def layer_norm_linear_fn( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, -): - return LayerNormLinearFn.apply( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual, - eps, - prenorm, - residual_in_fp32, - is_rms_norm, - ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/linear.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/linear.py deleted file mode 100644 index a8966dbc345ab0e593df0124451ee7be3dae131a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/linear.py +++ /dev/null @@ -1,594 +0,0 @@ -# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py -# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py -from typing import Optional - -import torch -import triton -import triton.language as tl -from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - -from flash_attn.ops.triton.k_activations import ( - gelu, - gelu_approx, - gelu_approx_grad, - gelu_grad, - squared_relu, - squared_relu_grad, -) - -# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications - - -def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - - -def get_configs_io_bound(): - configs = [] - for num_stages in [2, 3, 4, 5, 6]: - for block_m in [16, 32]: - for block_k in [32, 64]: - for block_n in [32, 64, 128, 256]: - num_warps = 2 if block_n <= 64 else 4 - configs.append( - triton.Config( - { - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SPLIT_K": 1, - }, - num_stages=num_stages, - num_warps=num_warps, - ) - ) - # split_k not used - # for split_k in [2, 4, 8, 16]: - # configs.append(triton.Config( - # {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - # num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) - return configs - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - # good for int8 - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - ] - + get_configs_io_bound(), - key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": estimate_matmul_time, - "top_k": 10, - }, -) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) -@triton.jit -def kernel_fwd( - C, # Pointers to matrices - ACT_INPUT, - A, - B, - bias, - # Matrix dimensions - M, - N, - K, - CACHE_KEY_M, - CACHE_KEY_N, - CACHE_KEY_K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. stride_am is how much to increase a_ptr - # by to get the element one row down (A has M rows) - stride_cm, - # stride_cn, # Assume that stride_cn == 1 - stride_am, - stride_ak, - stride_bn, - stride_bk, - # Meta-parameters - BLOCK_M: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - # split k not used, not performant with activation, kept because early_config_prune is expecting it - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - A_ROWMAJOR: tl.constexpr, - B_COLMAJOR: tl.constexpr, - BIAS: tl.constexpr, - SAVE_ACT_INPUT: tl.constexpr, - ACTIVATION: tl.constexpr, -): - - """ - Kernel for computing Out = activation(A x W + C) - - Input has shape (M, K) - - Weight has shape (K, N) - - Bias has shape (N,) - - Output has shape (M, N) - - ActInputs (optional) has shape (M, N) - 'ActInputs' optionally saves the A x W + C intermediate for backward computations - This kernel will consolidate over K - """ - - pid = tl.program_id(axis=0) - - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - # now compute the block that each program will go through - # rm (resp. rn) denotes a range of indices - # for rows (resp. col) of C - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - # trick to avoid masking on M and N axis - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - - if A_ROWMAJOR: - A = A + (ram[:, None] * stride_am + rk[None, :]) - else: - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - if B_COLMAJOR: - B = B + (rk[:, None] + rbn[None, :] * stride_bn) - else: - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.0) - b = tl.load(B, mask=rk[:, None] < k, other=0.0) - acc += tl.dot(a, b) - - if A_ROWMAJOR: - A += BLOCK_K - else: - A += BLOCK_K * stride_ak - if B_COLMAJOR: - B += BLOCK_K - else: - B += BLOCK_K * stride_bk - - # Putting bias after the matmul (instead of before) is faster, idk why - if BIAS: - bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) - acc += bias[None, :] - - # optional: save the activation inputs - if SAVE_ACT_INPUT: - # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn - act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] - tl.store(act_in_ptrs, acc) - - # optional: fused activation (while the data is in shared memory) - if ACTIVATION == "gelu": - acc = gelu(acc) - elif ACTIVATION == "gelu_approx": - acc = gelu_approx(acc) - elif ACTIVATION == "squared_relu": - acc = squared_relu(acc) - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # write back result - # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - C = C + rm[:, None] * stride_cm + rn[None, :] - mask = (rm < M)[:, None] & (rn < N)[None, :] - tl.store(C, acc) - - -def triton_linear_act( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: str = "id", - save_act_input: bool = False, -) -> torch.Tensor: - """ - Compute e = activation(x @ weight.T + bias). - This wrapper kicks the `kernel_fwd` Triton kernel - :param x: input tensor - :param weight: weight matrix - :param bias: an optional bias tensor - :param activation: Activation name. Needs to be a Triton kernel. - :param act_input: an optional tensor to save the activation inputs (for backward) - :return: result tensor - """ - # if torch.is_autocast_enabled(): - # dtype = torch.get_autocast_gpu_dtype() - # x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]] - - assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] - - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - x_reshaped = x.reshape(batch_dim, n) - - if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1: - x_reshaped = x_reshaped.contiguous() - if weight.stride(0) > 1 and weight.stride(1) > 1: - weight = weight.contiguous() - bias = bias.contiguous() if bias is not None else None - - assert ( - x.dtype == weight.dtype - ), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" - if bias is not None: - assert ( - x.dtype == bias.dtype - ), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" - assert ( - x_reshaped.shape[1] == weight.shape[1] - ), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}" - - assert ( - bias is None or bias.shape[0] == weight.shape[0] - ), "Incompatible dimensions in between weight and bias" - - M, K = x_reshaped.shape - N, K = weight.shape - - output = torch.empty((M, N), device=x.device, dtype=x.dtype) - act_input = torch.empty_like(output) if save_act_input else None - - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa - - kernel_fwd[grid]( - output, - act_input, - x_reshaped, - weight, # data ptrs - bias if bias is not None else x, # auto skip bias if not present - M, # shapes - N, - K, - M // 32, # key for triton cache (limit number of compilations) - N // 32, - K // 32, - stride_cm=output.stride(0), # strides - # stride_cn=output.stride(1), - stride_am=x_reshaped.stride(0), - stride_ak=x_reshaped.stride(1), - stride_bk=weight.stride(1), - stride_bn=weight.stride(0), - BIAS=bias is not None, # optional fused bias - SAVE_ACT_INPUT=save_act_input, # optional save activation inputs - ACTIVATION=activation, # optional fused activation - A_ROWMAJOR=x_reshaped.stride(1) == 1, - B_COLMAJOR=weight.stride(1) == 1, - GROUP_M=8, # speed optimization: group the programs - ) - - if not save_act_input: - return output.reshape(*batch_shape, output.shape[-1]) - else: - return ( - output.reshape(*batch_shape, output.shape[-1]), - act_input.reshape(*batch_shape, act_input.shape[-1]), - ) - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - # good for int8 - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - ] - + get_configs_io_bound(), - key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": estimate_matmul_time, - "top_k": 10, - }, -) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) -@triton.jit -def kernel_bwd( - C, # Pointers to matrices - ACT_INPUT, - A, - B, - # Matrix dimensions - M, - N, - K, - CACHE_KEY_M, - CACHE_KEY_N, - CACHE_KEY_K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. stride_am is how much to increase a_ptr - # by to get the element one row down (A has M rows) - stride_cm, - # stride_cn, # Assume that stride_cn == 1 - stride_am, - stride_ak, - stride_bk, - stride_bn, - # Meta-parameters - BLOCK_M: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - # split k not used, not performant with activation, kept because early_config_prune is expecting it - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - ACTIVATION: tl.constexpr, -): - - """ - Kernel for computing Out = activation(A x W + C) - - Input has shape (M, K) - - Weight has shape (K, N) - - Output has shape (M, N) - - ActInputs (optional) has shape (M, N) - 'ActInputs' optionally saves the A x W + C intermediate for backward computations - This kernel will consolidate over K - """ - - pid = tl.program_id(axis=0) - - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - # now compute the block that each program will go through - # rm (resp. rn) denotes a range of indices - # for rows (resp. col) of C - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - # trick to avoid masking on M and N axis - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.0) - b = tl.load(B, mask=rk[:, None] < k, other=0.0) - acc += tl.dot(a, b) - - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - # optional: fused activation (while the data is in shared memory) - if ACTIVATION != "id": - act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] - act_input = tl.load(act_in_ptrs).to(acc.dtype) - if ACTIVATION == "gelu": - acc *= gelu_grad(act_input) - elif ACTIVATION == "gelu_approx": - acc *= gelu_approx_grad(act_input) - elif ACTIVATION == "squared_relu": - acc *= squared_relu_grad(act_input) - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # write back result - C = C + rm[:, None] * stride_cm + rn[None, :] - mask = (rm < M)[:, None] & (rn < N)[None, :] - tl.store(C, acc, mask=mask) - - -def triton_dgrad_act( - grad_output: torch.Tensor, - weight: torch.Tensor, - activation: str = "id", - act_input: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - Compute e = activation(grad_output @ weight + bias). - This wrapper kicks the `kernel_fwd` Triton kernel - :param grad_output: input tensor - :param weight: weight matrix - :param activation: Activation name. Needs to be a Triton kernel. - :param act_input: an optional tensor to save the activation inputs (for backward) - :return: result tensor - """ - assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] - - batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1] - batch_dim = batch_shape.numel() - grad_output_reshaped = grad_output.reshape(batch_dim, n) - - if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: - grad_output_reshaped = grad_output_reshaped.contiguous() - if weight.stride(0) > 1 and weight.stride(1) > 1: - weight = weight.contiguous() - - assert ( - grad_output.dtype == weight.dtype - ), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}" - assert ( - grad_output_reshaped.shape[1] == weight.shape[0] - ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" - if activation != "id": - assert act_input is not None, f"act_input is required for activation {activation}" - - # M, N, K in bwd are different from M, N, K in fwd - M, K = grad_output_reshaped.shape - K, N = weight.shape - - grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype) - - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa - - kernel_bwd[grid]( - grad_input, - act_input, - grad_output_reshaped, - weight, # data ptrs - M, # shapes - N, - K, - M // 32, # key for triton cache (limit number of compilations) - N // 32, - K // 32, - stride_cm=grad_input.stride(0), # strides - # stride_cn=grad_input.stride(1), - stride_am=grad_output_reshaped.stride(0), - stride_ak=grad_output_reshaped.stride(1), - stride_bk=weight.stride(0), - stride_bn=weight.stride(1), - ACTIVATION=activation, # optional fused activation - GROUP_M=8, # speed optimization: group the programs - ) - - return grad_input.reshape(*batch_shape, grad_input.shape[-1]) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/mlp.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/mlp.py deleted file mode 100644 index 059f4f8a5e174c1f4824e43d313fca18eaa799b8..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/mlp.py +++ /dev/null @@ -1,149 +0,0 @@ -# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared -# to naive implementation. -import fused_dense_lib as fused_dense_cuda -import torch -import torch.nn as nn -import torch.nn.functional as F - -from flash_attn.utils.torch import custom_fwd, custom_bwd -from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd -from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act - - -class FusedDenseSqreluDenseFunc(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0): - """checkpoint_lvl: - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute act_input and gelu_out in the bwd - """ - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_gpu_dtype() - x, weight1, bias1, weight2, bias2 = [ - a.to(dtype=dtype) for a in [x, weight1, bias1, weight2, bias2] - ] - is_bf16 = x.dtype == torch.bfloat16 - assert checkpoint_lvl in [0, 1, 2] - x = x.contiguous() - weight1 = weight1.contiguous() - bias1 = bias1.contiguous() - weight2 = weight2.contiguous() - bias2 = bias2.contiguous() - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - if is_bf16: - act_input = fused_dense_cuda.linear_bias_forward( - x.reshape(batch_dim, n), weight1, bias1 - ) - output1 = sqrelu_fwd(act_input) - else: - save_act_input = checkpoint_lvl != 2 - result = triton_linear_act( - x.reshape(batch_dim, n), - weight1, - bias1, - activation="squared_relu", - save_act_input=save_act_input, - ) - if save_act_input: - output1, act_input = result - else: - output1 = result - output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2) - ctx.checkpoint_lvl = checkpoint_lvl - if checkpoint_lvl == 0: - ctx.save_for_backward(x, weight1, bias1, weight2, act_input, output1) - elif checkpoint_lvl == 1: - ctx.save_for_backward(x, weight1, bias1, weight2, act_input) - elif checkpoint_lvl == 2: - ctx.save_for_backward(x, weight1, bias1, weight2) - return output2.reshape(*batch_shape, output2.shape[-1]) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - grad_output = grad_output.contiguous() - checkpoint_lvl = ctx.checkpoint_lvl - x, weight1, bias1, weight2, *rest = ctx.saved_tensors - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - is_bf16 = x.dtype == torch.bfloat16 - if checkpoint_lvl == 0: - act_input, output1 = rest - elif checkpoint_lvl == 1: - (act_input,) = rest - output1 = sqrelu_fwd(act_input) - elif checkpoint_lvl == 2: - if is_bf16: - act_input = fused_dense_cuda.linear_bias_forward( - x.reshape(batch_dim, n), weight1, bias1 - ) - output1 = sqrelu_fwd(act_input) - else: - output1, act_input = triton_linear_act( - x.reshape(batch_dim, n), - weight1, - bias1, - activation="squared_relu", - save_act_input=True, - ) - - if is_bf16: - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) - grad_output1 = grad_output @ weight2 - grad_act_input = sqrelu_bwd(grad_output1, act_input) - grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( - x.reshape(batch_dim, n), weight1, grad_act_input - ) - else: - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) - grad_act_input = triton_dgrad_act( - grad_output, weight2, activation="squared_relu", act_input=act_input - ) - grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( - x.reshape(batch_dim, n), weight1, grad_act_input - ) - return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None - - -fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply - - -class FusedDenseSqreluDense(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - bias1=True, - bias2=True, - checkpoint_lvl=0, - device=None, - dtype=None, - ): - """ - checkpoint_lvl (increasing lvl means slower but more memory saving): - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute gelu_in and gelu_out in the bwd - """ - assert checkpoint_lvl in [0, 1, 2] - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features * 4 - assert bias1 == True, "DenseSqreluDense module without bias is currently not supported" - assert bias2 == True, "DenseSqreluDense module without bias is currently not supported" - self.checkpoint_lvl = checkpoint_lvl - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) - - def forward(self, x): - assert x.is_cuda - return fused_dense_sqrelu_dense_function( - x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias, self.checkpoint_lvl - ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/rotary.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/rotary.py deleted file mode 100644 index ff4017fda3e4a6e18cf3a51b34f0fb073d8f678a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/flash_attn/ops/triton/rotary.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) 2025, Tri Dao. -# As of 2025-04-23, we require triton >= 3.0 - -from typing import Optional, Union - -import torch - -import triton -import triton.language as tl - - -@triton.jit -def rotary_kernel( - OUT, # Pointers to matrices - X, - COS, - SIN, - CU_SEQLENS, - SEQLEN_OFFSETS, # this could be int or a pointer - # Matrix dimensions - seqlen, - nheads, - seqlen_ro, - # strides - stride_out_batch, - stride_out_seqlen, - stride_out_nheads, - stride_out_headdim, - stride_x_batch, - stride_x_seqlen, - stride_x_nheads, - stride_x_headdim, - # Meta-parameters - # We want ROTARY_DIM to be constexpr, otherwise the triton compiler doesn't know that - # the mask is constant every 8 elements, and it will generate LDG.16 instead of LDG.128 - ROTARY_DIM: tl.constexpr, - IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, - IS_VARLEN: tl.constexpr, - INTERLEAVED: tl.constexpr, - CONJUGATE: tl.constexpr, - BLOCK_H: tl.constexpr, - BLOCK_M: tl.constexpr, -): - BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM) - ROTARY_DIM_HALF = ROTARY_DIM // 2 - pid_head = tl.program_id(axis=0) - pid_m = tl.program_id(axis=1) - pid_batch = tl.program_id(axis=2) - - if not IS_VARLEN: - X = X + pid_batch * stride_x_batch - OUT = OUT + pid_batch * stride_out_batch - else: - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - X = X + start_idx * stride_x_seqlen - OUT = OUT + start_idx * stride_out_seqlen - - if pid_m * BLOCK_M >= seqlen: - return - - rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - if not IS_SEQLEN_OFFSETS_TENSOR: - rm_cs = rm + SEQLEN_OFFSETS - else: - rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) - - rk_half = tl.arange(0, BLOCK_K // 2) - COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) - SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) - mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) - cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) - sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) - if CONJUGATE: - sin = -sin - - if not INTERLEAVED: - # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT - X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk_half[None, None, :] * stride_x_headdim) - OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk_half[None, None, :] * stride_out_headdim) - mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk_half[None, None, :] < ROTARY_DIM_HALF) - x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) - x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0,).to(tl.float32) - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - tl.store(OUT, o0, mask=mask) - tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) - else: - rk = tl.arange(0, BLOCK_K) - X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk[None, None, :] * stride_x_headdim) - OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk[None, None, :] * stride_out_headdim) - mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk[None, None, :] < ROTARY_DIM) - x = tl.load(X, mask=mask, other=0.0).to(tl.float32) - x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) - tl.store(OUT, o, mask=mask) - - -def apply_rotary( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - interleaved=False, - inplace=False, - conjugate=False, -) -> torch.Tensor: - """ - Arguments: - x: (batch, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim). - cos: (seqlen_ro, rotary_dim / 2) - sin: (seqlen_ro, rotary_dim / 2) - seqlen_offsets: integer or integer tensor of size (batch,) - cu_seqlens: (batch + 1,) or None - max_seqlen: int - Returns: - y: (batch, seqlen, nheads, headdim) - """ - is_varlen = cu_seqlens is not None - if not is_varlen: - batch, seqlen, nheads, headdim = x.shape - else: - assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" - total_seqlen, nheads, headdim = x.shape - batch_p_1 = cu_seqlens.shape[0] - batch = batch_p_1 - 1 - seqlen = max_seqlen - seqlen_ro, rotary_dim = cos.shape - assert sin.shape == cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim, "rotary_dim must be <= headdim" - assert headdim <= 256, "Only support headdim <= 256" - assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" - - cos, sin = cos.contiguous(), sin.contiguous() - if isinstance(seqlen_offsets, torch.Tensor): - assert seqlen_offsets.shape == (batch,) - assert seqlen_offsets.dtype in [torch.int32, torch.int64] - seqlen_offsets = seqlen_offsets.contiguous() - else: - assert seqlen_offsets + seqlen <= seqlen_ro - - output = torch.empty_like(x) if not inplace else x - if rotary_dim < headdim and not inplace: - output[..., rotary_dim:].copy_(x[..., rotary_dim:]) - - grid = lambda META: (triton.cdiv(nheads, META["BLOCK_H"]), triton.cdiv(seqlen, META["BLOCK_M"]), batch) # noqa - BLOCK_M = 8 if rotary_dim <= 128 else 4 - - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(x.device.index): - torch.library.wrap_triton(rotary_kernel)[grid]( - output, # data ptrs - x, - cos, - sin, - cu_seqlens, - seqlen_offsets, - seqlen, # shapes - nheads, - seqlen_ro, - output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 - output.stride(-3), # seqlen_stride or total_seqlen_stride - output.stride(-2), # nheads_stride - output.stride(-1), # headdim_stride - x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 - x.stride(-3), # seqlen stride or total_seqlen_stride - x.stride(-2), # nheads stride - x.stride(-1), # headdim stride - rotary_dim, - isinstance(seqlen_offsets, torch.Tensor), - is_varlen, - interleaved, - conjugate, - BLOCK_M=BLOCK_M, - BLOCK_H=2, - ) - return output diff --git a/build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..7e55fc6b442d6812792ba058339097c95a75b3f2 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:201b532a74bf5aeefdd6cfe7db479c2d089392ab53a34c699c56d78e225cd09a +size 445273568 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so deleted file mode 100755 index e5359f13e8e98604d7f19262fea8a1a744f4bf27..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3b6ba1f39054da689353599c7f1ca8ef381edef2e6d3e027b1fdac48d1218267 -size 445302240 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_ops.py index a9819140ce922d5d25722ffeb3c2416285a9d068..579b6b9e94e9fa8ba86c3aec59c22dda67588306 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _flash_attn_876ac68_dirty -ops = torch.ops._flash_attn_876ac68_dirty +from . import _flash_attn_56449c1_dirty +ops = torch.ops._flash_attn_56449c1_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_flash_attn_876ac68_dirty::{op_name}" \ No newline at end of file + return f"_flash_attn_56449c1_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..ca85316d32eab8ae416c1b49384d75afd220114f --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:829746a9a9848d76f837c85613b9af3c367d51023134e99470bd208cddb0ba96 +size 448639320 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so deleted file mode 100755 index 36160a68829ccb8f3527ae7c394d9c3c7842c82d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0d242b0c0229ab2049a369cb3837f7336779a712dd85c6ac7751e43593272664 -size 448651608 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_ops.py index a9819140ce922d5d25722ffeb3c2416285a9d068..579b6b9e94e9fa8ba86c3aec59c22dda67588306 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _flash_attn_876ac68_dirty -ops = torch.ops._flash_attn_876ac68_dirty +from . import _flash_attn_56449c1_dirty +ops = torch.ops._flash_attn_56449c1_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_flash_attn_876ac68_dirty::{op_name}" \ No newline at end of file + return f"_flash_attn_56449c1_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e6633d5f13eb650f995222965e26372eb84d0599 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b84930b4e7778dcf476fef8dda3ae1ba3ceee449d2d7815fc69360329d5de63c +size 1037635064 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so deleted file mode 100755 index 5d4b60a2e2ecd21e030e5c9eff8cce8734947440..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_flash_attn_876ac68_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:23970c38f8d06013dd98797428f662f4ddef0840e5598d38920bf350bfcf7cd3 -size 1037708816 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_ops.py index a9819140ce922d5d25722ffeb3c2416285a9d068..579b6b9e94e9fa8ba86c3aec59c22dda67588306 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _flash_attn_876ac68_dirty -ops = torch.ops._flash_attn_876ac68_dirty +from . import _flash_attn_56449c1_dirty +ops = torch.ops._flash_attn_56449c1_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_flash_attn_876ac68_dirty::{op_name}" \ No newline at end of file + return f"_flash_attn_56449c1_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/__init__.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/__init__.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/__init__.py diff --git a/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..bd67cb661befcea136898f63fe39fa981a080cc3 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72ab37410a26028f25d7f5e1f2ab5895b6240c4497aed4dbc625a7fb3e3679cf +size 448643608 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..579b6b9e94e9fa8ba86c3aec59c22dda67588306 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_attn_56449c1_dirty +ops = torch.ops._flash_attn_56449c1_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_attn_56449c1_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/bert_padding.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/bert_padding.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/bert_padding.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/bert_padding.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/flash_attn_interface.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/flash_attn_interface.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/flash_attn_interface.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/flash_attn_interface.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/layers/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/layers/__init__.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/layers/__init__.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/layers/__init__.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/layers/patch_embed.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/layers/patch_embed.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/layers/patch_embed.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/layers/patch_embed.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/layers/rotary.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/layers/rotary.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/layers/rotary.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/layers/rotary.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/__init__.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/__init__.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/__init__.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/activations.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/activations.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/activations.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/activations.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/fused_dense.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/fused_dense.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/fused_dense.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/fused_dense.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/layer_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/layer_norm.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/layer_norm.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/layer_norm.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/rms_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/rms_norm.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/rms_norm.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/rms_norm.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/__init__.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/__init__.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/__init__.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/cross_entropy.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/cross_entropy.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/cross_entropy.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/cross_entropy.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/k_activations.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/k_activations.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/k_activations.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/k_activations.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/layer_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/layer_norm.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/layer_norm.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/layer_norm.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/linear.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/linear.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/linear.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/linear.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/mlp.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/mlp.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/mlp.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/mlp.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/rotary.py b/build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/rotary.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/flash_attn/ops/triton/rotary.py rename to build/torch28-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/rotary.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/__init__.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/__init__.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/__init__.py diff --git a/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..41d96ea13c10bdae5290cdd132b2abaa42e8216c --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbff361d529649c9096acb4af58f94658b270815dfb910b17c8dee77d8ac8887 +size 1037639488 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..579b6b9e94e9fa8ba86c3aec59c22dda67588306 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_attn_56449c1_dirty +ops = torch.ops._flash_attn_56449c1_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_attn_56449c1_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/bert_padding.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/bert_padding.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/bert_padding.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/bert_padding.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/flash_attn_interface.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/flash_attn_interface.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/flash_attn_interface.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/flash_attn_interface.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/layers/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/layers/__init__.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/layers/__init__.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/layers/__init__.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/layers/patch_embed.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/layers/patch_embed.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/layers/patch_embed.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/layers/patch_embed.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/layers/rotary.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/layers/rotary.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/layers/rotary.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/layers/rotary.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/__init__.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/__init__.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/__init__.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/activations.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/activations.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/activations.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/activations.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/fused_dense.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/fused_dense.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/fused_dense.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/fused_dense.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/layer_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/layer_norm.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/layer_norm.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/layer_norm.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/rms_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/rms_norm.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/rms_norm.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/rms_norm.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/__init__.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/__init__.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/__init__.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/cross_entropy.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/cross_entropy.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/cross_entropy.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/cross_entropy.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/k_activations.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/k_activations.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/k_activations.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/k_activations.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/layer_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/layer_norm.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/layer_norm.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/layer_norm.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/linear.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/linear.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/linear.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/linear.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/mlp.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/mlp.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/mlp.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/mlp.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/rotary.py b/build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/rotary.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/flash_attn/ops/triton/rotary.py rename to build/torch28-cxx11-cu128-x86_64-linux/flash_attn/ops/triton/rotary.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/__init__.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/__init__.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/__init__.py diff --git a/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..fde8c28af72d5e6f916fe23278d36bbe9de5a601 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca81cb44ea1aa87de6db7778b446b959a44e315d587a39a4b8f5495e4602cbd6 +size 1042939024 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..579b6b9e94e9fa8ba86c3aec59c22dda67588306 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_attn_56449c1_dirty +ops = torch.ops._flash_attn_56449c1_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_attn_56449c1_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/bert_padding.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/bert_padding.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/bert_padding.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/bert_padding.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/flash_attn_interface.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/flash_attn_interface.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/flash_attn_interface.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/flash_attn_interface.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/layers/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/layers/__init__.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/layers/__init__.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/layers/__init__.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/layers/patch_embed.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/layers/patch_embed.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/layers/patch_embed.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/layers/patch_embed.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/layers/rotary.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/layers/rotary.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/layers/rotary.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/layers/rotary.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/__init__.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/__init__.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/__init__.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/activations.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/activations.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/activations.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/activations.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/fused_dense.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/fused_dense.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/fused_dense.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/fused_dense.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/layer_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/layer_norm.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/layer_norm.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/layer_norm.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/rms_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/rms_norm.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/rms_norm.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/rms_norm.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/__init__.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/__init__.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/__init__.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/cross_entropy.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/cross_entropy.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/cross_entropy.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/cross_entropy.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/k_activations.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/k_activations.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/k_activations.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/k_activations.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/layer_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/layer_norm.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/layer_norm.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/layer_norm.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/linear.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/linear.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/linear.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/linear.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/mlp.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/mlp.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/mlp.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/mlp.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/rotary.py b/build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/rotary.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/flash_attn/ops/triton/rotary.py rename to build/torch28-cxx11-cu129-x86_64-linux/flash_attn/ops/triton/rotary.py