File size: 5,474 Bytes
a741640 3b8334d d6cc1b0 0518250 d6cc1b0 0518250 d6cc1b0 0518250 d6cc1b0 0518250 56449c1 0518250 56449c1 0518250 56449c1 0518250 d6cc1b0 0518250 d6cc1b0 0518250 56449c1 d6cc1b0 0518250 d6cc1b0 0518250 a741640 d6cc1b0 56449c1 d6cc1b0 0518250 56449c1 0518250 56449c1 0518250 a741640 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
---
license: bsd-3-clause
tags:
- kernel
---
<!--  -->
# Flash Attention
Flash Attention is a fast and memory-efficient implementation of the attention mechanism, designed to work with large models and long sequences. This is a Hugging Face compliant kernel build of Flash Attention.
Original code here [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention).
[`scripts/readme_example.py`](scripts/readme_example.py) provides a simple example of how to use the Flash Attention kernel in PyTorch. It demonstrates standard attention, causal attention, and variable-length sequences.
```python
# /// script
# dependencies = [
# "numpy",
# "torch",
# "kernels"
# ]
# ///
import torch
from kernels import get_kernel
# Setup
torch.manual_seed(42)
flash_attn = get_kernel("kernels-community/flash-attn")
device = torch.device("cuda")
# Create test tensors
B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim
q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
# Reference implementation using PyTorch SDPA
def reference_attention(query, key, value, causal=False):
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
out = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=causal)
return out.transpose(1, 2).contiguous()
# 1. Standard attention
print("\n1. Standard attention:")
out_ref = reference_attention(q, k, v)
out_flash = flash_attn.fwd(
q=q,
k=k,
v=v,
is_causal=False,
)[0]
print(f"Reference output: {out_ref.shape}")
print(f"Flash output: {out_flash.shape}")
print(f"Outputs close: {torch.allclose(out_flash, out_ref, atol=1e-2, rtol=1e-3)}")
# 2. Causal attention (for autoregressive models)
print("\n2. Causal attention:")
out_ref_causal = reference_attention(q, k, v, causal=True)
out_causal = flash_attn.fwd(
q=q,
k=k,
v=v,
is_causal=True,
)[0]
print(f"Reference causal output: {out_ref_causal.shape}")
print(f"Flash causal output: {out_causal.shape}")
print(f"Outputs close: {torch.allclose(out_causal, out_ref_causal, atol=1e-2, rtol=1e-3)}")
def var_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=False):
batch_size = cu_seqlens_q.shape[0] - 1
# Return output in packed format (same as flash attention)
total_tokens_q = q.shape[0]
out = torch.zeros((total_tokens_q, q.shape[1], q.shape[2]), device=q.device, dtype=q.dtype)
for b in range(batch_size):
start_q, end_q = cu_seqlens_q[b], cu_seqlens_q[b + 1]
start_k, end_k = cu_seqlens_k[b], cu_seqlens_k[b + 1]
# Extract slices for this batch
q_slice = q[start_q:end_q] # Shape: (seq_len_q, H, D)
k_slice = k[start_k:end_k] # Shape: (seq_len_k, H, D)
v_slice = v[start_k:end_k] # Shape: (seq_len_k, H, D)
# Add batch dimension for reference_attention
q_slice = q_slice.unsqueeze(0) # Shape: (1, seq_len_q, H, D)
k_slice = k_slice.unsqueeze(0) # Shape: (1, seq_len_k, H, D)
v_slice = v_slice.unsqueeze(0) # Shape: (1, seq_len_k, H, D)
# Compute attention and remove batch dimension
attn_out = reference_attention(q_slice, k_slice, v_slice, causal=causal)
attn_out = attn_out.squeeze(0) # Shape: (seq_len_q, H, D)
# Place result in output tensor (packed format)
out[start_q:end_q] = attn_out
return out
# 3. Variable length sequences (packed format)
print("\n3. Variable length sequences:")
# Pack sequences of lengths [3,4,3] for q and [4,5,3] for k into single tensors
q_var = torch.randn(10, H, D, device=device, dtype=torch.float16) # total_q=10
k_var = v_var = torch.randn(12, H, D, device=device, dtype=torch.float16) # total_k=12
cu_q = torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int32) # cumulative sequence lengths
cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32)
out_var_ref = var_reference_attention(q_var, k_var, v_var, cu_q, cu_k, max_seqlen_q=4, max_seqlen_k=5, causal=False)
# Custom function to handle variable
out_var = flash_attn.varlen_fwd(
q=q_var,
k=k_var,
v=v_var,
cu_seqlens_q=cu_q,
cu_seqlens_k=cu_k,
max_seqlen_q=4,
max_seqlen_k=5,
)[0]
print(f"Variable length output: {out_var.shape}")
print(f"Reference variable length output: {out_var_ref.shape}")
print(f"Outputs close: {torch.allclose(out_var, out_var_ref, atol=1e-2, rtol=1e-3)}")
```
run it using the following command:
```bash
uv run scripts/readme_example.py
```
```txt
Reading inline script metadata from `scripts/readme_example.py`
Fetching 20 files: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββ| 20/20 [00:00<00:00, 16371.21it/s]
1. Standard attention:
Reference output: torch.Size([2, 5, 4, 8])
Flash output: torch.Size([2, 5, 4, 8])
Outputs close: True
2. Causal attention:
Reference causal output: torch.Size([2, 5, 4, 8])
Flash causal output: torch.Size([2, 5, 4, 8])
Outputs close: True
3. Variable length sequences:
Variable length output: torch.Size([10, 4, 8])
Reference variable length output: torch.Size([10, 4, 8])
Outputs close: True
```
|