danieldk HF Staff commited on
Commit
5d398ae
·
1 Parent(s): 66b8aff

Remove sources

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +6 -5
  2. build.toml +0 -593
  3. flake.lock +0 -168
  4. flake.nix +0 -51
  5. flash-attn/block.h +0 -94
  6. flash-attn/copy_sm90_bulk_reduce.hpp +0 -49
  7. flash-attn/cuda_check.h +0 -19
  8. flash-attn/epilogue_bwd.hpp +0 -523
  9. flash-attn/epilogue_fwd.hpp +0 -484
  10. flash-attn/flash.h +0 -220
  11. flash-attn/flash_api.cpp +0 -1623
  12. flash-attn/flash_bwd_kernel_sm80.h +0 -173
  13. flash-attn/flash_bwd_kernel_sm90.h +0 -282
  14. flash-attn/flash_bwd_launch_template.h +0 -377
  15. flash-attn/flash_bwd_postprocess_kernel.h +0 -256
  16. flash-attn/flash_bwd_preprocess_kernel.h +0 -252
  17. flash-attn/flash_fwd_combine.cu +0 -13
  18. flash-attn/flash_fwd_combine_kernel.h +0 -702
  19. flash-attn/flash_fwd_combine_launch_template.h +0 -88
  20. flash-attn/flash_fwd_kernel_sm80.h +0 -215
  21. flash-attn/flash_fwd_kernel_sm90.h +0 -468
  22. flash-attn/flash_fwd_launch_template.h +0 -231
  23. flash-attn/flash_prepare_scheduler.cu +0 -124
  24. flash-attn/heuristics.h +0 -65
  25. flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu +0 -18
  26. flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu +0 -12
  27. flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu +0 -18
  28. flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu +0 -12
  29. flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu +0 -6
  30. flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu +0 -18
  31. flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu +0 -12
  32. flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu +0 -18
  33. flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu +0 -12
  34. flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu +0 -6
  35. flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu +0 -18
  36. flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu +0 -12
  37. flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu +0 -18
  38. flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu +0 -12
  39. flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu +0 -6
  40. flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu +0 -18
  41. flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu +0 -12
  42. flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu +0 -18
  43. flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu +0 -12
  44. flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu +0 -6
  45. flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu +0 -18
  46. flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu +0 -12
  47. flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu +0 -18
  48. flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu +0 -12
  49. flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu +0 -6
  50. flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu +0 -18
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
  license: apache-2.0
3
  tags:
4
- - kernel
5
  ---
6
-
7
  # vllm-flash-attn3
8
 
9
  This is an implementation of Flash Attention 3 CUDA kernels with support for attention sinks. The attention sinks implementation was contributed to Flash Attention by the [vLLM team](https://huggingface.co/vllm-project). The [transformers team](https://huggingface.co/transformers-community) packaged the implementation and pre-built it for use with the [kernels library](https://github.com/huggingface/kernels).
10
 
 
11
 
12
  ## Quickstart
13
 
@@ -43,7 +44,7 @@ torch.cuda.manual_seed(42)
43
  # Parameters
44
  batch_size = 2
45
  seqlen_q = 128 # Query sequence length
46
- seqlen_k = 256 # Key sequence length
47
  nheads = 8 # Number of attention heads
48
  d = 64 # Head dimension
49
 
@@ -65,7 +66,6 @@ print(f"\nAttention computation successful!")
65
  print(f"Output tensor stats - Mean: {output.mean().item():.4f}, Std: {output.std().item():.4f}")
66
  ```
67
 
68
-
69
  ## How to Use
70
 
71
  When loading your model with transformers, provide this repository id as the source of the attention implementation:
@@ -91,4 +91,5 @@ This will automatically resolve and download the appropriate code for your archi
91
 
92
  - [Tri Dao](https://huggingface.co/tridao) and team for Flash Attention and [Flash Attention 3](https://tridao.me/blog/2024/flash3/).
93
  - The [vLLM team](https://huggingface.co/vllm-project) for their implementation and their contribution of attention sinks.
94
- - The [transformers team](https://huggingface.co/transformers-community) for packaging, testing, building and making it available for use with the [kernels library](https://github.com/huggingface/kernels).
 
 
1
  ---
2
  license: apache-2.0
3
  tags:
4
+ - kernel
5
  ---
6
+
7
  # vllm-flash-attn3
8
 
9
  This is an implementation of Flash Attention 3 CUDA kernels with support for attention sinks. The attention sinks implementation was contributed to Flash Attention by the [vLLM team](https://huggingface.co/vllm-project). The [transformers team](https://huggingface.co/transformers-community) packaged the implementation and pre-built it for use with the [kernels library](https://github.com/huggingface/kernels).
10
 
11
+ Kernel source: https://github.com/huggingface/kernels-community/tree/main/vllm-flash-attn3
12
 
13
  ## Quickstart
14
 
 
44
  # Parameters
45
  batch_size = 2
46
  seqlen_q = 128 # Query sequence length
47
+ seqlen_k = 256 # Key sequence length
48
  nheads = 8 # Number of attention heads
49
  d = 64 # Head dimension
50
 
 
66
  print(f"Output tensor stats - Mean: {output.mean().item():.4f}, Std: {output.std().item():.4f}")
67
  ```
68
 
 
69
  ## How to Use
70
 
71
  When loading your model with transformers, provide this repository id as the source of the attention implementation:
 
91
 
92
  - [Tri Dao](https://huggingface.co/tridao) and team for Flash Attention and [Flash Attention 3](https://tridao.me/blog/2024/flash3/).
93
  - The [vLLM team](https://huggingface.co/vllm-project) for their implementation and their contribution of attention sinks.
94
+ - The [transformers team](https://huggingface.co/transformers-community) for packaging, testing, building and making it available for use with the [kernels library](https://github.com/huggingface/kernels).
95
+
build.toml DELETED
@@ -1,593 +0,0 @@
1
- [general]
2
- name = "vllm_flash_attn3"
3
- universal = false
4
- cuda-minver = "12.4"
5
- cuda-maxver = "12.4"
6
-
7
- [torch]
8
- src = [
9
- "torch-ext/pytorch_shim.h",
10
- "torch-ext/torch_binding.cpp",
11
- "torch-ext/torch_binding.h",
12
- ]
13
-
14
- [kernel.flash_attn]
15
- backend = "cuda"
16
- cuda-capabilities = ["8.0", "9.0a"]
17
- cuda-flags = [
18
- "-O3",
19
- "-std=c++17",
20
- "--ftemplate-backtrace-limit=0", # To debug template code
21
- "--use_fast_math",
22
- "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
23
- "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
24
- "-DCUTLASS_ENABLE_GDC_FOR_SM90",
25
- "--expt-relaxed-constexpr",
26
- "--expt-extended-lambda",
27
- "--use_fast_math",
28
- "-DNDEBUG",
29
- ]
30
- cxx-flags = ["-DFLASHATTENTION_DISABLE_PYBIND"]
31
- src = [
32
- "flash-attn/cuda_check.h",
33
- "flash-attn/flash_api.cpp",
34
- "flash-attn/flash_fwd_combine.cu",
35
- "flash-attn/flash_fwd_combine_kernel.h",
36
- "flash-attn/flash_fwd_combine_launch_template.h",
37
- "flash-attn/flash.h",
38
- "flash-attn/flash_prepare_scheduler.cu",
39
- "flash-attn/heuristics.h",
40
- "flash-attn/seqlen.h",
41
- "flash-attn/static_switch.h",
42
- "flash-attn/tile_size.h",
43
- "flash-attn/utils.h",
44
- ]
45
- depends = ["torch", "cutlass_3_9"]
46
-
47
- [kernel.flash_attn_sm80]
48
- backend = "cuda"
49
- cuda-capabilities = ["8.0", "9.0a"]
50
- cuda-flags = [
51
- "-O3",
52
- "-std=c++17",
53
- "--ftemplate-backtrace-limit=0", # To debug template code
54
- "--use_fast_math",
55
- "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
56
- "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
57
- "-DCUTLASS_ENABLE_GDC_FOR_SM90",
58
- "--expt-relaxed-constexpr",
59
- "--expt-extended-lambda",
60
- "--use_fast_math",
61
- "-DNDEBUG",
62
- ]
63
- src = [
64
- "flash-attn/block.h",
65
- "flash-attn/copy_sm90_bulk_reduce.hpp",
66
- "flash-attn/epilogue_bwd.hpp",
67
- "flash-attn/epilogue_fwd.hpp",
68
- "flash-attn/flash.h",
69
- "flash-attn/flash_bwd_kernel_sm80.h",
70
- "flash-attn/flash_bwd_kernel_sm90.h",
71
- "flash-attn/flash_bwd_launch_template.h",
72
- "flash-attn/flash_bwd_postprocess_kernel.h",
73
- "flash-attn/flash_bwd_preprocess_kernel.h",
74
- "flash-attn/flash_fwd_launch_template.h",
75
- "flash-attn/flash_fwd_kernel_sm80.h",
76
- "flash-attn/flash_fwd_kernel_sm90.h",
77
- "flash-attn/heuristics.h",
78
- "flash-attn/mainloop_bwd_sm80.hpp",
79
- "flash-attn/mainloop_fwd_sm80.hpp",
80
- "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
81
- "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
82
- "flash-attn/mask.h",
83
- "flash-attn/named_barrier.hpp",
84
- "flash-attn/pack_gqa.h",
85
- "flash-attn/paged_kv.h",
86
- "flash-attn/rotary.h",
87
- "flash-attn/sm90_pipeline_no_cluster.hpp",
88
- "flash-attn/softmax.h",
89
- "flash-attn/tile_size.h",
90
- "flash-attn/tile_scheduler.hpp",
91
-
92
- "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu",
93
- "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu",
94
- "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu",
95
- "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu",
96
- "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu",
97
- "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu",
98
- "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu",
99
- "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu",
100
- "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu",
101
- "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu",
102
- "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu",
103
- "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu",
104
- "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu",
105
- "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu",
106
- "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu",
107
- "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu",
108
- "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu",
109
- "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu",
110
- "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu",
111
- "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu",
112
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu",
113
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu",
114
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu",
115
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu",
116
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu",
117
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu",
118
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu",
119
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu",
120
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu",
121
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu",
122
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu",
123
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu",
124
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu",
125
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu",
126
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu",
127
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu",
128
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu",
129
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu",
130
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu",
131
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu",
132
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu",
133
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu",
134
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu",
135
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu",
136
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu",
137
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu",
138
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu",
139
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu",
140
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu",
141
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu",
142
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu",
143
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu",
144
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu",
145
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu",
146
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu",
147
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu",
148
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu",
149
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu",
150
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu",
151
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu",
152
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu",
153
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu",
154
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu",
155
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu",
156
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu",
157
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu",
158
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu",
159
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu",
160
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu",
161
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu",
162
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu",
163
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu",
164
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu",
165
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu",
166
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu",
167
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu",
168
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu",
169
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu",
170
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu",
171
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu",
172
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu",
173
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu",
174
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu",
175
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu",
176
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu",
177
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu",
178
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu",
179
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu",
180
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu",
181
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu",
182
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu",
183
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu",
184
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu",
185
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu",
186
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu",
187
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu",
188
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu",
189
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu",
190
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu",
191
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu"
192
- ]
193
- include = ["flash-attn"]
194
- depends = ["torch", "cutlass_3_9"]
195
-
196
- [kernel.flash_attn_sm90]
197
- backend = "cuda"
198
- cuda-capabilities = ["8.0", "9.0a"]
199
- cuda-flags = [
200
- "-O3",
201
- "-std=c++17",
202
- "--ftemplate-backtrace-limit=0", # To debug template code
203
- "--use_fast_math",
204
- "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
205
- "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
206
- "-DCUTLASS_ENABLE_GDC_FOR_SM90",
207
- "--expt-relaxed-constexpr",
208
- "--expt-extended-lambda",
209
- "--use_fast_math",
210
- "-DNDEBUG",
211
- ]
212
- src = [
213
- "flash-attn/block.h",
214
- "flash-attn/copy_sm90_bulk_reduce.hpp",
215
- "flash-attn/epilogue_bwd.hpp",
216
- "flash-attn/epilogue_fwd.hpp",
217
- "flash-attn/flash.h",
218
- "flash-attn/flash_bwd_kernel_sm80.h",
219
- "flash-attn/flash_bwd_kernel_sm90.h",
220
- "flash-attn/flash_bwd_launch_template.h",
221
- "flash-attn/flash_bwd_postprocess_kernel.h",
222
- "flash-attn/flash_bwd_preprocess_kernel.h",
223
- "flash-attn/flash_fwd_launch_template.h",
224
- "flash-attn/flash_fwd_kernel_sm80.h",
225
- "flash-attn/flash_fwd_kernel_sm90.h",
226
- "flash-attn/heuristics.h",
227
- "flash-attn/mainloop_bwd_sm80.hpp",
228
- "flash-attn/mainloop_fwd_sm80.hpp",
229
- "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
230
- "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
231
- "flash-attn/mask.h",
232
- "flash-attn/named_barrier.hpp",
233
- "flash-attn/pack_gqa.h",
234
- "flash-attn/paged_kv.h",
235
- "flash-attn/rotary.h",
236
- "flash-attn/sm90_pipeline_no_cluster.hpp",
237
- "flash-attn/softmax.h",
238
- "flash-attn/tile_size.h",
239
- "flash-attn/tile_scheduler.hpp",
240
-
241
- "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu",
242
- "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu",
243
- "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu",
244
- "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu",
245
- "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu",
246
- "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu",
247
- "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu",
248
- "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu",
249
- "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu",
250
- "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu",
251
- "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu",
252
- "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu",
253
- "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu",
254
- "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu",
255
- "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu",
256
- "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu",
257
- "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu",
258
- "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu",
259
- "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu",
260
- "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu",
261
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu",
262
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu",
263
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu",
264
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu",
265
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu",
266
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu",
267
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu",
268
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu",
269
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu",
270
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu",
271
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu",
272
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu",
273
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu",
274
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu",
275
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu",
276
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu",
277
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu",
278
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu",
279
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu",
280
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu",
281
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu",
282
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu",
283
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu",
284
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu",
285
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu",
286
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu",
287
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu",
288
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu",
289
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu",
290
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu",
291
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu",
292
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu",
293
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu",
294
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu",
295
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu",
296
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu",
297
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu",
298
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu",
299
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu",
300
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu",
301
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu",
302
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu",
303
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu",
304
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu",
305
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu",
306
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu",
307
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu",
308
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu",
309
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu",
310
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu",
311
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu",
312
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu",
313
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu",
314
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu",
315
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu",
316
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu",
317
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu",
318
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu",
319
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu",
320
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu",
321
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu",
322
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu",
323
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu",
324
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu",
325
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu",
326
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu",
327
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu",
328
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu",
329
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu",
330
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu",
331
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu",
332
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu",
333
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu",
334
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu",
335
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu",
336
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu",
337
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu",
338
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu",
339
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu",
340
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu",
341
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu",
342
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu",
343
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu",
344
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu",
345
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu",
346
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu",
347
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu",
348
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu",
349
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu",
350
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu",
351
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu",
352
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu",
353
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu",
354
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu",
355
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu",
356
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu",
357
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu",
358
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu",
359
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu",
360
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu",
361
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu",
362
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu",
363
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu",
364
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu",
365
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu",
366
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu",
367
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu",
368
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu",
369
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu",
370
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu",
371
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu",
372
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu",
373
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu",
374
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu",
375
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu",
376
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu",
377
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu",
378
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu",
379
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu",
380
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu",
381
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu",
382
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu",
383
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu",
384
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu",
385
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu",
386
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu",
387
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu",
388
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu",
389
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu",
390
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu",
391
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu",
392
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu",
393
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu",
394
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu",
395
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu",
396
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu",
397
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu",
398
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu",
399
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu",
400
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu",
401
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu",
402
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu",
403
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu",
404
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu",
405
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu",
406
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu",
407
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu",
408
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu",
409
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu",
410
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu",
411
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu",
412
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu",
413
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu",
414
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu",
415
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu",
416
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu",
417
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu",
418
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu",
419
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu",
420
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu",
421
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu",
422
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu",
423
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu",
424
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu",
425
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu",
426
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu",
427
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu",
428
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu",
429
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu",
430
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu",
431
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu",
432
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu",
433
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu",
434
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu",
435
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu",
436
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu",
437
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu",
438
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu",
439
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu",
440
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu",
441
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu",
442
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu",
443
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu",
444
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu",
445
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu",
446
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu",
447
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu",
448
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu",
449
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu",
450
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu",
451
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu",
452
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu",
453
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu",
454
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu",
455
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu",
456
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu",
457
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu",
458
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu",
459
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu",
460
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu",
461
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu",
462
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu",
463
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu",
464
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu",
465
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu",
466
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu",
467
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu",
468
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu",
469
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu",
470
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu",
471
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu",
472
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu",
473
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu",
474
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu",
475
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu",
476
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu",
477
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu",
478
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu",
479
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu",
480
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu",
481
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu",
482
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu",
483
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu",
484
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu",
485
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu",
486
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu",
487
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu",
488
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu",
489
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu",
490
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu",
491
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu",
492
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu",
493
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu",
494
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu",
495
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu",
496
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu",
497
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu",
498
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu",
499
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu",
500
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu",
501
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu",
502
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu",
503
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu",
504
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu",
505
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu",
506
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu",
507
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu",
508
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu",
509
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu",
510
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu",
511
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu",
512
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu",
513
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu",
514
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu",
515
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu",
516
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu",
517
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu",
518
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu",
519
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu",
520
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu",
521
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu",
522
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu",
523
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu",
524
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu",
525
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu",
526
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu",
527
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu",
528
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu",
529
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu",
530
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu",
531
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu",
532
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu",
533
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu",
534
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu",
535
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu",
536
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu",
537
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu",
538
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu",
539
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu",
540
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu",
541
- ]
542
- include = ["flash-attn"]
543
- depends = ["torch", "cutlass_3_9"]
544
-
545
- # [kernel.flash_attn_sm100]
546
- # backend = "cuda"
547
- # cuda-capabilities = ["8.0", "9.0a", "10.0"]
548
- # cuda-flags = [
549
- # "-O3",
550
- # "-std=c++17",
551
- # "--ftemplate-backtrace-limit=0", # To debug template code
552
- # "--use_fast_math",
553
- # "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
554
- # "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
555
- # "-DCUTLASS_ENABLE_GDC_FOR_SM90",
556
- # "--expt-relaxed-constexpr",
557
- # "--expt-extended-lambda",
558
- # "--use_fast_math",
559
- # "-DNDEBUG",
560
- # ]
561
- # src = [
562
- # "flash-attn/block.h",
563
- # "flash-attn/copy_sm90_bulk_reduce.hpp",
564
- # "flash-attn/epilogue_bwd.hpp",
565
- # "flash-attn/epilogue_fwd.hpp",
566
- # "flash-attn/flash.h",
567
- # "flash-attn/flash_bwd_kernel_sm80.h",
568
- # "flash-attn/flash_bwd_kernel_sm90.h",
569
- # "flash-attn/flash_bwd_launch_template.h",
570
- # "flash-attn/flash_bwd_postprocess_kernel.h",
571
- # "flash-attn/flash_bwd_preprocess_kernel.h",
572
- # "flash-attn/flash_fwd_launch_template.h",
573
- # "flash-attn/flash_fwd_kernel_sm80.h",
574
- # "flash-attn/flash_fwd_kernel_sm90.h",
575
- # "flash-attn/heuristics.h",
576
- # "flash-attn/mainloop_bwd_sm80.hpp",
577
- # "flash-attn/mainloop_fwd_sm80.hpp",
578
- # "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
579
- # "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
580
- # "flash-attn/mask.h",
581
- # "flash-attn/named_barrier.hpp",
582
- # "flash-attn/pack_gqa.h",
583
- # "flash-attn/paged_kv.h",
584
- # "flash-attn/rotary.h",
585
- # "flash-attn/sm90_pipeline_no_cluster.hpp",
586
- # "flash-attn/softmax.h",
587
- # "flash-attn/tile_size.h",
588
- # "flash-attn/tile_scheduler.hpp",
589
- #
590
- # "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu",
591
- # ]
592
- # include = ["flash-attn"]
593
- # depends = ["torch", "cutlass_3_9"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.lock DELETED
@@ -1,168 +0,0 @@
1
- {
2
- "nodes": {
3
- "flake-compat": {
4
- "locked": {
5
- "lastModified": 1747046372,
6
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
- "owner": "edolstra",
8
- "repo": "flake-compat",
9
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
- "type": "github"
11
- },
12
- "original": {
13
- "owner": "edolstra",
14
- "repo": "flake-compat",
15
- "type": "github"
16
- }
17
- },
18
- "flake-compat_2": {
19
- "locked": {
20
- "lastModified": 1733328505,
21
- "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
- "owner": "edolstra",
23
- "repo": "flake-compat",
24
- "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
- "type": "github"
26
- },
27
- "original": {
28
- "owner": "edolstra",
29
- "repo": "flake-compat",
30
- "type": "github"
31
- }
32
- },
33
- "flake-utils": {
34
- "inputs": {
35
- "systems": "systems"
36
- },
37
- "locked": {
38
- "lastModified": 1731533236,
39
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
- "owner": "numtide",
41
- "repo": "flake-utils",
42
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
- "type": "github"
44
- },
45
- "original": {
46
- "owner": "numtide",
47
- "repo": "flake-utils",
48
- "type": "github"
49
- }
50
- },
51
- "flake-utils_2": {
52
- "inputs": {
53
- "systems": "systems_2"
54
- },
55
- "locked": {
56
- "lastModified": 1731533236,
57
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
- "owner": "numtide",
59
- "repo": "flake-utils",
60
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
- "type": "github"
62
- },
63
- "original": {
64
- "owner": "numtide",
65
- "repo": "flake-utils",
66
- "type": "github"
67
- }
68
- },
69
- "hf-nix": {
70
- "inputs": {
71
- "flake-compat": "flake-compat_2",
72
- "flake-utils": "flake-utils_2",
73
- "nixpkgs": "nixpkgs"
74
- },
75
- "locked": {
76
- "lastModified": 1750234878,
77
- "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
- "owner": "huggingface",
79
- "repo": "hf-nix",
80
- "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
- "type": "github"
82
- },
83
- "original": {
84
- "owner": "huggingface",
85
- "repo": "hf-nix",
86
- "type": "github"
87
- }
88
- },
89
- "kernel-builder": {
90
- "inputs": {
91
- "flake-compat": "flake-compat",
92
- "flake-utils": "flake-utils",
93
- "hf-nix": "hf-nix",
94
- "nixpkgs": [
95
- "kernel-builder",
96
- "hf-nix",
97
- "nixpkgs"
98
- ]
99
- },
100
- "locked": {
101
- "lastModified": 1751014803,
102
- "narHash": "sha256-9Xfq2k3uPfB602NwQF+zAY2GQZiKUN1G7Q6XiDCUR8Y=",
103
- "owner": "huggingface",
104
- "repo": "kernel-builder",
105
- "rev": "bbc4e712ff2046e217818e97de2201e2b996756e",
106
- "type": "github"
107
- },
108
- "original": {
109
- "owner": "huggingface",
110
- "repo": "kernel-builder",
111
- "type": "github"
112
- }
113
- },
114
- "nixpkgs": {
115
- "locked": {
116
- "lastModified": 1747820358,
117
- "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
- "owner": "danieldk",
119
- "repo": "nixpkgs",
120
- "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
- "type": "github"
122
- },
123
- "original": {
124
- "owner": "danieldk",
125
- "ref": "cudatoolkit-12.9-kernel-builder",
126
- "repo": "nixpkgs",
127
- "type": "github"
128
- }
129
- },
130
- "root": {
131
- "inputs": {
132
- "kernel-builder": "kernel-builder"
133
- }
134
- },
135
- "systems": {
136
- "locked": {
137
- "lastModified": 1681028828,
138
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
- "owner": "nix-systems",
140
- "repo": "default",
141
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
- "type": "github"
143
- },
144
- "original": {
145
- "owner": "nix-systems",
146
- "repo": "default",
147
- "type": "github"
148
- }
149
- },
150
- "systems_2": {
151
- "locked": {
152
- "lastModified": 1681028828,
153
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
- "owner": "nix-systems",
155
- "repo": "default",
156
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
- "type": "github"
158
- },
159
- "original": {
160
- "owner": "nix-systems",
161
- "repo": "default",
162
- "type": "github"
163
- }
164
- }
165
- },
166
- "root": "root",
167
- "version": 7
168
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.nix DELETED
@@ -1,51 +0,0 @@
1
- {
2
- description = "Flake for Hopper Flash Attention kernel";
3
-
4
- inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder";
6
- };
7
-
8
- outputs =
9
- {
10
- self,
11
- kernel-builder,
12
- }:
13
- kernel-builder.lib.genFlakeOutputs {
14
- path = ./.;
15
- rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
- # Building with CDUA later than 12.4 fails with:
17
- #
18
- # error: 'ptxas' died due to signal 11 (Invalid memory reference)
19
- #
20
- # So, build for 12.4 only and copy to all the other build variants
21
- # by hand (which works fine thanks to backward compat).
22
- #
23
- # Still need to check if upstream FA3 has the same issue.
24
- torchVersions = [
25
- {
26
- torchVersion = "2.6";
27
- cudaVersion = "12.4";
28
- cxx11Abi = false;
29
- systems = [ "x86_64-linux" ];
30
- upstreamVariant = true;
31
- }
32
- {
33
- torchVersion = "2.6";
34
- cudaVersion = "12.4";
35
- cxx11Abi = true;
36
- systems = [ "x86_64-linux" ];
37
- upstreamVariant = true;
38
- }
39
- {
40
- torchVersion = "2.7";
41
- cudaVersion = "12.4";
42
- cxx11Abi = true;
43
- systems = [
44
- "x86_64-linux"
45
- "aarch64-linux"
46
- ];
47
- upstreamVariant = true;
48
- }
49
- ];
50
- };
51
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/block.h DELETED
@@ -1,94 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- namespace flash {
8
-
9
- template <class SeqlenInfo_t, int kBlockM, int kBlockN, bool Is_causal, bool Is_local, bool PackGQA=false, bool Split=false>
10
- struct BlockMN {
11
-
12
- static
13
- CUTLASS_DEVICE
14
- cute::tuple<int, int> get_n_block_min_max(
15
- SeqlenInfo_t const& seqlen_info,
16
- int const m_block, int const bidb, int const split_idx, int const num_splits,
17
- int const window_size_left, int const window_size_right,
18
- cutlass::FastDivmod const& qhead_per_khead_divmod) {
19
-
20
- int const seqlen_k = seqlen_info.seqlen_k;
21
- int const seqlen_q = seqlen_info.seqlen_q;
22
- int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
23
- if constexpr (Is_causal || Is_local) {
24
- int m_idx_max = (m_block + 1) * kBlockM;
25
- // TODO: check off-by-1 error
26
- if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
27
- n_block_max = std::min(n_block_max,
28
- cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN));
29
- }
30
- int n_block_min = 0;
31
- if constexpr (Is_local) {
32
- int m_idx_min = m_block * kBlockM;
33
- if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
34
- n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - window_size_left) / kBlockN);
35
- }
36
- // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
37
- if constexpr (Split) {
38
- uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
39
- int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
40
- int split_idx_actual = split_idx & 0x0000FFFF;
41
- int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
42
- int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual);
43
- n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split;
44
- n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);
45
- // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); }
46
- }
47
- // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
48
- return {n_block_min, n_block_max};
49
- }
50
-
51
- static
52
- CUTLASS_DEVICE
53
- cute::tuple<int, int> get_n_block_k_new_min_max(
54
- SeqlenInfo_t const& seqlen_info,
55
- int const m_block, int const bidb, int const split_idx, int const num_splits,
56
- int const window_size_left, int const window_size_right,
57
- cutlass::FastDivmod const& qhead_per_khead_divmod) {
58
-
59
- auto [n_block_min, n_block_max] = get_n_block_min_max(
60
- seqlen_info, m_block, bidb, split_idx, num_splits,
61
- window_size_left, window_size_right, qhead_per_khead_divmod);
62
- int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
63
- int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
64
- int const n_block_new_min = idx_k_new_min / kBlockN;
65
- int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
66
- // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
67
- return {n_block_new_min, n_block_new_max};
68
- }
69
-
70
- static
71
- CUTLASS_DEVICE
72
- cute::tuple<int, int> get_m_block_min_max(
73
- SeqlenInfo_t const& seqlen_info,
74
- int const n_block, int const bidb,
75
- int const window_size_left, int const window_size_right, int const sink_token_length) {
76
-
77
- int const seqlen_q = seqlen_info.seqlen_q;
78
- int const seqlen_k = seqlen_info.seqlen_k;
79
- int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
80
- if constexpr (Is_local) {
81
- if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) {
82
- m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM));
83
- }
84
- }
85
- int m_block_min = 0;
86
- if constexpr (Is_causal || Is_local) {
87
- m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM);
88
- }
89
- return {m_block_min, m_block_max};
90
- }
91
-
92
- };
93
-
94
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/copy_sm90_bulk_reduce.hpp DELETED
@@ -1,49 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include<cute/arch/copy_sm90_tma.hpp>
8
-
9
- namespace cute
10
- {
11
-
12
- ////////////////////////////////////////////////////////////////////////////////////////////////////
13
-
14
- struct SM90_BULK_REDUCE_ADD
15
- {
16
- CUTE_HOST_DEVICE static void
17
- copy(float const* smem_ptr,
18
- float * gmem_ptr, int32_t store_bytes)
19
- {
20
- #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
21
- uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
22
- asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
23
- :
24
- : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes)
25
- : "memory");
26
- #else
27
- CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
28
- #endif
29
- }
30
-
31
- CUTE_HOST_DEVICE static void
32
- copy(float const* smem_ptr,
33
- float * gmem_ptr, int32_t store_bytes, uint64_t cache_hint)
34
- {
35
- #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
36
- uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
37
- asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n"
38
- :
39
- : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint)
40
- : "memory");
41
- #else
42
- CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
43
- #endif
44
- }
45
- };
46
-
47
- ////////////////////////////////////////////////////////////////////////////////////////////////////
48
-
49
- } // end namespace cute
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/cuda_check.h DELETED
@@ -1,19 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <assert.h>
8
- #include <stdlib.h>
9
-
10
- #define CHECK_CUDA(call) \
11
- do { \
12
- cudaError_t status_ = call; \
13
- if (status_ != cudaSuccess) { \
14
- fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
15
- exit(1); \
16
- } \
17
- } while(0)
18
-
19
- #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/epilogue_bwd.hpp DELETED
@@ -1,523 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cutlass/cutlass.h"
8
- #include "cutlass/barrier.h"
9
- #include "cute/tensor.hpp"
10
-
11
- #include "cutlass/gemm/collective/builders/sm90_common.inl"
12
-
13
- #include "seqlen.h"
14
- #include "named_barrier.hpp"
15
- #include "utils.h"
16
-
17
- namespace flash {
18
-
19
- using namespace cute;
20
-
21
- template <class TileShape_MNK_, class Element_, class ArchTag_,
22
- int NumEpilogueThreads_, bool Varlen_, bool dKV_swapAB_, int AtomLayoutKdKV=1>
23
- struct CollectiveEpilogueBwd {
24
-
25
- using TileShape_MNK = TileShape_MNK_;
26
- using Element = Element_;
27
- using ArchTag = ArchTag_;
28
- static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
29
- static constexpr bool Varlen = Varlen_;
30
- static constexpr bool dKV_swapAB = dKV_swapAB_;
31
- static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90;
32
-
33
- static_assert(ArchTag::kMinComputeCapability >= 80);
34
-
35
- using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE;
36
-
37
- // These are for storing the output tensor without TMA (e.g., for setting output to zero)
38
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
39
- static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
40
- static constexpr int kHeadDim = get<2>(TileShape_MNK{});
41
- static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);
42
- static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
43
- using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
44
- Stride<Int<kGmemThreadsPerRow>, _1>>;
45
- using GmemTiledCopydKV = decltype(
46
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
47
- GmemLayoutAtom{},
48
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
49
-
50
- using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
51
- // TODO: do we have to change this if dKV_swapAB is true?
52
- decltype(cute::get<1>(TileShape_MNK{})), Int<CUTE_STATIC_V(cute::get<2>(TileShape_MNK{})) / AtomLayoutKdKV>>());
53
- using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{})));
54
- using SmemLayoutdKVtTMA =
55
- decltype(cute::composition(SmemLayoutdKVTMA{},
56
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
57
- make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
58
-
59
- // If we don't use TMA
60
- static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16);
61
- static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
62
- using SmemLayoutAtomdKVSTG =
63
- decltype(composition(Swizzle<kSwizzle, 3, 3>{},
64
- Layout<Shape<Int<8>, Int<kBlockKSmem>>,
65
- Stride<Int<kBlockKSmem>, _1>>{}));
66
-
67
- using SmemLayoutAtomdKV = std::conditional_t<Use_TMA, SmemLayoutAtomdKVTMA, SmemLayoutAtomdKVSTG>;
68
- using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{})));
69
- using SmemLayoutdKVt =
70
- decltype(cute::composition(SmemLayoutdKV{},
71
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
72
- make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
73
-
74
- using SmemCopyAtomdKV = Copy_Atom<
75
- std::conditional_t<
76
- ArchTag::kMinComputeCapability >= 90,
77
- std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
78
- AutoVectorizingCopyWithAssumedAlignment<128>
79
- >,
80
- Element>;
81
-
82
- static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128;
83
- static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment");
84
-
85
- struct TensorStorage : cute::aligned_struct<SmemAlignmentdKV> {
86
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dk;
87
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dv;
88
- };
89
-
90
- using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_k, d, head, batch)
91
- using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
92
-
93
- using TMA_dKV = std::conditional_t<
94
- Use_TMA,
95
- decltype(make_tma_copy(
96
- GmemTiledCopydKVTMA{},
97
- make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapedKV{}, StridedKV{}),
98
- SmemLayoutdKVTMA{},
99
- select<1, 2>(TileShape_MNK{}),
100
- _1{})), // no mcast for dKV
101
- std::nullptr_t
102
- >;
103
-
104
- // Host side kernel arguments
105
- struct Arguments {
106
- Element* ptr_dK;
107
- ShapedKV const shape_dK;
108
- StridedKV const stride_dK;
109
- Element* ptr_dV;
110
- StridedKV const stride_dV;
111
- int const num_heads_q;
112
- int* dk_semaphore;
113
- int* dv_semaphore;
114
- int const* cu_seqlens;
115
- int const* seqused;
116
- };
117
-
118
- // Device side kernel params
119
- struct Params {
120
- Element* ptr_dK;
121
- ShapedKV const shape_dK;
122
- StridedKV const stride_dK;
123
- Element* ptr_dV;
124
- StridedKV const stride_dV;
125
- TMA_dKV tma_store_dK, tma_store_dV;
126
- int const* cu_seqlens = nullptr;
127
- int const* seqused = nullptr;
128
- };
129
-
130
- static Params
131
- to_underlying_arguments(Arguments const& args) {
132
- Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK);
133
- Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV);
134
- TMA_dKV tma_store_dK = [&] {
135
- if constexpr (Use_TMA) {
136
- return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
137
- } else {
138
- return nullptr;
139
- }
140
- }();
141
- TMA_dKV tma_store_dV = [&] {
142
- if constexpr (Use_TMA) {
143
- return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
144
- } else {
145
- return nullptr;
146
- }
147
- }();
148
- return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV,
149
- tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
150
- }
151
-
152
- /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
153
- CUTLASS_DEVICE
154
- static void prefetch_tma_descriptors(Params const& params) {
155
- if constexpr (Use_TMA) {
156
- cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor());
157
- cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor());
158
- }
159
- }
160
-
161
- template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
162
- CUTLASS_DEVICE void
163
- store(Params const& params,
164
- FrgTensorO const& tdKrdK,
165
- FrgTensorO const& tdVrdV,
166
- SharedStorage& shared_storage,
167
- TiledMma tiled_mma,
168
- int thread_idx,
169
- cute::tuple<int32_t, int32_t, int32_t> const& block_coord
170
- ) {
171
-
172
- auto [n_block, bidh, bidb] = block_coord;
173
- Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{}));
174
- Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{}));
175
- Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{}));
176
- Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{}));
177
- auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma);
178
- auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx);
179
-
180
- Tensor tdVrdV_out = make_tensor_like<Element>(tdVrdV);
181
- flash::convert_type_out(tdVrdV, tdVrdV_out);
182
- Tensor tdKrdK_out = make_tensor_like<Element>(tdKrdK);
183
- flash::convert_type_out(tdKrdK, tdKrdK_out);
184
- Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
185
- Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
186
- // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); }
187
- Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
188
- Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
189
-
190
- // Make sure all WGs have finished reading K and V
191
- flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
192
- cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
193
- cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
194
- if constexpr (Use_TMA) {
195
- cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
196
- cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
197
- cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
198
-
199
- Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK);
200
- Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK);
201
- Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
202
- Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
203
- auto block_tma_dK = params.tma_store_dK.get_slice(_0{});
204
- auto block_tma_dV = params.tma_store_dV.get_slice(_0{});
205
- Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
206
- Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
207
- Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
208
- Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
209
- int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
210
- if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
211
- cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
212
- cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
213
- if (cute::elect_one_sync()) {
214
- cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);
215
- cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);
216
- tma_store_arrive();
217
- }
218
- }
219
- tma_store_wait<0>();
220
- // // Tell warp 0 that smem_k and smem_v are ready
221
- // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
222
-
223
- } else {
224
- flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
225
- static constexpr int kBlockN = get<1>(TileShape_MNK{});
226
- flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
227
- bool const is_varlen = Varlen && params.cu_seqlens;
228
- Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
229
- Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
230
- Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
231
- Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
232
-
233
- GmemTiledCopydKV gmem_tiled_copy_dKV;
234
- auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
235
- Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
236
- Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
237
- Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
238
- Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)
239
- Tensor tdKVrdV = make_fragment_like(tdKVgdV);
240
- Tensor tdKVrdK = make_fragment_like(tdKVgdK);
241
- Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k)
242
- // Repeat the partitioning with identity layouts
243
- Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
244
- Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
245
- #pragma unroll
246
- for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
247
- // Need to check OOB when reading from smem if kBlockN isn't evenly tiled
248
- static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
249
- flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
250
- gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdKV, kBlockN);
251
- flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
252
- gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdKV, kBlockN);
253
- // // Tell warp 0 that smem_k and smem_v are ready
254
- // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v
255
- // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
256
- // Construct identity layout for gdKV
257
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
258
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
259
- gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
260
- );
261
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
262
- gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
263
- );
264
- }
265
- }
266
-
267
- CUTLASS_DEVICE void
268
- store_tail() {
269
- // if constexpr (Use_TMA) { tma_store_wait<0>(); }
270
- }
271
-
272
- // Write 0 to dK and dV
273
- CUTLASS_DEVICE void
274
- store_zero(
275
- Params const& params,
276
- int thread_idx,
277
- cute::tuple<int32_t, int32_t, int32_t> const& block_coord
278
- ) {
279
- static constexpr int kBlockN = get<1>(TileShape_MNK{});
280
- auto [n_block, bidh, bidb] = block_coord;
281
- flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
282
- bool const is_varlen = Varlen && params.cu_seqlens;
283
- Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
284
- Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
285
- Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
286
- Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
287
-
288
- GmemTiledCopydKV gmem_tiled_copy_dKV;
289
- auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
290
- Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
291
- Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
292
- Tensor tdKVrdKV = make_fragment_like(tdKVgdK);
293
- clear(tdKVrdKV);
294
- // Construct identity layout for gdKV
295
- Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
296
- // Repeat the partitioning with identity layouts
297
- Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
298
- Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
299
- #pragma unroll
300
- for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
301
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
302
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
303
- gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN
304
- );
305
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
306
- gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN
307
- );
308
- }
309
-
310
- };
311
-
312
- template <class TileShape_MNK_, class ElementAccum, class ArchTag_,
313
- int NumEpilogueThreads_, bool Varlen_, bool Deterministic>
314
- struct CollectiveEpilogueBwdGQA {
315
-
316
- using TileShape_MNK = TileShape_MNK_;
317
- using Element = ElementAccum;
318
- using ArchTag = ArchTag_;
319
- static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
320
- static constexpr bool Varlen = Varlen_;
321
- static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90;
322
-
323
- static_assert(ArchTag::kMinComputeCapability >= 80);
324
-
325
- static constexpr int kBlockN = get<1>(TileShape_MNK{});
326
- static constexpr int kHeadDim = get<2>(TileShape_MNK{});
327
- static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp");
328
- static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup;
329
- // Thread layout, 256 or 384 threads per row
330
- // We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ
331
- using R2SLayoutAtomdKVaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumWarpGroups>>>;
332
- using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdKVaccum{},
333
- Layout<Shape < _4>>{})); // Val layout, 4 vals per store
334
- // For Sm80
335
- using R2GLayoutAtomdKVaccum = Layout<Shape<Int<NumEpilogueThreads>>>;
336
- using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2GLayoutAtomdKVaccum{},
337
- Layout<Shape < _1>>{})); // Val layout, 1 vals per store
338
-
339
- using SmemLayoutdKVaccum = Layout<Shape<Int<kBlockN * kHeadDim / NumWarpGroups>, Int<NumWarpGroups>>>;
340
- using SmemLayoutdKVaccumFlat = Layout<Shape<Int<kBlockN * kHeadDim>>>;
341
-
342
- // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we
343
- // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue.
344
- static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256);
345
- struct TensorStorageTMA : cute::aligned_struct<SmemAlignment> {
346
- cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdKVaccum>, SmemAlignment> smem_dkv;
347
- };
348
- struct TensorStorageSTG {
349
- cute::array<ElementAccum, 0> smem_dkv;
350
- };
351
- using TensorStorage = std::conditional_t<Use_TMA, TensorStorageTMA, TensorStorageSTG>;
352
-
353
- using ShapedKV = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_k_rounded * d, head, batch)
354
- using StridedKV = cute::Stride<_1, int64_t, int64_t>;
355
-
356
- // Host side kernel arguments
357
- struct Arguments {
358
- ElementAccum* ptr_dKaccum;
359
- ShapedKV const shape_dKaccum;
360
- StridedKV const stride_dKaccum;
361
- ElementAccum* ptr_dVaccum;
362
- StridedKV const stride_dVaccum;
363
- int num_heads_q;
364
- int* dk_semaphore;
365
- int* dv_semaphore;
366
- int const* cu_seqlens;
367
- int const* seqused;
368
- };
369
-
370
- // Device side kernel params
371
- struct Params {
372
- ElementAccum* ptr_dKaccum;
373
- ShapedKV const shape_dKaccum;
374
- StridedKV const stride_dKaccum;
375
- ElementAccum* ptr_dVaccum;
376
- StridedKV const stride_dVaccum;
377
- cutlass::FastDivmod qhead_per_khead_divmod;
378
- int* dk_semaphore;
379
- int* dv_semaphore;
380
- int const* cu_seqlens = nullptr;
381
- int const* seqused = nullptr;
382
- };
383
-
384
- static Params
385
- to_underlying_arguments(Arguments const& args) {
386
- if constexpr (Deterministic) {
387
- assert(args.dk_semaphore != nullptr);
388
- assert(args.dv_semaphore != nullptr);
389
- }
390
- return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.stride_dVaccum,
391
- cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))),
392
- args.dk_semaphore, args.dv_semaphore,
393
- args.cu_seqlens, args.seqused};
394
- }
395
-
396
- /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
397
- CUTLASS_DEVICE
398
- static void prefetch_tma_descriptors(Params const& params) {
399
- }
400
-
401
- template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
402
- CUTLASS_DEVICE void
403
- store(Params const& params,
404
- FrgTensorO const& tdKrdK,
405
- FrgTensorO const& tdVrdV,
406
- SharedStorage& shared_storage,
407
- TiledMma tiled_mma,
408
- int thread_idx,
409
- cute::tuple<int32_t, int32_t, int32_t> const& block_coord
410
- ) {
411
-
412
- auto [n_block, bidh, bidb] = block_coord;
413
- int bidh_idx_in_group;
414
- int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh);
415
- Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{});
416
- Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{});
417
- static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum);
418
-
419
- flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused};
420
- bool const is_varlen = Varlen && params.cu_seqlens;
421
- Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
422
- Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dKaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
423
- Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
424
- Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
425
-
426
- R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum;
427
- auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
428
- Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV);
429
-
430
- // Only used if !Use_TMA
431
- R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum;
432
- auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
433
-
434
- // Make sure all WGs have finished reading K and V, otherwise we get racy dQ
435
- // because smem_q could be changed.
436
- flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
437
- if constexpr (Use_TMA) {
438
- Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N)
439
- cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum);
440
- }
441
-
442
- // int const num_batch = params.num_batch;
443
- int const num_batch = get<2>(params.shape_dKaccum);
444
- int const num_head_kv = get<1>(params.shape_dKaccum);
445
- int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv;
446
- using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
447
-
448
- // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
449
-
450
- if constexpr (Deterministic) {
451
- Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
452
- }
453
- // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);}
454
- if constexpr (Use_TMA) {
455
- cutlass::arch::fence_view_async_shared();
456
- cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
457
- if (thread_idx == 0) {
458
- SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
459
- tma_store_arrive();
460
- tma_store_wait<0>();
461
- }
462
- } else {
463
- Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV);
464
- Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum);
465
- static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic)));
466
- #pragma unroll
467
- for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); }
468
- }
469
- if constexpr (Deterministic) {
470
- Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
471
- }
472
-
473
- if constexpr (Use_TMA) {
474
- cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
475
- Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N)
476
- cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum);
477
- }
478
- lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv;
479
- // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
480
-
481
- if constexpr (Deterministic) {
482
- Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
483
- }
484
- // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);}
485
- if constexpr (Use_TMA) {
486
- cutlass::arch::fence_view_async_shared();
487
- cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
488
- if (thread_idx == 0) {
489
- SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
490
- tma_store_arrive();
491
- tma_store_wait<0>();
492
- }
493
- } else {
494
- Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK);
495
- Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum);
496
- static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic)));
497
- #pragma unroll
498
- for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); }
499
- }
500
- if constexpr (Deterministic) {
501
- Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
502
- }
503
- // // Tell warp 0 that smem_k and smem_v are ready
504
- // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
505
- }
506
-
507
- CUTLASS_DEVICE void
508
- store_tail() {
509
- }
510
-
511
- // Write 0 to dK and dV
512
- CUTLASS_DEVICE void
513
- store_zero(
514
- Params const& params,
515
- int thread_idx,
516
- cute::tuple<int32_t, int32_t, int32_t> const& block_coord
517
- ) {
518
- // Don't need to do anything since dKaccum and dVaccum are already zero-initialized
519
- }
520
-
521
- };
522
-
523
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/epilogue_fwd.hpp DELETED
@@ -1,484 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <cutlass/cutlass.h>
8
- #include <cutlass/fast_math.h> // For FastDivMod
9
- #include "cute/tensor.hpp"
10
-
11
- #include "cutlass/gemm/collective/builders/sm90_common.inl"
12
- #include "cutlass/epilogue/collective/builders/sm90_common.inl"
13
-
14
- #include "seqlen.h"
15
- #include "named_barrier.hpp"
16
- #include "pack_gqa.h"
17
- #include "utils.h"
18
-
19
- namespace flash {
20
-
21
- using namespace cute;
22
-
23
- template <class TileShape_MNK_PV_, class ClusterShape_, class Element_, class ArchTag_,
24
- int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false>
25
- struct CollectiveEpilogueFwd {
26
-
27
- using TileShape_MNK_PV = TileShape_MNK_PV_;
28
- using ClusterShape = ClusterShape_;
29
- using Element = Element_;
30
- using ElementPartial = float;
31
- using ArchTag = ArchTag_;
32
- static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
33
- static constexpr bool Varlen = Varlen_;
34
- static constexpr bool PackGQA = PackGQA_;
35
- static constexpr bool Split = Split_;
36
- static constexpr bool Use_smem = !(Split && !Varlen);
37
- static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA;
38
-
39
- static_assert(ArchTag::kMinComputeCapability >= 80);
40
- static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);
41
- static_assert(sizeof(Element) <= 2);
42
-
43
- static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
44
- static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{});
45
-
46
- static constexpr bool LargeHeadDimV = kHeadDimV > 256;
47
-
48
- using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
49
-
50
- // These are for storing the output tensor without TMA (e.g., for setting output to zero)
51
- static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);
52
- static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
53
- // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements
54
- // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times
55
- // we need to call divmod.
56
- static constexpr int kBytePerRow = kHeadDimV * sizeof(Element);
57
- static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
58
- static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore;
59
- // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp
60
- static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);
61
- static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
62
- using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
63
- Stride<Int<kGmemThreadsPerRow>, _1>>;
64
- static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow");
65
- using GmemTiledCopyO = decltype(
66
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
67
- GmemLayoutAtom{},
68
- Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); // Val layout, 8 or 16 vals per store
69
-
70
- using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
71
- decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>());
72
- using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{})));
73
- static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
74
- static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
75
- using SmemLayoutAtomO = decltype(
76
- composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
77
- Layout<Shape<_8, Int<kBlockKGmem>>,
78
- Stride<Int<kBlockKGmem>, _1>>{}));
79
- using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{})));
80
- using SmemLayoutO = std::conditional_t<ArchTag::kMinComputeCapability >= 90, SmemLayoutOTMA, SmemLayoutOSTS>;
81
-
82
- using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch, num_splits)
83
- using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
84
- using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits)
85
- // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
86
- using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
87
- using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;
88
- // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
89
- using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
90
- using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;
91
-
92
- using CopyOpR2S = std::conditional_t<
93
- ArchTag::kMinComputeCapability >= 90,
94
- // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
95
- decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>()),
96
- AutoVectorizingCopyWithAssumedAlignment<128>
97
- >;
98
- using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
99
-
100
- // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{});
101
- // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment");
102
- // struct TensorStorage : cute::aligned_struct<SmemAlignmentO> {
103
- // cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0, SmemAlignmentO> smem_o;
104
- // };
105
- struct TensorStorage : cute::aligned_struct<128> {
106
- cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0> smem_o;
107
- };
108
-
109
- using TMA_O = std::conditional_t<
110
- Use_TMA_O,
111
- decltype(make_tma_copy(
112
- GmemTiledCopyOTMA{},
113
- make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
114
- SmemLayoutOTMA{},
115
- select<0, 1>(TileShape_MNK_PV{}),
116
- _1{})), // no mcast for O
117
- std::nullptr_t
118
- >;
119
-
120
- // Host side kernel arguments
121
- struct Arguments {
122
- Element* ptr_O;
123
- ShapeO const shape_O;
124
- StrideO const stride_O;
125
- ElementPartial* ptr_O_partial;
126
- StrideO const stride_O_partial;
127
- float* ptr_LSE;
128
- StrideLSE const stride_LSE;
129
- float* ptr_LSE_partial;
130
- StrideLSE const stride_LSE_partial;
131
- int32_t const nheads_kv;
132
- int const* cu_seqlens = nullptr;
133
- int const* seqused = nullptr;
134
- };
135
-
136
- // Device side kernel params
137
- struct Params {
138
- Element* ptr_O;
139
- ShapeO const shape_O;
140
- StrideO const stride_O;
141
- ShapeOPacked const shape_O_packed;
142
- StrideOPacked const stride_O_packed;
143
- ElementPartial* ptr_O_partial;
144
- StrideO const stride_O_partial;
145
- StrideOPacked const stride_O_partial_packed;
146
- float* ptr_LSE;
147
- StrideLSE const stride_LSE;
148
- ShapeLSEPacked const shape_LSE_packed;
149
- StrideLSEPacked const stride_LSE_packed;
150
- float* ptr_LSE_partial;
151
- StrideLSE const stride_LSE_partial;
152
- StrideLSEPacked const stride_LSE_partial_packed;
153
- cutlass::FastDivmod qhead_per_khead_divmod;
154
- TMA_O tma_store_O;
155
- int const* cu_seqlens = nullptr;
156
- int const* seqused = nullptr;
157
- };
158
-
159
- static Params
160
- to_underlying_arguments(Arguments const& args) {
161
- Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
162
- TMA_O tma_store_O = [&]{
163
- if constexpr (Use_TMA_O) {
164
- return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast
165
- } else {
166
- return nullptr;
167
- }
168
- }();
169
- // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)
170
- int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);
171
- auto const shape_O_packed = cute::conditional_return<!PackGQA>(
172
- args.shape_O,
173
- make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
174
- );
175
- auto const stride_O_packed = cute::conditional_return<!PackGQA>(
176
- args.stride_O,
177
- make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O))
178
- );
179
- auto const stride_O_partial_packed = cute::conditional_return<!PackGQA>(
180
- args.stride_O_partial,
181
- make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial))
182
- );
183
- // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)
184
- auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(
185
- select<0, 2, 3, 4>(args.shape_O),
186
- make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
187
- );
188
- auto const stride_LSE_packed = cute::conditional_return<!PackGQA>(
189
- args.stride_LSE,
190
- make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE))
191
- );
192
- auto const stride_LSE_partial_packed = cute::conditional_return<!PackGQA>(
193
- args.stride_LSE_partial,
194
- make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial))
195
- );
196
- return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed,
197
- args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed,
198
- args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed,
199
- args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed,
200
- cutlass::FastDivmod(qhead_per_khead),
201
- tma_store_O, args.cu_seqlens, args.seqused};
202
- }
203
-
204
- /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
205
- CUTLASS_DEVICE
206
- static void prefetch_tma_descriptors(Params const& params) {
207
- if constexpr (Use_TMA_O) {
208
- cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor());
209
- }
210
- }
211
-
212
- template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
213
- CUTLASS_DEVICE void
214
- store(Params const& params,
215
- FrgTensorO& tOrO,
216
- FrgTensorLSE const& lse,
217
- SharedStorage& shared_storage,
218
- TiledMma tiled_mma,
219
- int thread_idx,
220
- cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
221
- ) {
222
-
223
- auto [m_block, bidh, bidb, split_idx] = block_coord;
224
- int num_splits = get<4>(params.shape_O_packed);
225
- if constexpr (Split && Varlen) {
226
- uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
227
- int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
228
- num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
229
- split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
230
- }
231
- bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
232
-
233
- Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});
234
- // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO);
235
-
236
- static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4);
237
- // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion.
238
- // Otherwise we can permute after conversion.
239
- if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); }
240
- Tensor tOrO_out = make_tensor_like<Element>(tOrO);
241
- flash::convert_type_out(tOrO, tOrO_out);
242
- if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); }
243
-
244
- // Make sure all WGs have finished reading V
245
- // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that
246
- // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with
247
- // cp.async if we need).
248
- flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
249
-
250
- // Step 1: Write O from rmem -> smem
251
- if constexpr (Use_smem) {
252
- auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
253
- auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
254
- Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
255
- Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
256
- // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N)
257
- cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
258
- if constexpr (Use_TMA_O) {
259
- cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
260
- cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
261
- cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
262
- } else {
263
- flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
264
- }
265
- } else {
266
- if constexpr (ArchTag::kMinComputeCapability >= 90) {
267
- #pragma unroll
268
- for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
269
- shared_storage.pipelines.barrier_O.arrive(cta_id);
270
- }
271
- }
272
- }
273
-
274
- flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
275
- bool is_varlen = Varlen && params.cu_seqlens;
276
- int offset_o = seqlen_info.offset;
277
- int seqlen_o = seqlen_info.seqlen;
278
- int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
279
-
280
- // Step 2: Write LSE from rmem -> gmem
281
- auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
282
- // (MMA,MMA_M,MMA_K)
283
- Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
284
- static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
285
- static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
286
- Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));
287
- Tensor taccOcO_row = taccOcO_rowcol(_, _0{});
288
- CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
289
-
290
- using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
291
- using PackGQApartial_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>;
292
-
293
- Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
294
- params.shape_LSE_packed,
295
- !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
296
- // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
297
- if (!LargeHeadDimV || warp_group_idx == 0) {
298
- if constexpr (!PackGQA) {
299
- #pragma unroll
300
- for (int mi = 0; mi < size(lse); ++mi) {
301
- int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));
302
- if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); }
303
- }
304
- } else {
305
- PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
306
- }
307
- }
308
-
309
- // Step 3: Write O from smem -> gmem
310
- if constexpr (Use_TMA_O) {
311
- Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
312
- Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
313
- auto block_tma_O = params.tma_store_O.get_slice(_0{});
314
- Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
315
- Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
316
- int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
317
- if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
318
- cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
319
- cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
320
- if (cute::elect_one_sync()) {
321
- cute::copy(params.tma_store_O, tOsO, tOgO);
322
- tma_store_arrive();
323
- tma_store_wait<0>();
324
- #pragma unroll
325
- for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
326
- shared_storage.pipelines.barrier_O.arrive(cta_id);
327
- }
328
- }
329
- }
330
- } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence
331
- if (!is_split) {
332
- Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
333
- Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
334
- // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }
335
- GmemTiledCopyO gmem_tiled_copy_O;
336
- auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
337
- Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
338
- // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N)
339
- Tensor tOrO = make_fragment_like(tOsO);
340
- cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
341
- if constexpr (ArchTag::kMinComputeCapability >= 90) {
342
- cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v
343
- #pragma unroll
344
- for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
345
- shared_storage.pipelines.barrier_O.arrive(cta_id);
346
- }
347
- }
348
- if constexpr (!PackGQA) {
349
- // (BLK_M,BLK_K) -> (blk_m,blk_k)
350
- Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
351
- Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO)));
352
- #pragma unroll
353
- for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
354
- Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
355
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
356
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
357
- gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
358
- );
359
- } else {
360
- // If PackGQA, we split the work of compute O_ptr among threads in the same row
361
- PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
362
- }
363
- } else {
364
- Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
365
- Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
366
- // We already arrived on barrier_O earlier if !Use_smem
367
- if constexpr (Use_smem) {
368
- if constexpr (ArchTag::kMinComputeCapability >= 90) {
369
- #pragma unroll
370
- for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
371
- shared_storage.pipelines.barrier_O.arrive(cta_id);
372
- }
373
- }
374
- }
375
- if constexpr (!PackGQA) {
376
- static constexpr int kGmemElemsPerStoreDirect = 2;
377
- cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial> gmem_copy_direct;
378
- // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
379
- Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));
380
- Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
381
- Tensor tOgO = thread_mma.partition_C(gOpartial);
382
- Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout()));
383
- Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
384
- Tensor taccOcO_col = taccOcO_rowcol(_0{}, _);
385
- #pragma unroll
386
- for (int m = 0; m < size(taccOcO_row); ++m) {
387
- if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) {
388
- #pragma unroll
389
- for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) {
390
- if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) {
391
- cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k));
392
- }
393
- }
394
- }
395
- }
396
- } else {
397
- PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
398
- }
399
- }
400
- }
401
- }
402
-
403
- CUTLASS_DEVICE void
404
- store_tail() {
405
- // Don't need to do tma_store_wait<0>() here since we already did in @store
406
- }
407
-
408
- // Write 0 to output and -inf to LSE
409
- CUTLASS_DEVICE void
410
- store_zero(
411
- Params const& params,
412
- int thread_idx,
413
- cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
414
- ) {
415
- static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
416
- auto [m_block, bidh, bidb, split_idx] = block_coord;
417
- int num_splits = get<4>(params.shape_O_packed);
418
- if constexpr (Split && Varlen) {
419
- uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
420
- int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
421
- num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
422
- split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
423
- }
424
- bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
425
-
426
- flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
427
- bool const is_varlen = Varlen && params.cu_seqlens;
428
- int offset_o = seqlen_info.offset;
429
- int seqlen_o = seqlen_info.seqlen;
430
- int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
431
- Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
432
- params.shape_LSE_packed,
433
- !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
434
- Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block));
435
-
436
- static_assert(kBlockM <= NumEpilogueThreads);
437
- if (thread_idx < kBlockM) {
438
- const int row = m_block * kBlockM + thread_idx;
439
- if constexpr (!PackGQA) {
440
- if (row < seqlen_o) { mLSE(row) = -INFINITY; }
441
- } else {
442
- if (row < seqlen_o * qhead_per_khead) {
443
- int m_idx, h_idx;
444
- m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row);
445
- // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord"
446
- mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY;
447
- }
448
- }
449
- }
450
-
451
- // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used,
452
- // since it will not use the value of O if LSE is -inf.
453
- if (!is_split) {
454
- Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
455
-
456
- GmemTiledCopyO gmem_tiled_copy_O;
457
- auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
458
- Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
459
- if constexpr (!PackGQA) {
460
- Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
461
- #pragma unroll
462
- for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
463
- Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
464
- Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
465
- Tensor tOrO = make_fragment_like(tOgO);
466
- cute::clear(tOrO);
467
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
468
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
469
- gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
470
- );
471
- } else {
472
- // If PackGQA, we split the work of compute O_ptr among threads in the same row
473
- using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
474
- Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO)));
475
- cute::clear(tOrO);
476
- PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
477
- }
478
- }
479
-
480
- }
481
-
482
- };
483
-
484
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash.h DELETED
@@ -1,220 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <cuda.h>
8
- #include <vector>
9
-
10
- ////////////////////////////////////////////////////////////////////////////////////////////////////
11
-
12
- struct Qkv_params {
13
- using index_t = int64_t;
14
- // The QKV matrices.
15
- void *__restrict__ q_ptr;
16
- void *__restrict__ k_ptr;
17
- void *__restrict__ v_ptr;
18
-
19
- // The stride between rows of the Q, K and V matrices.
20
- index_t q_batch_stride;
21
- index_t k_batch_stride;
22
- index_t v_batch_stride;
23
- index_t q_row_stride;
24
- index_t k_row_stride;
25
- index_t v_row_stride;
26
- index_t q_head_stride;
27
- index_t k_head_stride;
28
- index_t v_head_stride;
29
- index_t v_dim_stride;
30
-
31
- // The number of heads.
32
- int h, h_k;
33
- };
34
-
35
- ////////////////////////////////////////////////////////////////////////////////////////////////////
36
-
37
- struct Flash_fwd_params : public Qkv_params {
38
- using index_t = int64_t;
39
-
40
- // The O matrix (output).
41
- void * __restrict__ o_ptr;
42
- void * __restrict__ oaccum_ptr;
43
-
44
- // The stride between rows of O.
45
- index_t o_batch_stride;
46
- index_t o_row_stride;
47
- index_t o_head_stride;
48
-
49
- // The pointer to the softmax sum.
50
- void * __restrict__ softmax_lse_ptr;
51
- void * __restrict__ softmax_lseaccum_ptr;
52
-
53
- // For FP8 scaling
54
- float * __restrict__ q_descale_ptr;
55
- float * __restrict__ k_descale_ptr;
56
- float * __restrict__ v_descale_ptr;
57
- index_t q_descale_batch_stride;
58
- index_t q_descale_head_stride;
59
- index_t k_descale_batch_stride;
60
- index_t k_descale_head_stride;
61
- index_t v_descale_batch_stride;
62
- index_t v_descale_head_stride;
63
-
64
- // The dimensions.
65
- int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
66
- int total_q, total_k, total_knew;
67
- int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q
68
- int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim
69
-
70
- // The scaling factors for the kernel.
71
- float scale_softmax;
72
- float softcap;
73
-
74
- // array of length b+1 holding starting offset of each sequence.
75
- int * __restrict__ cu_seqlens_q;
76
- int * __restrict__ cu_seqlens_k;
77
- int * __restrict__ cu_seqlens_knew;
78
- int * __restrict__ leftpad_k;
79
-
80
- // If provided, the actual length of each q/k sequence.
81
- int *__restrict__ seqused_q;
82
- int *__restrict__ seqused_k;
83
-
84
- // The stride between rows of Oaccum.
85
- index_t oaccum_split_stride;
86
- index_t oaccum_batch_stride;
87
- index_t oaccum_row_stride;
88
- index_t oaccum_head_stride;
89
-
90
- // The stride between rows of LSEaccum.
91
- index_t lseaccum_split_stride;
92
- index_t lseaccum_batch_stride;
93
- index_t lseaccum_head_stride;
94
-
95
- // The K_new and V_new matrices.
96
- void * __restrict__ knew_ptr;
97
- void * __restrict__ vnew_ptr;
98
-
99
- // The stride between rows of the Q, K and V matrices.
100
- index_t knew_batch_stride;
101
- index_t vnew_batch_stride;
102
- index_t knew_row_stride;
103
- index_t vnew_row_stride;
104
- index_t knew_head_stride;
105
- index_t vnew_head_stride;
106
-
107
- void *__restrict__ qv_ptr;
108
- index_t qv_batch_stride;
109
- index_t qv_row_stride;
110
- index_t qv_head_stride;
111
-
112
- // The cos and sin matrices for rotary embedding.
113
- void * __restrict__ rotary_cos_ptr;
114
- void * __restrict__ rotary_sin_ptr;
115
- int *__restrict__ seqlens_rotary;
116
-
117
- // The indices to index into the KV cache.
118
- int * __restrict__ kv_batch_idx;
119
-
120
- // Paged KV cache
121
- int * __restrict__ page_table;
122
- index_t page_table_batch_stride;
123
- int page_size;
124
- int num_pages;
125
- bool pagedkv_tma;
126
-
127
- // The dropout probability (probability of keeping an activation).
128
- float p_dropout;
129
- // uint32_t p_dropout_in_uint;
130
- // uint16_t p_dropout_in_uint16_t;
131
- uint8_t p_dropout_in_uint8_t;
132
-
133
- // Scale factor of 1 / (1 - p_dropout).
134
- float rp_dropout;
135
-
136
- // Local window size
137
- int window_size_left, window_size_right;
138
-
139
- // Pointer to the RNG seed (idx 0) and offset (idx 1).
140
- uint64_t * rng_state;
141
-
142
- bool is_bf16;
143
- bool is_fp32;
144
- bool is_e4m3;
145
- bool is_causal;
146
- bool is_local;
147
-
148
- bool is_rotary_interleaved;
149
-
150
- int num_splits; // For split-KV version
151
- bool pack_gqa;
152
-
153
- int * __restrict__ tile_count_semaphore;
154
- // int * __restrict__ num_m_blocks_ptr;
155
- // int * __restrict__ num_n_blocks_ptr;
156
- int * __restrict__ num_splits_dynamic_ptr;
157
- bool skip_scheduler_metadata_computation;
158
-
159
- int arch;
160
- int num_sm;
161
-
162
- // The S extra matrix, (num_heads)
163
- void *__restrict__ s_aux_ptr;
164
- };
165
-
166
- ////////////////////////////////////////////////////////////////////////////////////////////////////
167
-
168
- struct Flash_bwd_params : public Flash_fwd_params {
169
- using index_t = int64_t;
170
-
171
- // The dO and dQKV matrices.
172
- void *__restrict__ do_ptr;
173
- void *__restrict__ dq_ptr;
174
- void *__restrict__ dk_ptr;
175
- void *__restrict__ dv_ptr;
176
-
177
- // To accumulate dQ
178
- void *__restrict__ dq_accum_ptr;
179
- void *__restrict__ dk_accum_ptr;
180
- void *__restrict__ dv_accum_ptr;
181
-
182
- // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
183
- // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
184
- // dv_accum_ptr;
185
-
186
- // The stride between rows of the dO, dQ, dK and dV matrices.
187
- index_t do_batch_stride;
188
- index_t do_row_stride;
189
- index_t do_head_stride;
190
- index_t dq_batch_stride;
191
- index_t dk_batch_stride;
192
- index_t dv_batch_stride;
193
- index_t dq_row_stride;
194
- index_t dk_row_stride;
195
- index_t dv_row_stride;
196
- index_t dq_head_stride;
197
- index_t dk_head_stride;
198
- index_t dv_head_stride;
199
-
200
- // The pointer to the softmax d sum.
201
- void *__restrict__ dsoftmax_sum;
202
- void *__restrict__ softmax_lse_log2_ptr;
203
-
204
- int *__restrict__ dq_semaphore;
205
- int *__restrict__ dk_semaphore;
206
- int *__restrict__ dv_semaphore;
207
-
208
- bool deterministic;
209
- index_t dq_accum_split_stride;
210
- };
211
-
212
- ////////////////////////////////////////////////////////////////////////////////////////////////////
213
-
214
- template <int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
215
- void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
216
- void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl);
217
- template <int Arch, typename T, int kHeadDim, bool Has_softcap>
218
- void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
219
- template <typename T, typename Tpartial, int kBlockK>
220
- void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_api.cpp DELETED
@@ -1,1623 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
6
- #include <torch/nn/functional.h>
7
- #include <torch/version.h> // For TORCH_VERSION* macros
8
- #include <ATen/cuda/CUDAContext.h>
9
- #include <c10/cuda/CUDAGuard.h>
10
-
11
- #include <cutlass/numeric_types.h>
12
-
13
- #include "flash.h"
14
- #include "static_switch.h"
15
- #include "tile_size.h"
16
- #include "heuristics.h"
17
- #include "cuda_check.h"
18
-
19
- // Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909
20
- // This is so that we can pass in torch.dtype as a parameter to the function.
21
- #if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4)
22
-
23
- #include <pybind11/pybind11.h>
24
- #include <pybind11/stl.h>
25
-
26
- namespace pybind11::detail {
27
-
28
- template <>
29
- struct type_caster<at::ScalarType> {
30
- public:
31
- // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
32
- PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype"));
33
- // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType
34
- // cannot be default-initialized, we provide this constructor to explicitly
35
- // initialize that field. The value doesn't matter as it will be overwritten
36
- // after a successful call to load.
37
- type_caster() : value(at::kFloat) {}
38
- bool load(handle src, bool) {
39
- PyObject* obj = src.ptr();
40
- if (THPDtype_Check(obj)) {
41
- value = reinterpret_cast<THPDtype*>(obj)->scalar_type;
42
- return true;
43
- }
44
- return false;
45
- }
46
- static handle cast(
47
- const at::ScalarType& src,
48
- return_value_policy /* policy */,
49
- handle /* parent */) {
50
- return Py_NewRef(torch::getTHPDtype(src));
51
- }
52
- };
53
-
54
- } // namespace pybind11::detail
55
-
56
- #endif
57
-
58
- #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
59
- #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
60
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
61
-
62
- void set_params_fprop(Flash_fwd_params &params,
63
- // sizes
64
- const size_t b,
65
- const size_t seqlen_q,
66
- const size_t seqlen_k,
67
- const size_t seqlen_q_rounded,
68
- const size_t seqlen_k_rounded,
69
- const size_t h,
70
- const size_t h_k,
71
- const size_t d,
72
- const size_t d_rounded,
73
- // device pointers
74
- const at::Tensor q,
75
- const at::Tensor k,
76
- const at::Tensor v,
77
- at::Tensor out,
78
- void *cu_seqlens_q_d,
79
- void *cu_seqlens_k_d,
80
- void *seqused_q,
81
- void *seqused_k,
82
- void *softmax_lse_d,
83
- float p_dropout,
84
- float softmax_scale,
85
- int window_size_left,
86
- int window_size_right,
87
- const float softcap=0.f,
88
- const int sm_margin=0) {
89
-
90
- // Reset the parameters
91
- params = {};
92
-
93
- params.is_bf16 = q.dtype() == torch::kBFloat16;
94
- params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
95
-
96
- // Set the pointers and strides.
97
- params.q_ptr = q.data_ptr();
98
- params.k_ptr = k.data_ptr();
99
- params.v_ptr = v.data_ptr();
100
- // All stride are in elements, not bytes.
101
- params.q_row_stride = q.stride(-3);
102
- params.k_row_stride = k.stride(-3);
103
- params.v_row_stride = v.stride(-3);
104
- params.q_head_stride = q.stride(-2);
105
- params.k_head_stride = k.stride(-2);
106
- params.v_head_stride = v.stride(-2);
107
- params.v_dim_stride = v.stride(-1);
108
- params.o_ptr = out.data_ptr();
109
- params.o_row_stride = out.stride(-3);
110
- params.o_head_stride = out.stride(-2);
111
-
112
- if (cu_seqlens_q_d == nullptr) {
113
- params.q_batch_stride = q.stride(0);
114
- params.o_batch_stride = out.stride(0);
115
- }
116
- if (cu_seqlens_k_d == nullptr) {
117
- params.k_batch_stride = k.stride(0);
118
- params.v_batch_stride = v.stride(0);
119
- }
120
-
121
- params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
122
- params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
123
- params.seqused_q = static_cast<int *>(seqused_q);
124
- params.seqused_k = static_cast<int *>(seqused_k);
125
-
126
- // Softmax sum
127
- params.softmax_lse_ptr = softmax_lse_d;
128
-
129
- // Set the dimensions.
130
- params.b = b;
131
- params.h = h;
132
- params.h_k = h_k;
133
- params.seqlen_q = seqlen_q;
134
- params.seqlen_k = seqlen_k;
135
- params.seqlen_q_rounded = seqlen_q_rounded;
136
- params.seqlen_k_rounded = seqlen_k_rounded;
137
- params.d = d;
138
- params.d_rounded = d_rounded;
139
-
140
- // Set the different scale values.
141
- params.scale_softmax = softmax_scale;
142
- params.softcap = softcap;
143
-
144
- // Set this to probability of keeping an element to simplify things.
145
- params.p_dropout = 1.f - p_dropout;
146
- // Convert p from float to int so we don't have to convert the random uint to float to compare.
147
- // [Minor] We want to round down since when we do the comparison we use <= instead of <
148
- // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
149
- // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
150
- params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
151
- params.rp_dropout = 1.f / params.p_dropout;
152
- TORCH_CHECK(p_dropout < 1.f);
153
- #ifdef FLASHATTENTION_DISABLE_DROPOUT
154
- TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
155
- #endif
156
-
157
- // Causal is the special case where window_size_right == 0 and window_size_left < 0.
158
- // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
159
- params.is_causal = window_size_left < 0 && window_size_right == 0;
160
- params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
161
-
162
- // TODO: check this
163
- if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k - 1; }
164
- if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_q - 1; }
165
- params.window_size_left = window_size_left;
166
- params.window_size_right = window_size_right;
167
-
168
- params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
169
- params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
170
-
171
- #ifdef FLASHATTENTION_DISABLE_LOCAL
172
- TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
173
- #endif
174
- }
175
-
176
- void set_params_dgrad(Flash_bwd_params &params,
177
- // sizes
178
- const size_t b,
179
- const size_t seqlen_q,
180
- const size_t seqlen_k,
181
- const size_t seqlen_q_rounded,
182
- const size_t seqlen_k_rounded,
183
- const size_t h,
184
- const size_t h_k,
185
- const size_t d,
186
- const size_t d_rounded,
187
- // device pointers
188
- const at::Tensor q,
189
- const at::Tensor k,
190
- const at::Tensor v,
191
- const at::Tensor out,
192
- const at::Tensor dout,
193
- at::Tensor dq,
194
- at::Tensor dk,
195
- at::Tensor dv,
196
- void *cu_seqlens_q_d,
197
- void *cu_seqlens_k_d,
198
- void *seqused_q,
199
- void *seqused_k,
200
- void *dq_accum_d,
201
- void *dk_accum_d,
202
- void *dv_accum_d,
203
- void *softmax_lse_d,
204
- void *dsoftmax_sum_d,
205
- float p_dropout,
206
- float softmax_scale,
207
- int window_size_left,
208
- int window_size_right,
209
- const float softcap=0.f,
210
- bool deterministic=false,
211
- int const sm_margin=0) {
212
-
213
- set_params_fprop(params,
214
- b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
215
- q, k, v, out,
216
- cu_seqlens_q_d,
217
- cu_seqlens_k_d,
218
- seqused_q,
219
- seqused_k,
220
- softmax_lse_d,
221
- p_dropout,
222
- softmax_scale,
223
- window_size_left,
224
- window_size_right,
225
- softcap,
226
- sm_margin);
227
-
228
- // Set the pointers and strides.
229
- params.do_ptr = dout.data_ptr();
230
- params.do_row_stride = dout.stride(-3);
231
- params.do_head_stride = dout.stride(-2);
232
- params.dq_ptr = dq.data_ptr();
233
- params.dk_ptr = dk.data_ptr();
234
- params.dv_ptr = dv.data_ptr();
235
- params.dq_row_stride = dq.stride(-3);
236
- params.dk_row_stride = dk.stride(-3);
237
- params.dv_row_stride = dv.stride(-3);
238
- params.dq_head_stride = dq.stride(-2);
239
- params.dk_head_stride = dk.stride(-2);
240
- params.dv_head_stride = dv.stride(-2);
241
-
242
- if (cu_seqlens_q_d == nullptr) {
243
- params.do_batch_stride = dout.stride(0);
244
- params.dq_batch_stride = dq.stride(0);
245
- params.dk_batch_stride = dk.stride(0);
246
- params.dv_batch_stride = dv.stride(0);
247
- }
248
-
249
- params.dq_accum_ptr = dq_accum_d;
250
- params.dk_accum_ptr = dk_accum_d;
251
- params.dv_accum_ptr = dv_accum_d;
252
-
253
- // Softmax sum
254
- params.dsoftmax_sum = dsoftmax_sum_d;
255
-
256
- params.deterministic = deterministic;
257
- }
258
-
259
- void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
260
- // HEADDIM_SWITCH(params.d, [&] {
261
- // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
262
- // });
263
- TORCH_CHECK(params.num_splits >= 1);
264
- ARCH_SWITCH(params.arch, Arch, [&] {
265
- SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
266
- PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] {
267
- PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {
268
- // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation
269
- static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split;
270
- SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
271
- if (!params.is_e4m3) {
272
- if (params.is_bf16) {
273
- #ifndef FLASHATTENTION_DISABLE_HDIM64
274
- if (params.d <= 64) {
275
- if (params.dv > 256 && Arch == 90) {
276
- return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
277
- } else if (params.dv > 64 && Arch == 90) {
278
- return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
279
- } else {
280
- return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
281
- }
282
- }
283
- #endif
284
- #ifndef FLASHATTENTION_DISABLE_HDIM96
285
- if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
286
- #endif
287
- #ifndef FLASHATTENTION_DISABLE_HDIM128
288
- if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
289
- #endif
290
- #ifndef FLASHATTENTION_DISABLE_HDIM192
291
- if (params.d <= 192) {
292
- if (params.dv <= 128 && Arch == 90) {
293
- return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
294
- } else {
295
- return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
296
- }
297
- }
298
- #endif
299
- #ifndef FLASHATTENTION_DISABLE_HDIM256
300
- if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
301
- #endif
302
- } else {
303
- #ifndef FLASHATTENTION_DISABLE_FP16
304
- #ifndef FLASHATTENTION_DISABLE_HDIM64
305
- if (params.d <= 64) {
306
- if (params.dv > 256 && Arch == 90) {
307
- return run_mha_fwd_<Arch, cutlass::half_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
308
- } else if (params.dv > 64 && Arch == 90) {
309
- return run_mha_fwd_<Arch, cutlass::half_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
310
- } else {
311
- return run_mha_fwd_<Arch, cutlass::half_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
312
- }
313
- }
314
- #endif
315
- #ifndef FLASHATTENTION_DISABLE_HDIM96
316
- if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::half_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
317
- #endif
318
- #ifndef FLASHATTENTION_DISABLE_HDIM128
319
- if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::half_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
320
- #endif
321
- #ifndef FLASHATTENTION_DISABLE_HDIM192
322
- if (params.d <= 192) {
323
- if (params.dv <= 128 && Arch == 90) {
324
- return run_mha_fwd_<Arch, cutlass::half_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
325
- } else {
326
- return run_mha_fwd_<Arch, cutlass::half_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
327
- }
328
- }
329
- #endif
330
- #ifndef FLASHATTENTION_DISABLE_HDIM256
331
- if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::half_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
332
- #endif
333
- #else
334
- TORCH_CHECK(false, "This flash attention build does not support FP16.");
335
- #endif
336
- }
337
- } else {
338
- #ifndef FLASHATTENTION_DISABLE_FP8
339
- #ifndef FLASHATTENTION_DISABLE_HDIM64
340
- if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
341
- #endif
342
- #ifndef FLASHATTENTION_DISABLE_HDIM96
343
- if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
344
- #endif
345
- #ifndef FLASHATTENTION_DISABLE_HDIM128
346
- if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
347
- #endif
348
- #ifndef FLASHATTENTION_DISABLE_HDIM192
349
- if (params.d <= 192) {
350
- if (params.dv <= 128 && Arch == 90) {
351
- return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
352
- } else {
353
- return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
354
- }
355
- }
356
- #endif
357
- #ifndef FLASHATTENTION_DISABLE_HDIM256
358
- if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
359
- #endif
360
- #else
361
- TORCH_CHECK(false, "This flash attention build does not support FP8.");
362
- #endif
363
- }
364
- });
365
- });
366
- });
367
- });
368
- });
369
- }
370
-
371
- void run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl=false) {
372
- #ifndef FLASHATTENTION_DISABLE_SPLIT
373
- // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
374
- // so that kBlockM is smaller and we have more parallelism.
375
- if (params.is_fp32) {
376
- if (params.dv <= 64) {
377
- run_mha_fwd_combine_<float, float, 64>(params, stream, enable_pdl);
378
- } else {
379
- run_mha_fwd_combine_<float, float, 128>(params, stream, enable_pdl);
380
- }
381
- } else if (params.is_bf16) {
382
- if (params.dv <= 64) {
383
- run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream, enable_pdl);
384
- } else {
385
- run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream, enable_pdl);
386
- }
387
- } else {
388
- if (params.dv <= 64) {
389
- run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream, enable_pdl);
390
- } else {
391
- run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream, enable_pdl);
392
- }
393
- }
394
- #else
395
- TORCH_CHECK(false, "This flash attention build does not support combine kernels.");
396
- #endif
397
- }
398
-
399
- inline bool get_pagedkv_tma(Flash_fwd_params const& params) {
400
- if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; }
401
- // This needs to match the kernel configs
402
- auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f);
403
- int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
404
- int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90);
405
- // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower,
406
- // at least for MLA.
407
- return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM;
408
- }
409
-
410
- inline bool get_pack_gqa(Flash_fwd_params const& params) {
411
- // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size.
412
- // Has little effect on speed.
413
- if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; }
414
- #ifdef FLASHATTENTION_DISABLE_PACKGQA
415
- return false;
416
- #else
417
- // params.page_table must already be set
418
- if (params.h == params.h_k) { return false; }
419
- // This needs to match the kernel configs
420
- auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
421
- int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
422
- return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
423
- #endif
424
- }
425
-
426
- inline int get_num_splits(Flash_fwd_params const& params) {
427
- #ifdef FLASHATTENTION_DISABLE_SPLIT
428
- return 1;
429
- #else
430
- // Always enable PackGQA for Split
431
- // params.page_table must already be set
432
- // This needs to match the kernel configs
433
- bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
434
- auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
435
- // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
436
- // has not been set here. It's OK though because we might just underestimate kBlockN a bit
437
- auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
438
- int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
439
- int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
440
- int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);
441
- // If is_local, we're not going to load all of seqlen_k
442
- int const seqlen_k_loaded = !params.is_local
443
- ? params.seqlen_k
444
- : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));
445
- int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;
446
- int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;
447
- int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2);
448
- // Always enable PackGQA for Split
449
- // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits.
450
- // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending
451
- // that batch = 1.
452
- int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks;
453
- return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128);
454
- #endif
455
- }
456
-
457
- inline int get_max_headdim() {
458
- #ifndef FLASHATTENTION_DISABLE_HDIM256
459
- return 256;
460
- #endif
461
- #ifndef FLASHATTENTION_DISABLE_HDIM192
462
- return 192;
463
- #endif
464
- #ifndef FLASHATTENTION_DISABLE_HDIM128
465
- return 128;
466
- #endif
467
- #ifndef FLASHATTENTION_DISABLE_HDIM96
468
- return 96;
469
- #endif
470
- #ifndef FLASHATTENTION_DISABLE_HDIM64
471
- return 64;
472
- #endif
473
- return 0;
474
- }
475
-
476
- inline int round_up_headdim(int head_size) {
477
- #ifndef FLASHATTENTION_DISABLE_HDIM64
478
- if (head_size <= 64) { return 64; }
479
- #endif
480
- #ifndef FLASHATTENTION_DISABLE_HDIM96
481
- if (head_size <= 96) { return 96; }
482
- #endif
483
- #ifndef FLASHATTENTION_DISABLE_HDIM128
484
- if (head_size <= 128) { return 128; }
485
- #endif
486
- #ifndef FLASHATTENTION_DISABLE_HDIM192
487
- if (head_size <= 192) { return 192; }
488
- #endif
489
- #ifndef FLASHATTENTION_DISABLE_HDIM256
490
- if (head_size <= 256) { return 256; }
491
- #endif
492
- return 256;
493
- }
494
-
495
- inline int round_up_headdimv(int head_size) {
496
- if (head_size <= 64) { return 64; }
497
- if (head_size <= 96) { return 96; }
498
- if (head_size <= 128) { return 128; }
499
- if (head_size <= 192) { return 192; }
500
- if (head_size <= 256) { return 256; }
501
- return 512;
502
- }
503
-
504
- // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
505
- at::Tensor
506
- mha_fwd_get_scheduler_metadata(
507
- int batch_size,
508
- int max_seqlen_q,
509
- int max_seqlen_k,
510
- int num_heads,
511
- int num_heads_k,
512
- int headdim,
513
- int headdim_v,
514
- at::ScalarType qkv_dtype,
515
- const at::Tensor &seqused_k, // b
516
- std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
517
- std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
518
- std::optional<const at::Tensor> &cu_seqlens_k_new_, // b+1
519
- std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
520
- std::optional<const at::Tensor> &leftpad_k_, // b
521
- std::optional<int> page_size,
522
- int max_seqlen_k_new, // 0 means we're not appending new KV
523
- bool is_causal,
524
- int window_size_left,
525
- int window_size_right,
526
- bool has_softcap,
527
- int num_splits,
528
- std::optional<bool> pack_gqa_,
529
- int const sm_margin
530
- ) {
531
-
532
- TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn,
533
- "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
534
- TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
535
-
536
- // Reset the parameters
537
- Flash_fwd_params params{};
538
- params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16;
539
- params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn;
540
- params.b = batch_size;
541
- params.seqlen_q = max_seqlen_q;
542
- params.seqlen_k = max_seqlen_k;
543
- params.h = num_heads;
544
- params.h_k = num_heads_k;
545
- params.d = headdim;
546
- params.dv = headdim_v;
547
- params.d_rounded = round_up_headdim(headdim);
548
- params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v);
549
- params.seqlen_knew = max_seqlen_k_new;
550
-
551
- bool const is_varlen_q = cu_seqlens_q_.has_value();
552
- params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr<int>() : nullptr;
553
- bool const is_varlen_k = cu_seqlens_k_.has_value();
554
- params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr<int>() : nullptr;
555
- params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr<int>() : nullptr;
556
- params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr<int>() : nullptr;
557
- params.seqused_k = seqused_k.data_ptr<int>();
558
- params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr<int>() : nullptr;
559
- params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast<int*>(1) : nullptr;
560
- if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }
561
- if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }
562
- // causal=true is the same as causal=false in this case
563
- if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) {
564
- // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
565
- if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) {
566
- is_causal = false;
567
- }
568
- }
569
- if (is_causal) { window_size_right = 0; }
570
-
571
- params.is_causal = window_size_left < 0 && window_size_right == 0;
572
- params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
573
- if (window_size_left < 0 && window_size_right >= 0) { window_size_left = max_seqlen_k - 1; }
574
- if (window_size_left >= 0 && window_size_right < 0) { window_size_right = max_seqlen_q - 1; }
575
- params.window_size_left = window_size_left;
576
- params.window_size_right = window_size_right;
577
- params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
578
- params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
579
- params.softcap = has_softcap ? 1.0f : 0.0f;
580
-
581
- params.page_size = page_size.has_value() ? page_size.value() : 1;
582
- params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1);
583
-
584
- bool const use_dynamic_split = params.b <= 992;
585
- params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
586
-
587
- params.pagedkv_tma = get_pagedkv_tma(params);
588
- // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits)
589
- params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
590
- params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
591
- // Always enable PackGQA for Split
592
- params.pack_gqa = params.num_splits > 1;
593
-
594
- bool is_varlen = true;
595
-
596
- // Otherwise the kernel will be launched from cuda:0 device
597
- // Cast to char to avoid compiler warning about narrowing
598
- at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()};
599
-
600
- auto opts = seqused_k.options();
601
- // This needs to be set after get_num_splits
602
- at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
603
- bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1;
604
- if (scheduler_needs_semaphore || use_dynamic_split) {
605
- tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32));
606
- if (scheduler_needs_semaphore) {
607
- if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing
608
- params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
609
- } else {
610
- params.tile_count_semaphore = nullptr;
611
- }
612
- params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
613
- }
614
-
615
- if (params.num_splits_dynamic_ptr) {
616
- auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
617
- auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
618
- int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
619
- int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
620
- auto stream = at::cuda::getCurrentCUDAStream().stream();
621
- prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/);
622
- CHECK_CUDA_KERNEL_LAUNCH();
623
- }
624
- return tile_count_semaphore;
625
- }
626
-
627
- // b: batch_size
628
- // b_k: batch_size_k
629
- // s_q: seqlen_q
630
- // s_k: seqlen_k
631
- // s_k_new: seqlen_k_new
632
- // h: num_heads
633
- // h_k: num_heads_k
634
- // d: head_size
635
- std::vector<at::Tensor>
636
- mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
637
- const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
638
- const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table.
639
- std::optional<const at::Tensor> &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
640
- std::optional<const at::Tensor> &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
641
- std::optional<const at::Tensor> &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
642
- std::optional<at::Tensor> &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
643
- std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
644
- std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
645
- std::optional<const at::Tensor> &cu_seqlens_k_new_, // b+1
646
- std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
647
- std::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
648
- std::optional<int> max_seqlen_q_,
649
- // TODO: check if we need max_seqlen_k
650
- std::optional<int> max_seqlen_k_,
651
- std::optional<const at::Tensor> &page_table_, // (b_k, max_num_pages_per_seq)
652
- std::optional<const at::Tensor> &kv_batch_idx_, // b. indices to index into the KV cache
653
- std::optional<const at::Tensor> &leftpad_k_, // b
654
- std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
655
- std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
656
- std::optional<const at::Tensor> &seqlens_rotary_, // b
657
- std::optional<at::Tensor> &q_descale_, // (b, h_k), not (b, h)
658
- std::optional<at::Tensor> &k_descale_, // (b, h_k)
659
- std::optional<at::Tensor> &v_descale_, // (b, h_k)
660
- float const softmax_scale,
661
- bool is_causal,
662
- int window_size_left,
663
- int window_size_right,
664
- float const softcap,
665
- bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
666
- std::optional<at::Tensor> &scheduler_metadata_, // (b + 1)
667
- int num_splits,
668
- std::optional<bool> pack_gqa_,
669
- int const sm_margin,
670
- std::optional<const at::Tensor> &s_aux_ // (h)
671
- ) {
672
-
673
- auto dprops = at::cuda::getCurrentDeviceProperties();
674
- bool is_sm8x = dprops->major >= 8;
675
- TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
676
-
677
- auto q_type = q.scalar_type();
678
- TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
679
- "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
680
- if (dprops->major < 9) {
681
- TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
682
- "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type");
683
- }
684
- TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
685
- TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
686
-
687
- CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
688
-
689
- TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
690
- TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
691
- TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
692
-
693
- at::Tensor page_table;
694
- const bool paged_KV = page_table_.has_value();
695
- if (paged_KV) {
696
- page_table = page_table_.value();
697
- CHECK_DEVICE(page_table);
698
- TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32");
699
- TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension");
700
- }
701
-
702
- at::Tensor cu_seqlens_q;
703
- bool const is_varlen_q = cu_seqlens_q_.has_value();
704
- if (is_varlen_q) {
705
- cu_seqlens_q = cu_seqlens_q_.value();
706
- CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
707
- TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
708
- TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
709
- }
710
- at::Tensor cu_seqlens_k;
711
- bool const is_varlen_k = cu_seqlens_k_.has_value();
712
- if (is_varlen_k) {
713
- cu_seqlens_k = cu_seqlens_k_.value();
714
- CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
715
- TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
716
- TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
717
- TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported");
718
- TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported");
719
- }
720
-
721
- auto const sizes = q.sizes();
722
- const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
723
- int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
724
- int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
725
- int num_heads = q.size(-2);
726
- int const head_size = q.size(-1);
727
- int const head_size_v = v.size(-1);
728
- int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);
729
- int const num_pages = !paged_KV ? 0 : k.size(0);
730
- int const page_size = !paged_KV ? 1 : k.size(1);
731
- int const seqlen_k = !max_seqlen_k_.has_value() ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();
732
- int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
733
- int const num_heads_k = k.size(-2);
734
- int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);
735
- if (!kv_batch_idx_.has_value()) {
736
- TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
737
- }
738
- int const max_headdim = get_max_headdim();
739
- TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
740
- TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
741
- if (head_size_v != head_size) {
742
- TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) ||
743
- (head_size <= 64 && head_size_v <= 512),
744
- "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], "
745
- "or (Q/K <= 64 and V <= 512).");
746
- TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim");
747
- if (head_size_v > 256) {
748
- TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
749
- "HeaddimV > 256 requires fp16 and bf16 data type");
750
- }
751
- }
752
-
753
- // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
754
- // TODO: check this
755
- if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
756
- if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
757
- // causal=true is the same as causal=false in this case
758
- if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) {
759
- // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
760
- if ((head_size <= 64 || head_size > 128) || !paged_KV) {
761
- is_causal = false;
762
- }
763
- }
764
- if (is_causal) { window_size_right = 0; }
765
- // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_fprop will set params.is_causal=true.
766
- // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM.
767
- is_causal = window_size_left < 0 && window_size_right == 0;
768
-
769
- if (!is_varlen_q) {
770
- CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
771
- } else {
772
- CHECK_SHAPE(q, total_q, num_heads, head_size);
773
- CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
774
- }
775
- if (!paged_KV) {
776
- if (!is_varlen_k) {
777
- CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size);
778
- CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v);
779
- } else {
780
- CHECK_SHAPE(k, total_k, num_heads_k, head_size);
781
- CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);
782
- CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
783
- }
784
- } else {
785
- CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);
786
- CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v);
787
- CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);
788
- }
789
-
790
- if (seqused_q_.has_value()){
791
- auto seqused_q = seqused_q_.value();
792
- TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
793
- CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
794
- CHECK_SHAPE(seqused_q, batch_size);
795
- }
796
- if (seqused_k_.has_value()) {
797
- auto seqused_k = seqused_k_.value();
798
- TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
799
- CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
800
- CHECK_SHAPE(seqused_k, batch_size);
801
- }
802
-
803
- if (leftpad_k_.has_value()) {
804
- auto leftpad_k = leftpad_k_.value();
805
- TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
806
- CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k);
807
- CHECK_SHAPE(leftpad_k, batch_size);
808
- }
809
-
810
- // This is what we will template on
811
- bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value();
812
- #ifdef FLASHATTENTION_DISABLE_VARLEN
813
- TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
814
- #endif
815
-
816
- int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
817
- TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment));
818
- TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment));
819
-
820
- auto opts = q.options();
821
- auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
822
- at::Tensor out;
823
- if (out_.has_value()) {
824
- out = out_.value();
825
- TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
826
- CHECK_DEVICE(out);
827
- TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
828
- if (!is_varlen_q) {
829
- CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);
830
- } else {
831
- CHECK_SHAPE(out, total_q, num_heads, head_size_v);
832
- }
833
- } else {
834
- out = !is_varlen_q
835
- ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type))
836
- : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type));
837
- }
838
-
839
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
840
- int const head_size_rounded = round_up_headdim(head_size);
841
- int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v);
842
- int const seqlen_q_rounded = round_multiple(seqlen_q, 128);
843
- int const seqlen_k_rounded = round_multiple(seqlen_k, 128);
844
-
845
- // Otherwise the kernel will be launched from cuda:0 device
846
- // Cast to char to avoid compiler warning about narrowing
847
- at::cuda::CUDAGuard device_guard{(char)q.get_device()};
848
-
849
- at::Tensor softmax_lse;
850
- if (!is_varlen_q) {
851
- softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
852
- } else {
853
- softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
854
- }
855
-
856
- Flash_fwd_params params;
857
- set_params_fprop(params,
858
- batch_size,
859
- seqlen_q, seqlen_k,
860
- seqlen_q_rounded, seqlen_k_rounded,
861
- num_heads, num_heads_k,
862
- head_size, head_size_rounded,
863
- q, k, v, out,
864
- !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
865
- !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
866
- seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
867
- seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
868
- softmax_lse.data_ptr(),
869
- /*p_dropout=*/0.f,
870
- softmax_scale,
871
- window_size_left,
872
- window_size_right,
873
- softcap,
874
- sm_margin);
875
- params.total_q = total_q;
876
- params.total_k = total_k;
877
- params.b_k = batch_size_k;
878
- params.dv = head_size_v;
879
- params.dv_rounded = head_size_v_rounded;
880
- if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma
881
- params.leftpad_k = static_cast<int *>(leftpad_k_.value().data_ptr());
882
- }
883
- if (paged_KV) {
884
- params.page_table = page_table.data_ptr<int>();
885
- params.page_table_batch_stride = page_table.stride(0);
886
- }
887
- params.page_size = page_size;
888
- params.num_pages = num_pages;
889
-
890
- if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma
891
- at::Tensor k_new, v_new;
892
- TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in");
893
- TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in");
894
- TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache");
895
- at::Tensor cu_seqlens_k_new;
896
- bool const is_varlen_k_new = cu_seqlens_k_new_.has_value();
897
- if (is_varlen_k_new) {
898
- cu_seqlens_k_new = cu_seqlens_k_new_.value();
899
- CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new);
900
- TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32");
901
- }
902
- k_new = k_new_.value();
903
- v_new = v_new_.value();
904
- TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query");
905
- TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query");
906
- CHECK_DEVICE(k_new); CHECK_DEVICE(v_new);
907
- TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension");
908
- TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension");
909
- // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new
910
- int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0;
911
- int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0);
912
- if (!is_varlen_k_new) {
913
- CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size);
914
- CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v);
915
- } else {
916
- CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size);
917
- CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v);
918
- CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1);
919
- }
920
- params.seqlen_knew = seqlen_k_new;
921
- params.total_knew = total_k_new;
922
- params.knew_ptr = k_new.data_ptr();
923
- params.vnew_ptr = v_new.data_ptr();
924
- // All stride are in elements, not bytes.
925
- params.knew_row_stride = k_new.stride(-3);
926
- params.vnew_row_stride = v_new.stride(-3);
927
- params.knew_head_stride = k_new.stride(-2);
928
- params.vnew_head_stride = v_new.stride(-2);
929
- if (!is_varlen_k_new) {
930
- params.knew_batch_stride = k_new.stride(0);
931
- params.vnew_batch_stride = v_new.stride(0);
932
- }
933
- if (is_varlen_k_new) {
934
- params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());
935
- }
936
- }
937
-
938
- // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel
939
- bool const use_dynamic_split = is_varlen && params.b <= 992;
940
- // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it
941
- params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
942
-
943
- params.pagedkv_tma = get_pagedkv_tma(params);
944
- // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits)
945
- params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
946
- params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
947
- // Always enable PackGQA for Split
948
- params.pack_gqa = params.num_splits > 1;
949
-
950
- // This needs to be set after get_num_splits
951
- at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
952
- // We don't use the persistent scheduler if Split and not Varlen
953
- bool const scheduler_needs_semaphore = params.arch >= 90
954
- ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
955
- : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
956
- if (scheduler_needs_semaphore || use_dynamic_split) {
957
- int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b;
958
- params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value();
959
- if (scheduler_metadata_.has_value()) {
960
- at::Tensor scheduler_metadata = scheduler_metadata_.value();
961
- CHECK_DEVICE(scheduler_metadata);
962
- CHECK_SHAPE(scheduler_metadata, metadata_size);
963
- CHECK_CONTIGUOUS(scheduler_metadata);
964
- TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32");
965
- tile_count_semaphore = scheduler_metadata;
966
- } else {
967
- tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32));
968
- }
969
- if (scheduler_needs_semaphore && !use_dynamic_split) {
970
- tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing
971
- }
972
- params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr<int>() : nullptr;
973
- params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
974
- }
975
-
976
- if (q_v_.has_value()) {
977
- TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64");
978
- TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
979
- "q_v is only supported for fp16 and bf16 data type");
980
- TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs");
981
- at::Tensor q_v = q_v_.value();
982
- TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query");
983
- CHECK_DEVICE(q_v);
984
- TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension");
985
- if (!is_varlen_q) {
986
- CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v);
987
- } else {
988
- CHECK_SHAPE(q_v, total_q, num_heads, head_size_v);
989
- }
990
- params.qv_ptr = q_v.data_ptr();
991
- // All stride are in elements, not bytes.
992
- params.qv_row_stride = q_v.stride(-3);
993
- params.qv_head_stride = q_v.stride(-2);
994
- if (!is_varlen_q) {
995
- params.qv_batch_stride = q_v.stride(0);
996
- }
997
- }
998
-
999
- if (rotary_cos_.has_value()) {
1000
- TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
1001
- auto rotary_cos = rotary_cos_.value();
1002
- CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos);
1003
- params.rotary_dim = rotary_cos.size(1) * 2;
1004
- TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
1005
- TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
1006
- const int seqlen_ro = rotary_cos.size(0);
1007
- if (paged_KV) {
1008
- TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
1009
- }
1010
- CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
1011
- TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
1012
-
1013
- TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
1014
- auto rotary_sin = rotary_sin_.value();
1015
- CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin);
1016
- CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
1017
- TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
1018
- params.rotary_cos_ptr = rotary_cos.data_ptr();
1019
- params.rotary_sin_ptr = rotary_sin.data_ptr();
1020
- params.is_rotary_interleaved = is_rotary_interleaved;
1021
- if (seqlens_rotary_.has_value()) {
1022
- at::Tensor seqlens_rotary = seqlens_rotary_.value();
1023
- CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary);
1024
- TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32");
1025
- CHECK_SHAPE(seqlens_rotary, batch_size);
1026
- params.seqlens_rotary = seqlens_rotary.data_ptr<int>();
1027
- }
1028
- } else {
1029
- params.rotary_dim = 0;
1030
- }
1031
-
1032
- if (kv_batch_idx_.has_value()) {
1033
- auto kv_batch_idx = kv_batch_idx_.value();
1034
- CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx);
1035
- TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32");
1036
- params.kv_batch_idx = reinterpret_cast<int *>(kv_batch_idx.data_ptr());
1037
- }
1038
-
1039
- at::Tensor out_accum, softmax_lse_accum;
1040
- auto outaccum_type = at::ScalarType::Float;
1041
- if (params.num_splits > 1) {
1042
- TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
1043
- if (!is_varlen_q) {
1044
- out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type));
1045
- softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
1046
- params.oaccum_batch_stride = out_accum.stride(1);
1047
- params.lseaccum_batch_stride = softmax_lse_accum.stride(1);
1048
- } else {
1049
- out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type));
1050
- softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
1051
- }
1052
- params.is_fp32 = false;
1053
- params.oaccum_ptr = out_accum.data_ptr();
1054
- params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
1055
- params.oaccum_split_stride = out_accum.stride(0);
1056
- params.oaccum_row_stride = out_accum.stride(-2);
1057
- params.oaccum_head_stride = out_accum.stride(-3);
1058
- params.lseaccum_split_stride = softmax_lse_accum.stride(0);
1059
- params.lseaccum_head_stride = softmax_lse_accum.stride(-2);
1060
- }
1061
-
1062
- if (q_type == at::ScalarType::Float8_e4m3fn) {
1063
- if (q_descale_.has_value()) {
1064
- auto q_descale = q_descale_.value();
1065
- CHECK_DEVICE(q_descale);
1066
- CHECK_SHAPE(q_descale, batch_size, num_heads_k);
1067
- params.q_descale_ptr = q_descale.data_ptr<float>();
1068
- params.q_descale_batch_stride = q_descale.stride(0);
1069
- params.q_descale_head_stride = q_descale.stride(1);
1070
- } else {
1071
- params.q_descale_ptr = nullptr;
1072
- }
1073
- if (k_descale_.has_value()) {
1074
- auto k_descale = k_descale_.value();
1075
- CHECK_DEVICE(k_descale);
1076
- CHECK_SHAPE(k_descale, batch_size, num_heads_k);
1077
- params.k_descale_ptr = k_descale.data_ptr<float>();
1078
- params.k_descale_batch_stride = k_descale.stride(0);
1079
- params.k_descale_head_stride = k_descale.stride(1);
1080
- } else {
1081
- params.k_descale_ptr = nullptr;
1082
- }
1083
- if (v_descale_.has_value()) {
1084
- auto v_descale = v_descale_.value();
1085
- CHECK_DEVICE(v_descale);
1086
- CHECK_SHAPE(v_descale, batch_size, num_heads_k);
1087
- params.v_descale_ptr = v_descale.data_ptr<float>();
1088
- params.v_descale_batch_stride = v_descale.stride(0);
1089
- params.v_descale_head_stride = v_descale.stride(1);
1090
- } else {
1091
- params.v_descale_ptr = nullptr;
1092
- }
1093
- }
1094
-
1095
- if(s_aux_.has_value()) {
1096
- auto s_aux = s_aux_.value();
1097
- TORCH_CHECK(s_aux.scalar_type() == at::ScalarType::BFloat16,
1098
- "We only support bf16 dtype for S extra.");
1099
- CHECK_DEVICE(s_aux);
1100
- CHECK_SHAPE(s_aux, num_heads);
1101
- CHECK_CONTIGUOUS(s_aux);
1102
- params.s_aux_ptr = s_aux.data_ptr();
1103
- } else {
1104
- params.s_aux_ptr = nullptr;
1105
- }
1106
-
1107
- #ifdef FLASHATTENTION_DISABLE_LOCAL
1108
- TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
1109
- #endif
1110
- #ifdef FLASHATTENTION_DISABLE_SOFTCAP
1111
- TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
1112
- #endif
1113
- #ifdef FLASHATTENTION_DISABLE_SPLIT
1114
- TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
1115
- #endif
1116
- #ifdef FLASHATTENTION_DISABLE_PACKGQA
1117
- TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa.");
1118
- #endif
1119
- #ifdef FLASHATTENTION_DISABLE_PAGEDKV
1120
- TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV.");
1121
- #endif
1122
- #ifdef FLASHATTENTION_DISABLE_APPENDKV
1123
- TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV.");
1124
- #endif
1125
-
1126
- if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {
1127
- auto stream = at::cuda::getCurrentCUDAStream().stream();
1128
- run_mha_fwd(params, stream);
1129
- if (params.num_splits > 1) {
1130
- if (out_type == at::ScalarType::BFloat16) {
1131
- // Since we want output in BF16. Otherwise fwd_combine will output to FP16
1132
- params.is_bf16 = true;
1133
- }
1134
- // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
1135
- // and seqlen = total_q, and don't need to dispatch to Varlen there.
1136
- // However, with dynamic split, each row needs to know which batch it belongs to
1137
- // to read the number of splits, so we just use the varlen version of combine kernel.
1138
- // if (is_varlen_q && !seqused_q_.has_value()) {
1139
- // if (is_varlen_q) {
1140
- // params.b = 1;
1141
- // params.seqlen_q = total_q;
1142
- // }
1143
- // This will zero out the semaphore if needed
1144
- run_mha_fwd_combine(params, stream, true /*enable_pdl*/);
1145
- } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) {
1146
- // need to zero out the semaphore in this case
1147
- tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_();
1148
- }
1149
- } else if (total_q > 0 && num_heads_k > 0) {
1150
- // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
1151
- out.zero_();
1152
- softmax_lse.fill_(std::numeric_limits<float>::infinity());
1153
- }
1154
-
1155
- // return {out, softmax_lse};
1156
- return {out, softmax_lse, out_accum, softmax_lse_accum};
1157
- }
1158
-
1159
- void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
1160
- #ifndef FLASHATTENTION_DISABLE_BACKWARD
1161
- // FP16_SWITCH(!params.is_bf16, [&] {
1162
- // HEADDIM_SWITCH(params.d, [&] {
1163
- // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
1164
- // });
1165
- // });
1166
- ARCH_SWITCH(params.arch, Arch, [&] {
1167
- SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
1168
- if (!params.is_bf16) {
1169
- #ifndef FLASHATTENTION_DISABLE_FP16
1170
- #ifndef FLASHATTENTION_DISABLE_HDIM64
1171
- if (params.d <= 64) { return run_mha_bwd_<Arch, cutlass::half_t, 64, Has_softcap>(params, stream); }
1172
- #endif
1173
- #ifndef FLASHATTENTION_DISABLE_HDIM96
1174
- if (params.d <= 96) { return run_mha_bwd_<Arch, cutlass::half_t, 96, Has_softcap>(params, stream); }
1175
- #endif
1176
- #ifndef FLASHATTENTION_DISABLE_HDIM128
1177
- if (params.d <= 128) { return run_mha_bwd_<Arch, cutlass::half_t, 128, Has_softcap>(params, stream); }
1178
- #endif
1179
- #ifndef FLASHATTENTION_DISABLE_HDIM192
1180
- if (params.d <= 192) { return run_mha_bwd_<Arch, cutlass::half_t, 192, Has_softcap>(params, stream); }
1181
- #endif
1182
- #ifndef FLASHATTENTION_DISABLE_HDIM256
1183
- if (params.d <= 256) { return run_mha_bwd_<Arch, cutlass::half_t, 256, Has_softcap>(params, stream); }
1184
- #endif
1185
- #else
1186
- TORCH_CHECK(false, "This flash attention build does not support FP16.");
1187
- #endif
1188
- } else {
1189
- #ifndef FLASHATTENTION_DISABLE_HDIM64
1190
- if (params.d <= 64) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 64, Has_softcap>(params, stream); }
1191
- #endif
1192
- #ifndef FLASHATTENTION_DISABLE_HDIM96
1193
- if (params.d <= 96) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 96, Has_softcap>(params, stream); }
1194
- #endif
1195
- #ifndef FLASHATTENTION_DISABLE_HDIM128
1196
- if (params.d <= 128) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 128, Has_softcap>(params, stream); }
1197
- #endif
1198
- #ifndef FLASHATTENTION_DISABLE_HDIM192
1199
- if (params.d <= 192) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 192, Has_softcap>(params, stream); }
1200
- #endif
1201
- #ifndef FLASHATTENTION_DISABLE_HDIM256
1202
- if (params.d <= 256) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 256, Has_softcap>(params, stream); }
1203
- #endif
1204
- }
1205
- });
1206
- });
1207
- #endif
1208
- }
1209
-
1210
-
1211
- // b: batch_size
1212
- // s_q: seqlen_q
1213
- // s_k: seqlen_k
1214
- // h: num_heads
1215
- // h_k: num_heads_k
1216
- // d: head_size
1217
- std::vector<at::Tensor> mha_bwd(
1218
- const at::Tensor &dout, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1219
- const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1220
- const at::Tensor &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
1221
- const at::Tensor &v, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
1222
- const at::Tensor &out, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1223
- const at::Tensor &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q
1224
- std::optional<at::Tensor> &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1225
- std::optional<at::Tensor> &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
1226
- std::optional<at::Tensor> &dv_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
1227
- std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
1228
- std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
1229
- std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
1230
- std::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
1231
- std::optional<int> max_seqlen_q_,
1232
- std::optional<int> max_seqlen_k_,
1233
- float const softmax_scale,
1234
- bool is_causal,
1235
- int window_size_left,
1236
- int window_size_right,
1237
- float const softcap,
1238
- bool const deterministic,
1239
- int const sm_margin) {
1240
-
1241
- #ifdef FLASHATTENTION_DISABLE_BACKWARD
1242
- TORCH_CHECK(false, "This flash attention build does not support backward.");
1243
- #endif
1244
-
1245
- auto dprops = at::cuda::getCurrentDeviceProperties();
1246
- bool is_sm8x = dprops->major >= 8;
1247
- TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
1248
-
1249
- auto q_type = q.dtype();
1250
- TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
1251
- "FlashAttention only support fp16 and bf16 data type");
1252
- TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
1253
- TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
1254
- TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
1255
- TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
1256
-
1257
- CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
1258
- CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
1259
-
1260
- TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1261
- TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1262
- TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1263
- TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
1264
- TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
1265
-
1266
- at::Tensor cu_seqlens_q;
1267
- bool const is_varlen_q = cu_seqlens_q_.has_value();
1268
- if (is_varlen_q) {
1269
- cu_seqlens_q = cu_seqlens_q_.value();
1270
- CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
1271
- TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
1272
- TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
1273
- }
1274
- at::Tensor cu_seqlens_k;
1275
- bool const is_varlen_k = cu_seqlens_k_.has_value();
1276
- if (is_varlen_k) {
1277
- cu_seqlens_k = cu_seqlens_k_.value();
1278
- CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
1279
- TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
1280
- TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
1281
- }
1282
- // This is what we will template on
1283
- bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();
1284
- #ifdef FLASHATTENTION_DISABLE_VARLEN
1285
- TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
1286
- #endif
1287
-
1288
- auto const sizes = q.sizes();
1289
- int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
1290
- int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
1291
- int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
1292
- int const num_heads = q.size(-2);
1293
- int const head_size = q.size(-1);
1294
- int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value();
1295
- int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
1296
- int const num_heads_k = k.size(-2);
1297
- TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
1298
- int const max_headdim = get_max_headdim();
1299
- TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
1300
- TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
1301
-
1302
- // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
1303
- if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
1304
- if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
1305
- if (is_causal) { window_size_right = 0; }
1306
- // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.
1307
- // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).
1308
- is_causal = window_size_left < 0 && window_size_right == 0;
1309
-
1310
- int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
1311
- int const head_size_rounded = round_up_headdim(head_size);
1312
- // Very important that these match the kernel configs
1313
- bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
1314
- int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
1315
- : (head_size_rounded <= 96 ? 64
1316
- : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)
1317
- : 64));
1318
- int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
1319
- int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
1320
- int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
1321
- int const kBlockN_sm90 = head_size_rounded <= 128
1322
- ? 128
1323
- : (head_size_rounded <= 192 ? 96 : 80);
1324
- int const kBlockN_sm80 = head_size_rounded <= 128
1325
- ? 128
1326
- : (head_size_rounded <= 192 ? 80 : 64);
1327
- int const kBlockN_sm86 = head_size_rounded <= 64 ? 128
1328
- : (head_size_rounded <= 96 ? 128
1329
- : (head_size_rounded <= 128 ? 96
1330
- : (head_size_rounded <= 192 ? 64 : 64)));
1331
- int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
1332
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1333
- int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
1334
- int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
1335
- int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);
1336
- int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);
1337
-
1338
- if (!is_varlen_q) {
1339
- CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
1340
- CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
1341
- CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
1342
- } else {
1343
- CHECK_SHAPE(q, total_q, num_heads, head_size);
1344
- CHECK_SHAPE(out, total_q, num_heads, head_size);
1345
- CHECK_SHAPE(dout, total_q, num_heads, head_size);
1346
- CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
1347
- }
1348
- if (!is_varlen_k) {
1349
- CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
1350
- CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
1351
- } else {
1352
- CHECK_SHAPE(k, total_k, num_heads_k, head_size);
1353
- CHECK_SHAPE(v, total_k, num_heads_k, head_size);
1354
- CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
1355
- }
1356
-
1357
- if (seqused_q_.has_value()){
1358
- auto seqused_q = seqused_q_.value();
1359
- TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
1360
- CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
1361
- CHECK_SHAPE(seqused_q, batch_size);
1362
- }
1363
- if (seqused_k_.has_value()){
1364
- auto seqused_k = seqused_k_.value();
1365
- TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
1366
- CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
1367
- CHECK_SHAPE(seqused_k, batch_size);
1368
- }
1369
-
1370
- at::Tensor dq, dk, dv;
1371
- if (dq_.has_value()) {
1372
- dq = dq_.value();
1373
- TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
1374
- CHECK_DEVICE(dq);
1375
- TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
1376
- if (!is_varlen_q) {
1377
- CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
1378
- } else {
1379
- CHECK_SHAPE(dq, total_q, num_heads, head_size);
1380
- }
1381
- } else {
1382
- dq = torch::empty_like(q);
1383
- }
1384
- if (dk_.has_value()) {
1385
- dk = dk_.value();
1386
- TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
1387
- CHECK_DEVICE(dk);
1388
- TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
1389
- if (!is_varlen_k) {
1390
- CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
1391
- } else {
1392
- CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
1393
- }
1394
- } else {
1395
- dk = torch::empty_like(k);
1396
- }
1397
- if (dv_.has_value()) {
1398
- dv = dv_.value();
1399
- TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
1400
- CHECK_DEVICE(dv);
1401
- TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
1402
- if (!is_varlen_k) {
1403
- CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
1404
- } else {
1405
- CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
1406
- }
1407
- } else {
1408
- dv = torch::empty_like(v);
1409
- }
1410
-
1411
- // Otherwise the kernel will be launched from cuda:0 device
1412
- // Cast to char to avoid compiler warning about narrowing
1413
- at::cuda::CUDAGuard device_guard{(char)q.get_device()};
1414
-
1415
- auto opts = q.options();
1416
- // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
1417
- at::Tensor softmax_d, softmax_lse_log2;
1418
- if (!is_varlen) {
1419
- // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
1420
- softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
1421
- softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
1422
- } else {
1423
- softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
1424
- softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
1425
- }
1426
- at::Tensor dq_accum, dk_accum, dv_accum;
1427
- if (!is_varlen) {
1428
- dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1429
- } else {
1430
- dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1431
- }
1432
- if (num_heads_k != num_heads) { // MQA / GQA
1433
- if (!is_varlen) {
1434
- dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1435
- dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1436
- } else {
1437
- dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
1438
- dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
1439
- }
1440
- }
1441
-
1442
- Flash_bwd_params params;
1443
- set_params_dgrad(params,
1444
- batch_size,
1445
- seqlen_q, seqlen_k,
1446
- seqlen_q_rounded, seqlen_k_rounded,
1447
- num_heads, num_heads_k,
1448
- head_size, head_size_rounded,
1449
- q, k, v, out,
1450
- dout, dq, dk, dv,
1451
- !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
1452
- !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
1453
- seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
1454
- seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
1455
- dq_accum.data_ptr(),
1456
- num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
1457
- num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
1458
- softmax_lse.data_ptr(),
1459
- softmax_d.data_ptr(),
1460
- /*p_dropout=*/0.f,
1461
- softmax_scale,
1462
- window_size_left,
1463
- window_size_right,
1464
- softcap,
1465
- deterministic,
1466
- sm_margin);
1467
- params.total_q = total_q;
1468
- params.total_k = total_k;
1469
- params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
1470
- params.dv = head_size; // We don't support hdim_v being different from hdim_qk for now
1471
-
1472
- // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
1473
- // params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
1474
- // Will be zero'ed out in the backward preprocess kernel
1475
- at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
1476
- params.dq_semaphore = dq_semaphore.data_ptr<int>();
1477
- if (num_heads_k != num_heads && params.deterministic) {
1478
- // TODO: do we need to zero them out?
1479
- at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
1480
- at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
1481
- params.dk_semaphore = dk_semaphore.data_ptr<int>();
1482
- params.dv_semaphore = dv_semaphore.data_ptr<int>();
1483
- }
1484
-
1485
- #ifdef FLASHATTENTION_DISABLE_LOCAL
1486
- TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
1487
- #endif
1488
- #ifdef FLASHATTENTION_DISABLE_SOFTCAP
1489
- TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
1490
- #endif
1491
-
1492
- if (total_q > 0 && total_k > 0 && num_heads_k > 0) {
1493
- auto stream = at::cuda::getCurrentCUDAStream().stream();
1494
- run_mha_bwd(params, stream);
1495
- } else if (total_k > 0 && num_heads_k > 0) {
1496
- // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
1497
- dk.zero_();
1498
- dv.zero_();
1499
- softmax_d.zero_();
1500
- } else if (total_q > 0 && num_heads_k > 0) {
1501
- dq.zero_();
1502
- softmax_d.zero_();
1503
- }
1504
-
1505
- return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
1506
- }
1507
-
1508
- std::vector<at::Tensor>
1509
- mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x seqlen x num_heads x head_size
1510
- const at::Tensor &lse_partial, // num_splits x batch_size x seqlen x num_heads
1511
- std::optional<at::Tensor> out_, // batch_size x seqlen x num_heads x head_size
1512
- std::optional<at::ScalarType> out_dtype_
1513
- ) {
1514
-
1515
- auto dprops = at::cuda::getCurrentDeviceProperties();
1516
- bool is_sm8x = dprops->major >= 8;
1517
- TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer.");
1518
-
1519
- auto out_partial_type = out_partial.scalar_type();
1520
- TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type");
1521
- TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type");
1522
-
1523
- CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);
1524
-
1525
- TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1526
- TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension");
1527
-
1528
- const auto sizes = out_partial.sizes();
1529
-
1530
- const int num_splits = sizes[0];
1531
- const int batch_size = sizes[1];
1532
- const int seqlen = sizes[2];
1533
- const int num_heads = sizes[3];
1534
- const int head_size_og = sizes[4];
1535
- TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256");
1536
-
1537
- CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);
1538
- CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);
1539
-
1540
- int const alignment = 4;
1541
- at::Tensor out_partial_padded;
1542
- auto pad = [](at::Tensor x, int alignment) {
1543
- return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
1544
- };
1545
- out_partial_padded = pad(out_partial, alignment);
1546
-
1547
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1548
- const int head_size = round_multiple(head_size_og, alignment);
1549
-
1550
- auto opts = out_partial.options();
1551
- at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());
1552
- TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16");
1553
- at::Tensor out;
1554
- if (out_.has_value()) {
1555
- out = out_.value();
1556
- TORCH_CHECK(out.scalar_type() == out_type);
1557
- CHECK_DEVICE(out);
1558
- TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
1559
- CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);
1560
- if (head_size_og % alignment != 0) {
1561
- out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
1562
- }
1563
- } else {
1564
- out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
1565
- }
1566
-
1567
- // Otherwise the kernel will be launched from cuda:0 device
1568
- // Cast to char to avoid compiler warning about narrowing
1569
- at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()};
1570
-
1571
- auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2);
1572
-
1573
- Flash_fwd_params params {}; // Need to reset the params to set everything to zero
1574
- params.is_fp32 = out_type == at::ScalarType::Float;
1575
- params.is_bf16 = out_type == at::ScalarType::BFloat16;
1576
- params.oaccum_ptr = out_partial_padded.data_ptr();
1577
- params.softmax_lseaccum_ptr = lse_partial.data_ptr();
1578
- params.o_ptr = out.data_ptr();
1579
- params.softmax_lse_ptr = softmax_lse.data_ptr();
1580
- params.b = batch_size;
1581
- params.h = num_heads;
1582
- params.seqlen_q = seqlen;
1583
- params.dv = head_size;
1584
- params.num_splits = num_splits;
1585
- params.oaccum_split_stride = out_partial_padded.stride(0);
1586
- params.oaccum_row_stride = out_partial_padded.stride(2);
1587
- params.oaccum_head_stride = out_partial_padded.stride(3);
1588
- params.oaccum_batch_stride = out_partial_padded.stride(1);
1589
- params.lseaccum_split_stride = lse_partial.stride(0);
1590
- params.lseaccum_head_stride = lse_partial.stride(3);
1591
- params.lseaccum_batch_stride = lse_partial.stride(1);
1592
- params.o_row_stride = out.stride(1);
1593
- params.o_head_stride = out.stride(2);
1594
- params.o_batch_stride = out.stride(0);
1595
- params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
1596
-
1597
- if (seqlen > 0 && batch_size > 0) {
1598
- auto stream = at::cuda::getCurrentCUDAStream().stream();
1599
- run_mha_fwd_combine(params, stream, false /*enable_pdl*/);
1600
- }
1601
-
1602
- at::Tensor out_padded = out;
1603
- if (head_size_og % alignment != 0) {
1604
- out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
1605
- // if (out_.has_value()) { out_.value().copy_(out); }
1606
- }
1607
-
1608
- return {out, softmax_lse};
1609
- }
1610
-
1611
- #ifndef FLASHATTENTION_DISABLE_PYBIND
1612
-
1613
- #include <torch/python.h>
1614
-
1615
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1616
- m.doc() = "FlashAttention";
1617
- m.def("fwd", &mha_fwd, "Forward pass");
1618
- m.def("bwd", &mha_bwd, "Backward pass");
1619
- m.def("fwd_combine", &mha_combine, "Combine partial attention outputs");
1620
- m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass");
1621
- }
1622
-
1623
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_bwd_kernel_sm80.h DELETED
@@ -1,173 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include <cutlass/cutlass.h>
10
- #include <cutlass/array.h>
11
- #include <cutlass/numeric_types.h>
12
- #include <cutlass/kernel_hardware_info.h>
13
-
14
- #include "utils.h"
15
-
16
- namespace flash {
17
-
18
- using namespace cute;
19
-
20
- template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
21
- class FlashAttnBwdSm80 {
22
-
23
- public:
24
-
25
- // Type Aliases
26
- static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
27
- static constexpr bool Is_local = CollectiveMainloop_::Is_local;
28
- static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
29
- static constexpr bool Varlen = CollectiveMainloop_::Varlen;
30
-
31
- // Mainloop derived types
32
- using CollectiveMainloop = CollectiveMainloop_;
33
- using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
34
- using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
35
- using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
36
- using ArchTag = typename CollectiveMainloop::ArchTag;
37
- using MainloopArguments = typename CollectiveMainloop::Arguments;
38
- using MainloopParams = typename CollectiveMainloop::Params;
39
- static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
40
-
41
- // Epilogue derived types
42
- using CollectiveEpilogue = CollectiveEpilogue_;
43
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
44
- using EpilogueParams = typename CollectiveEpilogue::Params;
45
-
46
- static_assert(ArchTag::kMinComputeCapability >= 80);
47
-
48
- using TileScheduler = TileScheduler_;
49
- using TileSchedulerArguments = typename flash::TileSchedulerArguments;
50
- using TileSchedulerParams = typename TileScheduler::Params;
51
-
52
- static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{}));
53
- static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{}));
54
- static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
55
-
56
- // Kernel level shared memory storage
57
- struct SharedStorage {
58
- struct TensorStorage : cute::aligned_struct<128> {
59
- union {
60
- typename CollectiveMainloop::TensorStorage mainloop;
61
- typename CollectiveEpilogue::TensorStorage epilogue;
62
- };
63
- } tensors;
64
-
65
- alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
66
-
67
- };
68
-
69
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
70
-
71
- // Device side arguments
72
- struct Arguments {
73
- MainloopArguments mainloop{};
74
- EpilogueArguments epilogue{};
75
- cutlass::KernelHardwareInfo hw_info{};
76
- TileSchedulerArguments scheduler{};
77
- };
78
-
79
- // Kernel entry point API
80
- struct Params {
81
- MainloopParams mainloop{};
82
- EpilogueParams epilogue{};
83
- cutlass::KernelHardwareInfo hw_info{};
84
- TileSchedulerParams scheduler{};
85
- };
86
-
87
- //
88
- // Methods
89
- //
90
-
91
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
92
- static
93
- Params
94
- to_underlying_arguments(Arguments const& args) {
95
- CUTLASS_TRACE_HOST("to_underlying_arguments():");
96
-
97
- // Get SM count if needed, otherwise use user supplied SM count
98
- int sm_count = args.hw_info.sm_count;
99
- if (sm_count <= 0) {
100
- CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
101
- " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
102
- sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
103
- }
104
-
105
- CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
106
-
107
- cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
108
- return {
109
- CollectiveMainloop::to_underlying_arguments(args.mainloop),
110
- CollectiveEpilogue::to_underlying_arguments(args.epilogue),
111
- hw_info,
112
- TileScheduler::to_underlying_arguments(args.scheduler)
113
- };
114
- }
115
-
116
- // Computes the kernel launch grid shape based on runtime parameters
117
- static dim3
118
- get_grid_shape(Params const& params) {
119
- return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
120
- }
121
-
122
- static dim3
123
- get_block_shape() {
124
- return dim3(MaxThreadsPerBlock, 1, 1);
125
- }
126
-
127
- CUTLASS_DEVICE
128
- void
129
- operator()(Params const& params, char* smem_buf) {
130
-
131
- static constexpr int kBlockM = get<0>(TileShape_MNK{});
132
- static constexpr int kBlockN = get<1>(TileShape_MNK{});
133
-
134
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
135
-
136
- CollectiveMainloop mainloop;
137
- CollectiveEpilogue epilogue;
138
-
139
- TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
140
- // Initialize matmul objects.
141
- TiledMmadKV tiled_mma_dKV;
142
-
143
- scheduler.init_consumer();
144
-
145
- int warp_idx = cutlass::canonical_warp_idx_sync();
146
- CUTLASS_PRAGMA_NO_UNROLL
147
- for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
148
- work_tile_info.is_valid(params.scheduler);
149
- work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
150
-
151
- auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
152
- auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
153
- cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
154
-
155
- // dK and dV output accumulator.
156
- Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
157
- Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
158
- bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x,
159
- block_coord, shared_storage);
160
- scheduler.prefetch_next_work(params.scheduler, work_tile_info);
161
- if (tile_valid) {
162
- epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
163
- threadIdx.x, block_coord);
164
- } else {
165
- epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
166
- }
167
- }
168
-
169
- }
170
-
171
- };
172
-
173
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_bwd_kernel_sm90.h DELETED
@@ -1,282 +0,0 @@
1
-
2
- /******************************************************************************
3
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
4
- ******************************************************************************/
5
-
6
- #pragma once
7
-
8
- #include "cute/tensor.hpp"
9
-
10
- #include <cutlass/cutlass.h>
11
- #include <cutlass/arch/reg_reconfig.h>
12
- #include <cutlass/array.h>
13
- #include <cutlass/numeric_types.h>
14
- #include <cutlass/numeric_conversion.h>
15
- #include <cutlass/kernel_hardware_info.h>
16
- #include "cutlass/pipeline/pipeline.hpp"
17
-
18
- #include "utils.h"
19
-
20
- namespace flash {
21
-
22
- using namespace cute;
23
-
24
- template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
25
- class FlashAttnBwdSm90 {
26
-
27
- public:
28
-
29
- // Type Aliases
30
- static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
31
- static constexpr bool Is_local = CollectiveMainloop_::Is_local;
32
- static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
33
- static constexpr bool Varlen = CollectiveMainloop_::Varlen;
34
-
35
- // Mainloop derived types
36
- using CollectiveMainloop = CollectiveMainloop_;
37
- using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
38
- using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
39
- using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
40
- using ArchTag = typename CollectiveMainloop::ArchTag;
41
- using ClusterShape = typename CollectiveMainloop::ClusterShape;
42
- using MainloopArguments = typename CollectiveMainloop::Arguments;
43
- using MainloopParams = typename CollectiveMainloop::Params;
44
- static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
45
-
46
- // Epilogue derived types
47
- using CollectiveEpilogue = CollectiveEpilogue_;
48
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
49
- using EpilogueParams = typename CollectiveEpilogue::Params;
50
-
51
- static_assert(ArchTag::kMinComputeCapability >= 90);
52
-
53
- using TileScheduler = TileScheduler_;
54
- using TileSchedulerArguments = typename flash::TileSchedulerArguments;
55
- using TileSchedulerParams = typename TileScheduler::Params;
56
-
57
- static constexpr uint32_t NumLoadWarpGroups = 1;
58
- static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup;
59
- static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
60
- static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
61
- static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
62
-
63
- /// Register requirement for Load and Math WGs
64
- static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32;
65
- static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160;
66
- // If you want to print from the producer warp, you'd need to increase the number of registers
67
- // Otherwise you'll get CUDA error.
68
- // static constexpr uint32_t LoadRegisterRequirement = 40;
69
- // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
70
-
71
- // Kernel level shared memory storage
72
- struct SharedStorage {
73
- struct TensorStorage : cute::aligned_struct<128> {
74
- union {
75
- typename CollectiveMainloop::TensorStorage mainloop;
76
- typename CollectiveEpilogue::TensorStorage epilogue;
77
- };
78
- } tensors;
79
-
80
- struct PipelineStorage : cute::aligned_struct<16> {
81
- alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV;
82
- alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q;
83
- alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do;
84
- alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
85
- } pipelines;
86
-
87
- };
88
-
89
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
90
-
91
- // Device side arguments
92
- struct Arguments {
93
- MainloopArguments mainloop{};
94
- EpilogueArguments epilogue{};
95
- cutlass::KernelHardwareInfo hw_info{};
96
- TileSchedulerArguments scheduler{};
97
- };
98
-
99
- // Kernel entry point API
100
- struct Params {
101
- MainloopParams mainloop{};
102
- EpilogueParams epilogue{};
103
- cutlass::KernelHardwareInfo hw_info{};
104
- TileSchedulerParams scheduler{};
105
- };
106
-
107
- //
108
- // Methods
109
- //
110
-
111
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
112
- static
113
- Params
114
- to_underlying_arguments(Arguments const& args) {
115
- CUTLASS_TRACE_HOST("to_underlying_arguments():");
116
-
117
- // Get SM count if needed, otherwise use user supplied SM count
118
- int sm_count = args.hw_info.sm_count;
119
- if (sm_count <= 0) {
120
- CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
121
- " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
122
- sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
123
- }
124
-
125
- CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
126
-
127
- cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
128
- return {
129
- CollectiveMainloop::to_underlying_arguments(args.mainloop),
130
- CollectiveEpilogue::to_underlying_arguments(args.epilogue),
131
- hw_info,
132
- TileScheduler::to_underlying_arguments(args.scheduler)
133
- };
134
- }
135
-
136
- // Computes the kernel launch grid shape based on runtime parameters
137
- static dim3
138
- get_grid_shape(Params const& params) {
139
- return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
140
- }
141
-
142
- static dim3
143
- get_block_shape() {
144
- return dim3(MaxThreadsPerBlock, 1, 1);
145
- }
146
-
147
- CUTLASS_DEVICE
148
- void
149
- operator()(Params const& params, char* smem_buf) {
150
-
151
- static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
152
- static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
153
- static constexpr int kBlockM = get<0>(TileShape_MNK{});
154
- static constexpr int kBlockN = get<1>(TileShape_MNK{});
155
-
156
- using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
157
- using PipelineParams = typename MainloopPipeline::Params;
158
- using PipelineState = typename MainloopPipeline::PipelineState;
159
- using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO;
160
- using PipelineParams_dO = typename MainloopPipeline_dO::Params;
161
- using PipelineState_dO = typename MainloopPipeline_dO::PipelineState;
162
- static constexpr bool Q_dO_same_stages = std::is_same_v<MainloopPipeline, MainloopPipeline_dO>;
163
-
164
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
165
-
166
- int const lane_predicate = cute::elect_one_sync();
167
- int const warp_idx = cutlass::canonical_warp_idx_sync();
168
-
169
- // Issue Tma Descriptor Prefetch from a single thread
170
- if (warp_idx == 0 && lane_predicate) {
171
- CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
172
- CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
173
- }
174
-
175
- // Obtain warp index
176
- int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
177
-
178
- PipelineParams pipeline_params;
179
- pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE;
180
- int warp_group_idx = cutlass::canonical_warp_group_idx();
181
- pipeline_params.role = warp_group_idx == 0
182
- ? MainloopPipeline::ThreadCategory::Producer
183
- : MainloopPipeline::ThreadCategory::Consumer;
184
- pipeline_params.is_leader = warp_group_thread_idx == 0;
185
- pipeline_params.num_consumers = NumMmaThreads;
186
-
187
- if (warp_idx == 0 && lane_predicate) {
188
- shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/);
189
- }
190
- // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init();
191
- MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{});
192
- auto role_dO = warp_group_idx == 0
193
- ? MainloopPipeline_dO::ThreadCategory::Producer
194
- : MainloopPipeline_dO::ThreadCategory::Consumer;
195
- PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers};
196
- MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return<Q_dO_same_stages>(pipeline_params, pipeline_params_dO), ClusterShape{});
197
-
198
- CollectiveMainloop mainloop;
199
- CollectiveEpilogue epilogue;
200
-
201
- // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
202
- if constexpr (size(ClusterShape{}) > 1) {
203
- cute::cluster_arrive_relaxed();
204
- cute::cluster_wait();
205
- } else {
206
- __syncthreads();
207
- }
208
-
209
- TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
210
-
211
- if (warp_group_idx == 0) { // Producer
212
- cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
213
-
214
- int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
215
- if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO
216
- PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
217
- PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline_dO>();
218
- for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler);
219
- work_tile_info.is_valid(params.scheduler);
220
- work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info)) {
221
- auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
222
- auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
223
- cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
224
- auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
225
- scheduler.prefetch_next_work(params.scheduler, work_tile_info);
226
- };
227
- mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write,
228
- smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord);
229
- }
230
- mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do);
231
- } else if (warp_idx_in_warpgroup == 1) {
232
- for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
233
- work_tile_info.is_valid(params.scheduler);
234
- work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
235
- auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
236
- auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
237
- cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
238
- mainloop.store_dq(params.mainloop, shared_storage, block_coord);
239
- }
240
- }
241
- } else { // Consumer
242
- cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
243
- // Initialize matmul objects.
244
- TiledMmadKV tiled_mma_dKV;
245
-
246
- PipelineState smem_pipe_read;
247
- PipelineState_dO smem_pipe_read_do;
248
-
249
- mainloop.mma_init();
250
- scheduler.init_consumer();
251
-
252
- int work_idx = 0;
253
- CUTLASS_PRAGMA_NO_UNROLL
254
- for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
255
- work_tile_info.is_valid(params.scheduler);
256
- work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
257
- auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
258
- auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
259
- cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
260
-
261
- // dK and dV output accumulator.
262
- Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
263
- Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
264
- bool tile_valid = mainloop.mma(
265
- params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do,
266
- tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage);
267
- if (tile_valid) {
268
- epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
269
- threadIdx.x - NumCopyThreads, block_coord);
270
- } else {
271
- epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
272
- }
273
-
274
- }
275
- epilogue.store_tail();
276
- }
277
-
278
- }
279
-
280
- };
281
-
282
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_bwd_launch_template.h DELETED
@@ -1,377 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include "cutlass/device_kernel.h" // For device_kernel
10
- #include "cutlass/kernel_launch.h" // For kernel_launch
11
- #include "cutlass/cluster_launch.hpp" // For ClusterLauncher
12
-
13
- #include "static_switch.h"
14
- #include "flash.h"
15
- #include "flash_bwd_preprocess_kernel.h"
16
- #include "flash_bwd_postprocess_kernel.h"
17
- #include "tile_scheduler.hpp"
18
- #include "mainloop_bwd_sm90_tma_gmma_ws.hpp"
19
- #include "mainloop_bwd_sm80.hpp"
20
- #include "epilogue_bwd.hpp"
21
- #include "flash_bwd_kernel_sm90.h"
22
- #include "flash_bwd_kernel_sm80.h"
23
-
24
- using namespace cute;
25
-
26
- template <int Arch, int kHeadDim, int kBlockM, int kBlockN, typename Element,
27
- bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool Deterministic, bool GQA,
28
- int Stages_dO=2, int Stages_dS_or_QSm80=2,
29
- bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
30
- int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
31
- bool V_in_regs=false>
32
- void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
33
- static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
34
- using ElementAccum = float;
35
- using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
36
-
37
- int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM);
38
- int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN);
39
- bool const is_varlen_q = params.cu_seqlens_q;
40
- bool const is_varlen_k = params.cu_seqlens_k;
41
- int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
42
- int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k;
43
- int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded;
44
- int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded;
45
- int batch_q = !is_varlen_q ? params.b : 1;
46
- int batch_k = !is_varlen_k ? params.b : 1;
47
-
48
- using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
49
- using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, ArchTag, /*Clear_dQaccum=*/true, Varlen>;
50
- typename PreprocessKernel::Arguments preprocess_args {
51
- static_cast<Element const*>(params.o_ptr),
52
- {seqlen_q, params.d, params.h, batch_q}, // shape_O
53
- {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O
54
- static_cast<Element const*>(params.do_ptr),
55
- {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
56
- static_cast<float*>(params.dsoftmax_sum),
57
- {seqlen_q_rounded, params.h, batch_q}, // shape_dPsum
58
- {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
59
- static_cast<float*>(params.softmax_lse_ptr),
60
- {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE
61
- static_cast<float*>(params.softmax_lse_log2_ptr),
62
- {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
63
- static_cast<ElementAccum*>(params.dq_accum_ptr),
64
- {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
65
- {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum
66
- params.b,
67
- params.dq_semaphore,
68
- params.cu_seqlens_q,
69
- params.seqused_q
70
- };
71
- typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
72
- int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
73
- dim3 grid_m(num_m_block, params.h, params.b);
74
- cutlass::kernel_launch<PreprocessKernel>(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/);
75
- CHECK_CUDA_KERNEL_LAUNCH();
76
-
77
- using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
78
- using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster
79
- // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80
80
- static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80;
81
- static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1;
82
- using CollectiveMainloop = std::conditional_t<
83
- Arch >= 90,
84
- flash::CollectiveMainloopBwdSm90<Stages, Stages_dO, Stages_dS, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
85
- Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
86
- SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,
87
- flash::CollectiveMainloopBwdSm80<Stages, Stages_dO, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm80,
88
- Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
89
- SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>
90
- >;
91
- using CollectiveEpilogue = std::conditional_t<
92
- !GQA,
93
- flash::CollectiveEpilogueBwd<TileShape_MNK, Element, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, dKV_swapAB, NumMmaWarpGroups * (Arch >= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>,
94
- flash::CollectiveEpilogueBwdGQA<TileShape_MNK, ElementAccum, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, Deterministic>
95
- >;
96
- using Scheduler = flash::SingleTileScheduler<Varlen, false /*Split*/, false /*PackGQA*/, kBlockN>;
97
- using AttnKernel = std::conditional_t<
98
- Arch >= 90,
99
- flash::enable_sm90_or_later<flash::FlashAttnBwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
100
- flash::enable_sm80_to_sm89<flash::FlashAttnBwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
101
- >;
102
-
103
- typename CollectiveMainloop::Arguments mainloop_args {
104
- static_cast<Element const*>(params.q_ptr),
105
- {seqlen_q, params.d, params.h, batch_q}, // shape_Q
106
- {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
107
- static_cast<Element const*>(params.k_ptr),
108
- {seqlen_k, params.d, params.h_k, batch_k}, // shape_K
109
- {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
110
- static_cast<Element const*>(params.v_ptr),
111
- {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V
112
- static_cast<Element const*>(params.do_ptr),
113
- {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
114
- static_cast<ElementAccum*>(params.dq_accum_ptr),
115
- {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
116
- {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
117
- static_cast<float*>(params.softmax_lse_log2_ptr),
118
- {seqlen_q_rounded, params.h, batch_q}, // shape_LSE
119
- {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
120
- static_cast<float*>(params.dsoftmax_sum),
121
- {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
122
- params.scale_softmax,
123
- params.window_size_left, params.window_size_right,
124
- params.softcap,
125
- params.b,
126
- params.dq_semaphore,
127
- params.cu_seqlens_q, params.cu_seqlens_k,
128
- params.seqused_q, params.seqused_k
129
- };
130
- // The case work with GQA is ugly but idk how to fix it.
131
- typename CollectiveEpilogue::Arguments epilogue_args {
132
- static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dk_ptr : params.dk_accum_ptr),
133
- [&] {
134
- if constexpr (!GQA) {
135
- return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK
136
- } else {
137
- return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum
138
- }
139
- }(),
140
- [&] {
141
- if constexpr (!GQA) {
142
- return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK
143
- } else {
144
- return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum
145
- }
146
- }(),
147
- static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dv_ptr : params.dv_accum_ptr),
148
- [&] {
149
- if constexpr (!GQA) {
150
- return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV
151
- } else {
152
- return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum
153
- }
154
- }(),
155
- params.h,
156
- params.dk_semaphore,
157
- params.dv_semaphore,
158
- params.cu_seqlens_k,
159
- params.seqused_k,
160
- };
161
-
162
- int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{}));
163
- num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{}));
164
- typename flash::TileSchedulerArguments scheduler_args {
165
- num_blocks_n, params.h, params.b, 1 /*num_splits*/,
166
- params.h / params.h_k,
167
- params.seqlen_k,
168
- params.seqlen_q, params.d, params.dv, sizeof(Element),
169
- params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k
170
- };
171
-
172
- int device;
173
- cudaGetDevice(&device);
174
- typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
175
- mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
176
- });
177
-
178
- dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
179
- dim3 block_dims = AttnKernel::get_block_shape();
180
- int smem_size = AttnKernel::SharedStorageSize;
181
- // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
182
- // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do));
183
- // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds));
184
- // int smem_size_dqacc = [&] {
185
- // if constexpr (Arch >= 90) {
186
- // return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc));
187
- // } else {
188
- // return 0;
189
- // }
190
- // }();
191
- // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
192
- // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
193
- // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse));
194
- // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum));
195
- // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum);
196
- if constexpr (size(ClusterShape{}) > 1) {
197
- void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
198
- if (smem_size >= 48 * 1024) {
199
- CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
200
- }
201
- dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
202
- cutlass::ClusterLauncher::launch(
203
- grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/);
204
- } else {
205
- if (smem_size >= 48 * 1024) {
206
- CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<AttnKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
207
- }
208
- cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/);
209
- }
210
- CHECK_CUDA_KERNEL_LAUNCH();
211
-
212
- using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, ArchTag,
213
- AttnKernel::CollectiveMainloop::NumMmaThreads,
214
- typename AttnKernel::CollectiveMainloop::TiledMmadQ,
215
- AttnKernel::CollectiveMainloop::dQ_swapAB
216
- >;
217
- typename PostprocessKernel::Arguments postprocess_args {
218
- static_cast<ElementAccum const*>(params.dq_accum_ptr),
219
- {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
220
- {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
221
- static_cast<Element*>(params.dq_ptr),
222
- {seqlen_q, params.d, params.h, batch_q}, // shape_dQ
223
- {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
224
- params.scale_softmax,
225
- params.cu_seqlens_q,
226
- params.seqused_q
227
- };
228
- typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
229
- int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
230
- dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b);
231
- int smem_size_postprocess = PostprocessKernel::SharedStorageSize;
232
- if (smem_size_postprocess >= 48 * 1024) {
233
- CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
234
- }
235
- cutlass::kernel_launch<PostprocessKernel>(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/);
236
- CHECK_CUDA_KERNEL_LAUNCH();
237
-
238
- if constexpr (GQA) {
239
- using TileShape_NK = cute::Shape<Int<kBlockN>, Int<kHeadDim>>;
240
- using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_NK, Element, ElementAccum, ArchTag,
241
- AttnKernel::CollectiveEpilogue::NumEpilogueThreads,
242
- typename AttnKernel::CollectiveMainloop::TiledMmadKV,
243
- AttnKernel::CollectiveMainloop::dKV_swapAB
244
- >;
245
- typename PostprocessKerneldKV::Arguments postprocess_dK_args {
246
- static_cast<ElementAccum const*>(params.dk_accum_ptr),
247
- {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum
248
- {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum
249
- static_cast<Element*>(params.dk_ptr),
250
- {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK
251
- {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK
252
- 1.f,
253
- params.cu_seqlens_k,
254
- params.seqused_k
255
- };
256
- typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args);
257
- typename PostprocessKerneldKV::Arguments postprocess_dV_args {
258
- static_cast<ElementAccum const*>(params.dv_accum_ptr),
259
- {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dVaccum
260
- {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum
261
- static_cast<Element*>(params.dv_ptr),
262
- {seqlen_k, params.d, params.h_k, batch_k}, // shape_dV
263
- {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV
264
- 1.f,
265
- params.cu_seqlens_k,
266
- params.seqused_k
267
- };
268
- typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args);
269
- int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{}));
270
- dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b);
271
- int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize;
272
- if (smem_size_postprocess >= 48 * 1024) {
273
- CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKerneldKV>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
274
- }
275
- cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/);
276
- CHECK_CUDA_KERNEL_LAUNCH();
277
- cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/);
278
- CHECK_CUDA_KERNEL_LAUNCH();
279
- }
280
-
281
- }
282
-
283
- template<int Arch, typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Has_softcap,
284
- int Stages_dO=2, int Stages_dS_or_QSm80=2,
285
- bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
286
- int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
287
- bool V_in_regs=false>
288
- void run_mha_bwd_dispatch(Flash_bwd_params &params, cudaStream_t stream) {
289
- VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
290
- BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
291
- // BOOL_SWITCH(params.deterministic, Deterministic, [&] {
292
- // run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
293
- run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
294
- // });
295
- });
296
- });
297
- }
298
-
299
-
300
- template<int Arch, typename T, bool Has_softcap>
301
- void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
302
- CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
303
- if constexpr (Arch >= 90) {
304
- if constexpr (Is_causal && Has_softcap) {
305
- // register spill with 128 x 128
306
- run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);
307
- } else {
308
- // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
309
- run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);
310
- }
311
- } else if constexpr (Arch == 86 || Arch == 89) {
312
- run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
313
- // run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
314
- // run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
315
- // run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 1, 8, 4, false>(params, stream);
316
- } else {
317
- run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false>(params, stream);
318
- }
319
- });
320
- }
321
-
322
- template<int Arch, typename T, bool Has_softcap>
323
- void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
324
- CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
325
- if constexpr (Arch >= 90) {
326
- run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);
327
- } else if constexpr (Arch == 86 || Arch == 89) {
328
- run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
329
- } else {
330
- run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false>(params, stream);
331
- }
332
- });
333
- }
334
-
335
- template<int Arch, typename T, bool Has_softcap>
336
- void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
337
- CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
338
- if constexpr (Arch >= 90) {
339
- if constexpr (Is_causal || Is_local || Has_softcap) {
340
- run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
341
- } else {
342
- run_mha_bwd_dispatch<Arch, T, 80, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
343
- }
344
- } else if constexpr (Arch == 86 || Arch == 89) {
345
- run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true>(params, stream);
346
- } else {
347
- run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false>(params, stream);
348
- }
349
- });
350
- }
351
-
352
- template<int Arch, typename T, bool Has_softcap>
353
- void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
354
- CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
355
- if constexpr (Arch >= 90) {
356
- run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
357
- } else if constexpr (Arch == 86 || Arch == 89) {
358
- run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true>(params, stream);
359
- } else {
360
- run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false>(params, stream);
361
- }
362
- });
363
- }
364
-
365
- template<int Arch, typename T, bool Has_softcap>
366
- void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
367
- CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
368
- if constexpr (Arch >= 90) {
369
- run_mha_bwd_dispatch<Arch, T, 64, 80, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
370
- } else if constexpr (Arch == 86 || Arch == 89) {
371
- run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true>(params, stream);
372
- // run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 1, 2, true>(params, stream);
373
- } else {
374
- run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 2, 2, false>(params, stream);
375
- }
376
- });
377
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_bwd_postprocess_kernel.h DELETED
@@ -1,256 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include <cutlass/cutlass.h>
10
- #include <cutlass/array.h>
11
- #include <cutlass/numeric_types.h>
12
- #include <cutlass/numeric_conversion.h>
13
- #include "cutlass/arch/barrier.h"
14
-
15
- #include "seqlen.h"
16
- #include "utils.h"
17
-
18
- namespace flash {
19
-
20
- using namespace cute;
21
-
22
- template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class TiledMma, bool dQ_swapAB>
23
- class FlashAttnBwdPostprocessConvertdQ {
24
-
25
- public:
26
-
27
- // Type Aliases
28
- using TileShape_MK = TileShape_MK_;
29
- using ArchTag = ArchTag_;
30
-
31
- static_assert(ArchTag::kMinComputeCapability >= 75);
32
- static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90;
33
-
34
- static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
35
- static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
36
-
37
- static constexpr int kBlockM = get<0>(TileShape_MK{});
38
- static constexpr int kHeadDim = get<1>(TileShape_MK{});
39
- static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup");
40
- static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup;
41
- using R2SLayoutAtomdQaccum = std::conditional_t<
42
- IsSm90,
43
- Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumdQWarpGgroups>>>,
44
- Layout<Shape<Int<kNThreads>>>
45
- >;
46
- using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},
47
- Layout<Shape<Int<IsSm90 ? 4 : 1>>>{})); // Val layout, 1 or 4 vals per read
48
- using G2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreads>>>;
49
- // UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions
50
- using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, ElementAccum>{}, G2SLayoutAtomdQaccum{},
51
- Layout<Shape<_4>>{})); // Val layout, 4 vals per read
52
- // We don't do bound checking for the gmem -> smem load so we just assert here.
53
- static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0);
54
- static constexpr int SmemdQaccumSize = size(TileShape_MK{});
55
- using SmemLayoutdQaccumFlat = Layout<Shape<Int<SmemdQaccumSize>>>;
56
- using SmemLayoutdQaccum = std::conditional_t<
57
- IsSm90,
58
- Layout<Shape<Int<kBlockM * kHeadDim / NumdQWarpGgroups>, Int<NumdQWarpGgroups>>>,
59
- Layout<Shape<Int<kBlockM * kHeadDim>>>
60
- >;
61
-
62
- // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
63
- // then setting kBlockKSmem to 32 will cause "Static shape_div failure".
64
- // We want to treat it as 64 x 48, so kBlockKSmem should be 16.
65
- static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{});
66
- static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16);
67
- static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
68
- using SmemLayoutAtomdQ =
69
- decltype(composition(Swizzle<kSwizzle, 3, 3>{},
70
- Layout<Shape<Int<8>, Int<kBlockKSmem>>,
71
- Stride<Int<kBlockKSmem>, _1>>{}));
72
- using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{}));
73
- using SmemLayoutdQt =
74
- decltype(cute::composition(SmemLayoutdQ{},
75
- make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})),
76
- make_stride(Int<get<0>(TileShape_MK{})>{}, _1{}))));
77
-
78
- using SmemCopyAtomdQ = Copy_Atom<
79
- std::conditional_t<
80
- IsSm90,
81
- std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
82
- AutoVectorizingCopyWithAssumedAlignment<128>
83
- >,
84
- Element>;
85
-
86
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
87
- static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
88
- static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock));
89
- static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
90
- using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
91
- Stride<Int<kGmemThreadsPerRow>, _1>>;
92
- using GmemTiledCopy = decltype(
93
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
94
- GmemLayoutAtom{},
95
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
96
-
97
- struct SharedStorage : cute::aligned_struct<128> {
98
- cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>> smem_dqacc;
99
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
100
- alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum;
101
- };
102
-
103
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
104
-
105
- using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
106
- using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;
107
- using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
108
- using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
109
-
110
- // Device side arguments
111
- struct Arguments {
112
- ElementAccum const* ptr_dQaccum;
113
- ShapedQaccum const shape_dQaccum;
114
- StridedQaccum const stride_dQaccum;
115
- Element* ptr_dQ;
116
- ShapedQ const shape_dQ;
117
- StridedQ const stride_dQ;
118
- float const softmax_scale;
119
- int const* cu_seqlens = nullptr;
120
- int const* seqused = nullptr;
121
- };
122
-
123
- // Kernel entry point API
124
- struct Params {
125
- ElementAccum const* ptr_dQaccum;
126
- ShapedQaccum const shape_dQaccum;
127
- StridedQaccum const stride_dQaccum;
128
- Element* ptr_dQ;
129
- ShapedQ const shape_dQ;
130
- StridedQ const stride_dQ;
131
- float const softmax_scale;
132
- int const* cu_seqlens = nullptr;
133
- int const* seqused = nullptr;
134
- };
135
-
136
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
137
- static
138
- Params
139
- to_underlying_arguments(Arguments const& args) {
140
- return {
141
- args.ptr_dQaccum,
142
- args.shape_dQaccum,
143
- args.stride_dQaccum,
144
- args.ptr_dQ,
145
- args.shape_dQ,
146
- args.stride_dQ,
147
- args.softmax_scale,
148
- args.cu_seqlens,
149
- args.seqused
150
- };
151
- }
152
-
153
- CUTLASS_DEVICE
154
- void
155
- operator()(Params const& params, char* smem_buf) {
156
-
157
- static constexpr int kBlockM = get<0>(TileShape_MK{});
158
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
159
-
160
- Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{});
161
- Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{});
162
- Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{});
163
- Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{});
164
-
165
- int const thread_idx = threadIdx.x;
166
- int const m_block = blockIdx.x;
167
- int const bidh = blockIdx.y;
168
- int const bidb = blockIdx.z;
169
-
170
- flash::SeqlenInfo<true /*Varlen*/, kBlockM> seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused);
171
- bool const is_varlen = params.cu_seqlens;
172
- if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; }
173
-
174
- // Step 1: load dQaccum from gmem to smem
175
- Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum const*>(params.ptr_dQaccum)),
176
- params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
177
- Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block)); // (M * K)
178
- if constexpr (IsSm90) { // Use BulkCopy
179
- static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v<ElementAccum> / 8);
180
- auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};
181
- // if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); }
182
- if (thread_idx == 0) {
183
- shared_storage.barrier_dQaccum.init(1 /*numThreads*/);
184
- shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
185
- copy(bulk_copy.with(*reinterpret_cast<uint64_t*>(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat);
186
- }
187
- __syncthreads();
188
- shared_storage.barrier_dQaccum.wait(0);
189
- } else {
190
- G2STiledCopydQaccum g2s_tiled_copy_dQaccum;
191
- auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
192
- Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum);
193
- Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum);
194
- cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s);
195
- __syncthreads();
196
- }
197
-
198
- // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); }
199
-
200
- // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16
201
- R2STiledCopydQaccum s2r_tiled_copy_dQaccum;
202
- auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx);
203
- Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum);
204
- TiledMma tiled_mma_dQ;
205
- Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{}));
206
- // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); }
207
- // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); }
208
- // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); }
209
- CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum));
210
- Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum);
211
- cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
212
- #pragma unroll
213
- for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; }
214
- // Convert tdQrdQ from fp32 to fp16
215
- Tensor rdQ = make_tensor_like<Element>(taccdQrdQaccum);
216
- flash::convert_type_out(taccdQrdQaccum, rdQ);
217
-
218
- // Step 3: Copy dQ from register to smem
219
- auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ);
220
- auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx);
221
- Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
222
- // if (cute::thread0()) { print(smem_tiled_copy_dQ); }
223
- // if (cute::thread0()) { print(smem_thr_copy_dQ); }
224
- // if (cute::thread0()) { print(sdQ); }
225
- Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return<!dQ_swapAB>(sdQ, sdQt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
226
- cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
227
- __syncthreads();
228
-
229
- // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
230
- Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0);
231
- Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
232
- GmemTiledCopy gmem_tiled_copy_dQ;
233
- auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx);
234
- Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
235
- Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
236
-
237
- Tensor tdQrdQ = make_fragment_like(tdQsdQ);
238
- Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{}));
239
- Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
240
- #pragma unroll
241
- for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); }
242
- // Need to check OOB when reading from smem if kBlockM isn't evenly tiled
243
- static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
244
- flash::copy</*Is_even_MN=*/EvenM, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
245
- gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM);
246
-
247
- // Step 5: Copy dQ from register to gmem
248
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
249
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
250
- gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM)
251
- );
252
- }
253
-
254
- };
255
-
256
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_bwd_preprocess_kernel.h DELETED
@@ -1,252 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include <cutlass/cutlass.h>
10
- #include <cutlass/array.h>
11
- #include <cutlass/numeric_types.h>
12
- #include <cutlass/numeric_conversion.h>
13
-
14
- #include "seqlen.h"
15
- #include "utils.h"
16
-
17
- namespace flash {
18
-
19
- using namespace cute;
20
-
21
- template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, bool Clear_dQaccum, bool Varlen>
22
- class FlashAttnBwdPreprocess {
23
-
24
- public:
25
-
26
- // Type Aliases
27
- using TileShape_MK = TileShape_MK_;
28
- using ArchTag = ArchTag_;
29
-
30
- static_assert(std::is_same_v<Element, cutlass::half_t> && ArchTag::kMinComputeCapability >= 75 ||
31
- std::is_same_v<Element, cutlass::bfloat16_t> && ArchTag::kMinComputeCapability >= 80 ||
32
- std::is_same_v<Element, cutlass::float_e4m3_t> && ArchTag::kMinComputeCapability >= 89);
33
-
34
- static constexpr uint32_t MaxThreadsPerBlock = 256;
35
- static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
36
- static constexpr int SharedStorageSize = 0;
37
-
38
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
39
- static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
40
- static constexpr int kBlockM = get<0>(TileShape_MK{});
41
- static constexpr int kHeadDim = get<1>(TileShape_MK{});
42
- // We want kBlockKGmem to be a power of 2 so that when we do the summing,
43
- // it's just between threads in the same warp
44
- static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
45
- static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
46
- static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
47
- using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
48
- Stride<Int<kGmemThreadsPerRow>, _1>>;
49
- using GmemTiledCopy = decltype(
50
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
51
- GmemLayoutAtom{},
52
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
53
-
54
- static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
55
- static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, "MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum");
56
- using GmemLayoutAtomAccum = Layout<Shape<Int<MaxThreadsPerBlock>>>;
57
- using GmemTiledCopyAccum = decltype(
58
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
59
- GmemLayoutAtomAccum{},
60
- Layout<Shape<Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
61
-
62
- using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
63
- using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
64
- using ShapedPsum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q, head, batch)
65
- using StridedPsum = cute::Stride<_1, int64_t, int64_t>;
66
- using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
67
- using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
68
-
69
- // Device side arguments
70
- struct Arguments {
71
- Element const* ptr_O;
72
- ShapeO const shape_O;
73
- StrideO const stride_O;
74
- Element const* ptr_dO;
75
- StrideO const stride_dO;
76
- float* ptr_dPsum;
77
- ShapedPsum const shape_dPsum;
78
- StridedPsum const stride_dPsum;
79
- float const* ptr_LSE;
80
- StridedPsum const stride_LSE;
81
- float *ptr_LSE_log2;
82
- StridedPsum const stride_LSE_log2;
83
- ElementAccum* ptr_dQaccum;
84
- ShapedQaccum const shape_dQaccum;
85
- StridedQaccum const stride_dQaccum;
86
- int num_batch; // We need this to know the size of dq_semaphore in case of varlen
87
- int* dq_semaphore;
88
- int const* cu_seqlens = nullptr;
89
- int const* seqused = nullptr;
90
- };
91
-
92
- // Kernel entry point API
93
- struct Params {
94
- Element const* ptr_O;
95
- ShapeO const shape_O;
96
- StrideO const stride_O;
97
- Element const* ptr_dO;
98
- StrideO const stride_dO;
99
- float* ptr_dPsum;
100
- ShapedPsum const shape_dPsum;
101
- StridedPsum const stride_dPsum;
102
- float const* ptr_LSE;
103
- StridedPsum const stride_LSE;
104
- float* ptr_LSE_log2;
105
- StridedPsum const stride_LSE_log2;
106
- ElementAccum* ptr_dQaccum;
107
- ShapedQaccum const shape_dQaccum;
108
- StridedQaccum const stride_dQaccum;
109
- int num_batch;
110
- int* dq_semaphore;
111
- int const* cu_seqlens = nullptr;
112
- int const* seqused = nullptr;
113
- };
114
-
115
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
116
- static
117
- Params
118
- to_underlying_arguments(Arguments const& args) {
119
- return {
120
- args.ptr_O,
121
- args.shape_O,
122
- args.stride_O,
123
- args.ptr_dO,
124
- args.stride_dO,
125
- args.ptr_dPsum,
126
- args.shape_dPsum,
127
- args.stride_dPsum,
128
- args.ptr_LSE,
129
- args.stride_LSE,
130
- args.ptr_LSE_log2,
131
- args.stride_LSE_log2,
132
- args.ptr_dQaccum,
133
- args.shape_dQaccum,
134
- args.stride_dQaccum,
135
- args.num_batch,
136
- args.dq_semaphore,
137
- args.cu_seqlens,
138
- args.seqused
139
- };
140
- }
141
-
142
- CUTLASS_DEVICE
143
- void
144
- operator()(Params const& params, [[maybe_unused]] char* smem_buf) {
145
-
146
- static constexpr int kBlockM = get<0>(TileShape_MK{});
147
-
148
- int const thread_idx = threadIdx.x;
149
- int const m_block = blockIdx.x;
150
- int const bidh = blockIdx.y;
151
- int const bidb = blockIdx.z;
152
-
153
- flash::SeqlenInfo<Varlen, kBlockM> seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused);
154
- bool const is_varlen = Varlen && params.cu_seqlens;
155
- int const seqlen_o = seqlen_info.seqlen;
156
- if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }
157
-
158
- Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);
159
- Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
160
- Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0);
161
- Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
162
-
163
- auto shape_LSE = select<0, 2, 3>(params.shape_O);
164
- Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0);
165
- Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape<Int<kBlockM>>{}, make_coord(m_block));
166
- static_assert(kBlockM <= MaxThreadsPerBlock);
167
- float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY;
168
-
169
- GmemTiledCopy gmem_tiled_copy_O;
170
- auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
171
-
172
- Tensor tOgO = gmem_thr_copy_O.partition_S(gO);
173
- Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO);
174
- // Construct identity layout for gO
175
- Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
176
- // Repeat the partitioning with identity layouts
177
- Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
178
- Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
179
- #pragma unroll
180
- for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
181
-
182
- // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128)
183
- Tensor tOrO = make_fragment_like(tOgO);
184
- Tensor tOrdO = make_fragment_like(tOgdO);
185
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
186
- gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM
187
- );
188
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
189
- gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM
190
- );
191
- // if (threadIdx.x == 222) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));}
192
-
193
- // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64))
194
- Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout())));
195
- Tensor tOrO_l = make_tensor(tOrO.data(), l);
196
- Tensor o_fp32 = make_tensor_like<float>(tOrO_l);
197
- flash::convert_type_out(tOrO_l, o_fp32);
198
- Tensor tOrdO_l = make_tensor(tOrdO.data(), l);
199
- Tensor do_fp32 = make_tensor_like<float>(tOrdO_l);
200
- flash::convert_type_out(tOrdO_l, do_fp32);
201
- // Sum across the last dimension
202
- Tensor dP_sum = make_tensor<float>(make_shape(size<0>(o_fp32)));
203
- #pragma unroll
204
- for (int mi = 0; mi < size<0>(o_fp32); ++mi) {
205
- float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
206
- #pragma unroll
207
- for (int ni = 1; ni < size<1>(o_fp32); ni++) {
208
- dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
209
- }
210
- flash::SumOp<float> sum_op;
211
- dP_sum(mi) = flash::Allreduce<kGmemThreadsPerRow>::run(dP_sum_cur, sum_op);
212
- }
213
-
214
- Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0);
215
- Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape<Int<kBlockM>>{}, make_coord(m_block));
216
- if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) {
217
- #pragma unroll
218
- for (int mi = 0; mi < size(dP_sum); ++mi) {
219
- int const row = get<0>(tOcO(_0{}, mi, _0{}));
220
- gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0;
221
- }
222
- }
223
-
224
- int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM);
225
- Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0);
226
- Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape<Int<kBlockM>>{}, make_coord(m_block));
227
- if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) {
228
- gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E);
229
- }
230
-
231
- if constexpr (Clear_dQaccum) {
232
- Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
233
- Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));
234
- GmemTiledCopyAccum gmem_tiled_copy_dQaccum;
235
- auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx);
236
- Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
237
- Tensor zero = make_fragment_like(tdQgdQaccum);
238
- clear(zero);
239
- cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, zero, tdQgdQaccum);
240
- }
241
-
242
- if (params.dq_semaphore != nullptr && thread_idx == 0) {
243
- int const num_batch = params.num_batch;
244
- int const num_head = get<2>(params.shape_O);
245
- params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0;
246
- }
247
-
248
- }
249
-
250
- };
251
-
252
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_fwd_combine.cu DELETED
@@ -1,13 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
-
4
- #include "flash_fwd_combine_launch_template.h"
5
-
6
- template void run_mha_fwd_combine_<float, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
7
- template void run_mha_fwd_combine_<float, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
8
-
9
- template void run_mha_fwd_combine_<cutlass::half_t, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
10
- template void run_mha_fwd_combine_<cutlass::half_t, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
11
-
12
- template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
13
- template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_fwd_combine_kernel.h DELETED
@@ -1,702 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include <cutlass/cutlass.h>
10
- #include <cutlass/arch/memory.h>
11
- #include <cutlass/array.h>
12
- #include <cutlass/numeric_types.h>
13
- #include <cutlass/numeric_conversion.h>
14
-
15
- #include "cutlass/arch/grid_dependency_control.h"
16
-
17
- #include "seqlen.h"
18
- #include "utils.h"
19
-
20
- namespace flash {
21
-
22
- using namespace cute;
23
-
24
- template <class TileShape_MK_, int kLogMaxSplits_, int kNThreads, int AlignmentLSE_,
25
- bool Is_even_K, bool Varlen, class Element, class ElementPartial, class ArchTag_>
26
- class FlashAttnFwdCombine {
27
-
28
- public:
29
-
30
- // Type Aliases
31
- using TileShape_MK = TileShape_MK_;
32
- using ArchTag = ArchTag_;
33
- static constexpr int kMaxSplits = 1 << kLogMaxSplits_;
34
- static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float)));
35
- static_assert(AlignmentLSE >= 1);
36
- static constexpr int kStages = 4;
37
-
38
- static_assert(ArchTag::kMinComputeCapability >= 75);
39
- static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;
40
-
41
- static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
42
- static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
43
-
44
- static constexpr int kBlockM = get<0>(TileShape_MK{});
45
- static constexpr int kBlockK = get<1>(TileShape_MK{});
46
-
47
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial);
48
- static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad");
49
- static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32);
50
- static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
51
- static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
52
- using GmemCopyAtom = std::conditional_t<
53
- Has_cp_async,
54
- cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, ElementPartial>,
55
- cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>
56
- >;
57
- using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
58
- Stride<Int<kGmemThreadsPerRow>, _1>>;
59
- static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);
60
- using GmemTiledCopyAccum = decltype(
61
- make_tiled_copy(GmemCopyAtom{},
62
- GmemLayoutAtom{},
63
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
64
- using GmemTiledCopy = decltype(
65
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
66
- GmemLayoutAtom{},
67
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
68
-
69
- using AlignmentTypeLSE = cute::uint_byte_t<static_cast<int>(sizeof(float)) * AlignmentLSE>;
70
- static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float);
71
- static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE");
72
- static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8");
73
- static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8)));
74
- static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE;
75
- static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE");
76
- using GmemLayoutAtomLSE = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRowLSE>, Int<kGmemThreadsPerRowLSE>>,
77
- Stride<Int<kGmemThreadsPerRowLSE>, _1>>;
78
- static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0);
79
- using GmemCopyAtomLSE = std::conditional_t<
80
- Has_cp_async,
81
- cute::Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeLSE>, float>,
82
- cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<AlignmentLSE * sizeof(float) * 8>, float>
83
- >;
84
- using GmemTiledCopyLSE = decltype(
85
- make_tiled_copy(GmemCopyAtomLSE{},
86
- GmemLayoutAtomLSE{},
87
- Layout<Shape<_1, Int<kGmemElemsPerLoadLSE>>>{})); // Val layout, 4 vals per load
88
-
89
- // Otherwise we get IMA when some threads access sLSE, as we're not doing any masking
90
- static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE");
91
- // This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
92
- using SmemLSESwizzle = std::conditional_t<
93
- kBlockMSmem == 8,
94
- Swizzle<5, 0, 5>,
95
- std::conditional_t<kBlockMSmem == 16, Swizzle<4, 0, 4>, Swizzle<3, 2, 3>>
96
- >;
97
- using SmemLayoutAtomLSE =
98
- decltype(composition(SmemLSESwizzle{},
99
- Layout<Shape<Int<8>, Int<kBlockMSmem>>,
100
- Stride<Int<kBlockMSmem>, _1>>{}));
101
- using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape<Int<kMaxSplits>, Int<kBlockM>>{}));
102
-
103
- using SmemLayoutO = Layout<Shape<Int<kBlockM>, Int<kBlockK>, Int<kStages>>,
104
- Stride<Int<kBlockK>, _1, Int<kBlockM * kBlockK>>>;
105
-
106
- // We want each column (kMaxSplits) to be processed by threads in the same warp.
107
- // To reduce the number of shuffles, we want as few threads on the same column as possible.
108
- // E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column
109
- // have have 64 such quads.
110
- static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem");
111
- static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem;
112
- static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp");
113
- using S2RLayoutAtomLSE = Layout<Shape<Int<kSmemThreadsPerColLSEt>, Int<MaxThreadsPerBlock / kSmemThreadsPerColLSEt>>>;
114
- using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, float>{}, S2RLayoutAtomLSE{}, Layout<_1>{}));
115
-
116
- using ShapeOPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, num_splits, head, batch)
117
- using StrideOPartial = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
118
- using ShapeLSEPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, num_splits, head, batch)
119
- using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen, num_splits, head, batch)
120
- using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
121
- using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
122
- using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
123
- using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
124
-
125
- struct BlockCoord {
126
- int block_m;
127
- int block_k;
128
- int bidb;
129
- };
130
-
131
- struct SharedStorage : cute::aligned_struct<128> {
132
- cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
133
- cute::array_aligned<int, kBlockM> smem_max_valid_split;
134
- cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
135
- BlockCoord block_coord;
136
- };
137
-
138
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
139
-
140
- // Device side arguments
141
- struct Arguments {
142
- int b;
143
- ElementPartial const* const ptr_O_partial;
144
- ShapeOPartial const shape_O_partial;
145
- StrideOPartial const stride_O_partial;
146
- float const* const ptr_LSE_partial;
147
- ShapeLSEPartial const shape_LSE_partial;
148
- StrideLSEPartial const stride_LSE_partial;
149
- Element* const ptr_O;
150
- StrideO const stride_O;
151
- float* const ptr_LSE;
152
- StrideLSE const stride_LSE;
153
- int const* const cu_seqlens = nullptr;
154
- int const* const seqused = nullptr;
155
- int const* const num_splits_dynamic_ptr = nullptr;
156
- int* const semaphore_to_reset = nullptr;
157
- };
158
-
159
- // Kernel entry point API
160
- struct CollectiveParams {
161
- int b;
162
- ElementPartial const* const ptr_O_partial;
163
- ShapeOPartial const shape_O_partial;
164
- StrideOPartial const stride_O_partial;
165
- float const* const ptr_LSE_partial;
166
- ShapeLSEPartial const shape_LSE_partial;
167
- StrideLSEPartial const stride_LSE_partial;
168
- Element* const ptr_O;
169
- StrideO const stride_O;
170
- float* const ptr_LSE;
171
- StrideLSE const stride_LSE;
172
- cutlass::FastDivmod seqlen_divmod, head_divmod;
173
- int const* const cu_seqlens = nullptr;
174
- int const* const seqused = nullptr;
175
- int const* const num_splits_dynamic_ptr = nullptr;
176
- int* const semaphore_to_reset = nullptr;
177
- };
178
-
179
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
180
- static
181
- CollectiveParams
182
- to_underlying_arguments(Arguments const& args) {
183
- assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
184
- return {
185
- args.b,
186
- args.ptr_O_partial,
187
- args.shape_O_partial,
188
- args.stride_O_partial,
189
- args.ptr_LSE_partial,
190
- args.shape_LSE_partial,
191
- args.stride_LSE_partial,
192
- args.ptr_O,
193
- args.stride_O,
194
- args.ptr_LSE,
195
- args.stride_LSE,
196
- cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)),
197
- args.cu_seqlens,
198
- args.seqused,
199
- args.num_splits_dynamic_ptr,
200
- args.semaphore_to_reset
201
- };
202
- }
203
-
204
- struct SchedulerArguments {
205
- int b;
206
- int seqlen_q;
207
- int total_q;
208
- int num_heads;
209
- int dv;
210
- int const* cu_seqlens_q;
211
- int const* seqused_q;
212
- };
213
-
214
- struct StaticTileScheduler {
215
- struct Params {};
216
- static Params to_underlying_arguments(SchedulerArguments const& args) { return {}; }
217
-
218
- SharedStorage& shared_storage;
219
- CUTE_DEVICE StaticTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {}
220
-
221
- static dim3 get_grid_shape(SchedulerArguments const& args) {
222
- unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
223
- unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM);
224
- return {num_blocks_m, num_blocks_k, static_cast<unsigned int>(args.b)};
225
- }
226
-
227
- CUTE_DEVICE BlockCoord get_block_coord(Params const& params) {
228
- int block_m = blockIdx.x;
229
- int block_k = blockIdx.y;
230
- int bidb = blockIdx.z;
231
- return {block_m, block_k, bidb};
232
- }
233
- };
234
-
235
- struct StaticVarlenTileScheduler {
236
- //
237
- // For varlen we have two Scheduling algos:
238
- // 1) STANDARD, same as StaticTileScheduler
239
- // 2) LINEARIZE_M_AND_BATCH, this flattens the tiled M dimension and
240
- // batch dimension into a linear tile index. The grid is then a
241
- // 2D grid of (tile_id, k_block). We then map the linear tile id
242
- // to (m_block, bidb) in the get_block_coord function. This mapping
243
- // is non-trivial since each batch element can have a different
244
- // number of m_blocks. This has overhead when computing the block
245
- // coordinates, but it is more efficient when prefills and decodes
246
- // are mixed since in that case the STANDARD scheduling algo will
247
- // have a lot of empty (no work) blocks in the grid.
248
- //
249
-
250
- enum SchedulingAlgo {
251
- STANDARD, // Same as StaticTileScheduler
252
- LINEARIZE_M_AND_BATCH, // Linearize the M and batch dimensions into a single tile index
253
- };
254
-
255
- struct Params {
256
- int b;
257
- int num_heads;
258
- int const* const cu_seqlens_q;
259
- int const* const seqused_q;
260
- SchedulingAlgo algo;
261
- };
262
-
263
- SharedStorage& shared_storage;
264
- CUTE_DEVICE StaticVarlenTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {}
265
-
266
- static SchedulingAlgo choose_scheduling_algo(SchedulerArguments const& args) {
267
- // Choose the scheduling algorithm based on how dense the grid of tiles that
268
- // do actual work is. If the grid is more then 50% sparse, we linearize the M
269
- // and batch. If the grid is more than 50% dense, we use the standard scheduling
270
- // algorithm since its more efficient at calculating the block coordinates.
271
- // NOTE: in varlen case args.seqlen_q is the max seqlen_q across all batches
272
- // use lower bound to estimate when the density is more than 50%
273
- int lower_bound_on_non_empty_tiles = cute::ceil_div(args.total_q, kBlockM);
274
- int grid_size = args.b * cute::ceil_div(args.seqlen_q, kBlockM);
275
- return 2 * lower_bound_on_non_empty_tiles >= grid_size ?
276
- SchedulingAlgo::STANDARD :
277
- SchedulingAlgo::LINEARIZE_M_AND_BATCH;
278
- }
279
-
280
- static Params to_underlying_arguments(SchedulerArguments const& args) {
281
- return {
282
- args.b,
283
- args.num_heads,
284
- args.cu_seqlens_q,
285
- args.seqused_q,
286
- choose_scheduling_algo(args)
287
- };
288
- }
289
-
290
- static dim3 get_grid_shape(SchedulerArguments const& args) {
291
- unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
292
-
293
- switch (choose_scheduling_algo(args)) {
294
- case SchedulingAlgo::STANDARD: {
295
- unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
296
- unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM);
297
- return {num_blocks_m, num_blocks_k, static_cast<unsigned int>(args.b)};
298
- }
299
- case SchedulingAlgo::LINEARIZE_M_AND_BATCH: {
300
- // rough worst case upper bound on the number of blocks required
301
- // (assuming each batch has an additional partial block)
302
- unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b;
303
- return {num_blocks_m, num_blocks_k, 1};
304
- }}
305
-
306
- // rough worst case upper bound on the number of blocks required
307
- // (assuming each batch has an additional partial block)
308
- unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b;
309
- return {num_blocks_m, num_blocks_k, 1};
310
- }
311
-
312
- CUTE_DEVICE BlockCoord get_block_coord_linearized_m_and_batch(Params const& params) {
313
- int num_heads = params.num_heads;
314
- int curr_tile_id = blockIdx.x;
315
-
316
- // Scan through the batches find the batch that contains the current
317
- // tile_id. Compute using only the first warp of the block.
318
- if (threadIdx.x < 32) {
319
- // We compute linearized tile index start and ends for each batch
320
- // in groups of 32 in parallel
321
- int group_start_bidb = -(cutlass::NumThreadsPerWarp);
322
- int group_end_bidb = 0;
323
- int group_end_tile_id = 0;
324
- int group_start_tile_id = 0;
325
- int group_total_num_tiles = 0;
326
-
327
- int local_num_m_blocks = 0;
328
- int local_num_m_blocks_cumulative = 0;
329
-
330
- do {
331
- group_start_bidb += cutlass::NumThreadsPerWarp;
332
- group_end_bidb += cutlass::NumThreadsPerWarp;
333
-
334
- auto get_num_m_blocks = [&](int bidb) {
335
- if (bidb >= params.b) return 0;
336
- flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, 0, params.cu_seqlens_q, params.seqused_q};
337
- return cute::ceil_div(seqlen_info.seqlen * num_heads, Int<kBlockM>{}());
338
- };
339
-
340
- // Cumulative number of blocks for the next 31 batches
341
- local_num_m_blocks = get_num_m_blocks(group_start_bidb + threadIdx.x);
342
- local_num_m_blocks_cumulative = warp_prefix_sum(local_num_m_blocks);
343
- // Total number of blocks for the next 32 batches
344
- group_total_num_tiles = warp_shfl_get_last(local_num_m_blocks_cumulative);
345
-
346
- group_start_tile_id = group_end_tile_id;
347
- group_end_tile_id += group_total_num_tiles;
348
- } while (curr_tile_id >= group_end_tile_id && group_end_bidb < params.b);
349
-
350
- int local_batch_end_tile_id = group_start_tile_id + local_num_m_blocks_cumulative;
351
- // Find the last batch idx in the group where `local_batch_end_tile_id <= curr_tile_id`
352
- // these values below are now common to all threads in the warp
353
- int batch_idx_in_group = warp_last_true_laneid(local_batch_end_tile_id <= curr_tile_id);
354
- int batch_num_m_blocks = warp_shfl_get(local_num_m_blocks, batch_idx_in_group);
355
- int batch_m_start_tile_id = group_start_tile_id + (batch_idx_in_group > 0 ?
356
- warp_shfl_get(local_num_m_blocks_cumulative, batch_idx_in_group - 1) : 0);
357
-
358
- int bidb = group_start_bidb + batch_idx_in_group;
359
- int block_m = curr_tile_id - batch_m_start_tile_id;
360
- // NOTE(lucas): not sure why this causes a block_k unused warning
361
- // just inlined `blockIdx.y` to suppress the warning
362
- // int block_k = blockIdx.y;
363
- // shared_storage.block_coord = {block_m, block_k, bidb};
364
- BlockCoord block_coord{block_m, static_cast<int>(blockIdx.y), bidb};
365
- if (threadIdx.x == 0) { shared_storage.block_coord = block_coord; }
366
- }
367
-
368
- __syncthreads();
369
- return shared_storage.block_coord;
370
- }
371
-
372
-
373
- CUTE_DEVICE BlockCoord get_block_coord_standard(Params const& params) {
374
- int block_m = blockIdx.x;
375
- int block_k = blockIdx.y;
376
- int bidb = blockIdx.z;
377
- return {block_m, block_k, bidb};
378
- }
379
-
380
- CUTE_DEVICE BlockCoord get_block_coord(Params const& params) {
381
- switch (params.algo) {
382
- case SchedulingAlgo::STANDARD:
383
- return get_block_coord_standard(params);
384
- case SchedulingAlgo::LINEARIZE_M_AND_BATCH:
385
- return get_block_coord_linearized_m_and_batch(params);
386
- }
387
- return {0, 0, 0}; // Should never reach here
388
- }
389
- };
390
-
391
- using TileScheduler = std::conditional_t<
392
- Varlen,
393
- StaticVarlenTileScheduler,
394
- StaticTileScheduler
395
- >;
396
-
397
- using SchedulerParams = typename TileScheduler::Params;
398
-
399
- struct Params {
400
- CollectiveParams params;
401
- SchedulerParams scheduler_params;
402
- };
403
-
404
- CUTLASS_DEVICE
405
- void
406
- operator()(Params const& kernel_params, char* smem_buf) {
407
- CollectiveParams const& params = kernel_params.params;
408
-
409
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
410
- TileScheduler tile_scheduler{shared_storage};
411
-
412
- Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
413
- Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
414
- Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
415
-
416
- int const thread_idx = threadIdx.x;
417
-
418
- BlockCoord block_coord = tile_scheduler.get_block_coord(kernel_params.scheduler_params);
419
-
420
- int const m_block = block_coord.block_m;
421
- int const k_block = block_coord.block_k;
422
- int const batch = block_coord.bidb;
423
-
424
- if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
425
- cutlass::arch::wait_on_dependent_grids();
426
- *params.semaphore_to_reset = 0;
427
- }
428
-
429
- flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
430
- int const offset = seqlen_info.offset;
431
- int const seqlen = seqlen_info.seqlen;
432
- int max_idx = seqlen * get<2>(params.shape_LSE_partial);
433
-
434
- bool block_coord_valid =
435
- block_coord.block_m < cute::ceil_div(max_idx, Int<kBlockM>{}) &&
436
- block_coord.bidb < params.b;
437
- if (!block_coord_valid) { return; }
438
-
439
- int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
440
- if (num_splits <= 1) { return; }
441
-
442
- cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
443
-
444
- // Step 1: load LSE_partial from gmem -> smem
445
- Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)),
446
- select<1, 0, 2, 3>(params.shape_LSE_partial),
447
- select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head)
448
- Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int<kGmemElemsPerLoadLSE>>{});
449
- GmemTiledCopyLSE gmem_tiled_copy_LSE;
450
- auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx);
451
- Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE);
452
-
453
- // Construct identity layout for sLSE
454
- Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE))); // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m)
455
- // Repeat the partitioning with identity layouts
456
- Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE);
457
-
458
- cutlass::arch::wait_on_dependent_grids();
459
-
460
- #pragma unroll
461
- for (int m = 0; m < size<2>(tLSEcLSE); ++m) {
462
- int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m)));
463
- int idx = m_block * kBlockM + mi;
464
- if (idx < max_idx) {
465
- int m_idx, bidh;
466
- if constexpr (!Varlen) {
467
- bidh = params.seqlen_divmod.divmod(m_idx, idx);
468
- } else {
469
- bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
470
- }
471
- Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh);
472
- #pragma unroll
473
- for (int s = 0; s < size<1>(tLSEcLSE); ++s) {
474
- int si = get<0>(tLSEcLSE(_0{}, s, _0{}));
475
- // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast<int>(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);}
476
- if (si < num_splits) {
477
- cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m));
478
- } else {
479
- cute::fill(tLSEsLSE(_, s, m), -INFINITY);
480
- }
481
- }
482
- } else {
483
- // We don't need to zero out the rest of the LSEs, as we will not write the output to gmem
484
- // cute::fill(tLSEsLSE(_, _, m), -INFINITY);
485
- }
486
- }
487
- if constexpr (Has_cp_async) { cute::cp_async_fence(); }
488
-
489
- // Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2.
490
- // We want these async loads to be in flight as we compute the LSE.
491
- GmemTiledCopyAccum gmem_tiled_copy_O_partial;
492
- auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx);
493
- // Construct identity layout for gO
494
- Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
495
- // Repeat the partitioning with identity layouts
496
- Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO);
497
- Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)),
498
- params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head)
499
-
500
- // Precompute these values to avoid recomputing them in the loop
501
- Tensor tOmidx = make_tensor<int>(make_shape(size<1>(tOcO)));
502
- Tensor tObidh = make_tensor<int>(make_shape(size<1>(tOcO)));
503
- Tensor tOrOptr = make_tensor<ElementPartial const*>(make_shape(size<1>(tOcO)));
504
- #pragma unroll
505
- for (int m = 0; m < size<1>(tOcO); ++m) {
506
- int mi = get<0>(tOcO(_0{}, m, _0{}));
507
- int idx = m_block * kBlockM + mi;
508
- if constexpr (!Varlen) {
509
- tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx);
510
- } else {
511
- tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx);
512
- }
513
- tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m));
514
- if (idx >= max_idx) {
515
- tObidh[m] = -1;
516
- }
517
- }
518
-
519
- Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
520
- if constexpr (!(Is_even_K)) {
521
- #pragma unroll
522
- for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; }
523
- }
524
-
525
- Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO);
526
-
527
- auto load_O_partial = [&] (int split, int stage) {
528
- Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage);
529
- #pragma unroll
530
- for (int m = 0; m < size<1>(tOcO); ++m) {
531
- if (tObidh(m) >= 0) {
532
- Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout());
533
- Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape<Int<kGmemElemsPerLoad>>{});
534
- #pragma unroll
535
- for (int k = 0; k < size<2>(tOcO); ++k) {
536
- int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
537
- if (Is_even_K || tOpO(k)) {
538
- cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k));
539
- }
540
- }
541
- }
542
- }
543
- };
544
-
545
- for (int s = 0; s < kStages - 1; ++s) {
546
- if (s < num_splits) { load_O_partial(s, s); }
547
- if constexpr (Has_cp_async) { cute::cp_async_fence(); }
548
- }
549
-
550
- // Step 3: load and transpose LSE_partial from smem -> rmem
551
- if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
552
- __syncthreads();
553
-
554
- S2RTiledCopyLSE s2r_tiled_copy_LSE;
555
- auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx);
556
- Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE);
557
- Tensor ts2rrLSE = make_fragment_like(ts2rsLSE);
558
- cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE);
559
-
560
- // Step 4: compute the final LSE along the split dimension
561
- Tensor lse_sum = make_tensor<float>(make_shape(size<2>(ts2rrLSE)));
562
- Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE);
563
- // We compute the max valid split for each row to short-circuit the computation later
564
- Tensor max_valid_split = make_tensor<int>(make_shape(size<2>(ts2rrLSE)));
565
- static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1);
566
- #pragma unroll
567
- for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
568
- float lse_max = ts2rrLSE(_0{}, _0{}, m);
569
- #pragma unroll
570
- for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); }
571
- MaxOp<float> max_op;
572
- lse_max = Allreduce<kSmemThreadsPerColLSEt>::run(lse_max, max_op);
573
- int max_valid_idx = -1;
574
- #pragma unroll
575
- for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
576
- if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); }
577
- }
578
- MaxOp<int> max_int_op;
579
- max_valid_split[m] = Allreduce<kSmemThreadsPerColLSEt>::run(max_valid_idx, max_int_op);
580
- float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
581
- float lse_sum_cur = 0.f;
582
- #pragma unroll
583
- for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
584
- float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur);
585
- lse_sum_cur += scale;
586
- // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast<int>(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);}
587
- // ts2rsLSE(_0{}, m, s) = scale;
588
- ts2rrLSE(_0{}, s, m) = scale;
589
- }
590
- SumOp<float> sum_op;
591
- lse_sum_cur = Allreduce<kSmemThreadsPerColLSEt>::run(lse_sum_cur, sum_op);
592
- lse_sum(m) = logf(lse_sum_cur) + lse_max;
593
- float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur;
594
- #pragma unroll
595
- for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; }
596
- }
597
- // Store the scales exp(lse - lse_logsum) back to smem
598
- cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE);
599
-
600
- // Store max_valid_split to smem
601
- #pragma unroll
602
- for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
603
- if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem
604
- int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
605
- if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; }
606
- }
607
- }
608
-
609
- // Step 5: store final LSE back to gmem
610
- if (k_block == 0) {
611
- auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial);
612
- Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0);
613
- #pragma unroll
614
- for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
615
- if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem
616
- int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
617
- int idx = m_block * kBlockM + mi;
618
- if (idx < max_idx) {
619
- int m_idx, bidh;
620
- if constexpr (!Varlen) {
621
- bidh = params.seqlen_divmod.divmod(m_idx, idx);
622
- } else {
623
- bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
624
- }
625
- // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m));
626
- mLSE(m_idx, bidh) = lse_sum(m);
627
- }
628
- }
629
- }
630
- }
631
-
632
- // Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O
633
- __syncthreads();
634
- int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))];
635
- #pragma unroll
636
- for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); }
637
- Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor<ElementPartial>(TileShape_MK{})).layout();
638
- Tensor tOrOpartial = make_fragment_like<ElementPartial>(tOrOpartial_layout);
639
- Tensor tOrO = make_fragment_like<float>(tOrOpartial);
640
- clear(tOrO);
641
- int stage_load = kStages - 1, stage_compute = 0;
642
- #pragma unroll 4 // Already tuned for speed
643
- for (int s = 0; s <= thr_max_valid_split; ++s) {
644
- Tensor scale = make_tensor<float>(make_shape(size<1>(tOrOpartial)));
645
- #pragma unroll
646
- for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); }
647
-
648
- if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); }
649
- if constexpr (Has_cp_async) { cute::cp_async_fence(); }
650
- stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0;
651
- if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
652
- // We don't need __syncthreads() because each thread is just reading its own data from smem
653
- cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>{},
654
- tOsOpartial(_, _, _, stage_compute), tOrOpartial);
655
- stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0;
656
-
657
- #pragma unroll
658
- for (int m = 0; m < size<1>(tOrOpartial); ++m) {
659
- if (tObidh(m) >= 0 && scale(m) > 0.f) {
660
- #pragma unroll
661
- for (int k = 0; k < size<2>(tOrOpartial); ++k) {
662
- if (Is_even_K || tOpO(k)) {
663
- Tensor rOpartial = make_tensor_like<float>(tOrOpartial(_, m, k));
664
- flash::convert_type_out(tOrOpartial(_, m, k), rOpartial);
665
- #pragma unroll
666
- for (int i = 0; i < size<0>(tOrOpartial); ++i) {
667
- tOrO(i, m, k) += scale(m) * rOpartial[i];
668
- }
669
- }
670
- }
671
- }
672
- }
673
- }
674
-
675
- // Step 7: Write the final O to gmem
676
- Tensor rO = make_tensor_like<Element>(tOrO);
677
- flash::convert_type_out(tOrO, rO);
678
- auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial));
679
- Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)),
680
- shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0);
681
- Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int<kGmemElemsPerLoad>>{});
682
- GmemTiledCopy gmem_tiled_copy_O;
683
- auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
684
-
685
- #pragma unroll
686
- for (int m = 0; m < size<1>(tOcO); ++m) {
687
- if (tObidh(m) >= 0) {
688
- #pragma unroll
689
- for (int k = 0; k < size<2>(tOcO); ++k) {
690
- int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
691
- if (Is_even_K || tOpO(k)) {
692
- cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m)));
693
- }
694
- }
695
- }
696
- }
697
-
698
- }
699
-
700
- };
701
-
702
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_fwd_combine_launch_template.h DELETED
@@ -1,88 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include "cutlass/cutlass.h"
10
- #include "cutlass/arch/arch.h" // For cutlass::arch::Sm80
11
- #include "cutlass/device_kernel.h" // For device_kernel
12
- #include "cutlass/kernel_launch.h" // For kernel_launch
13
-
14
- #include "static_switch.h"
15
- #include "flash.h"
16
- #include "flash_fwd_combine_kernel.h"
17
-
18
- using namespace cute;
19
-
20
- template <int Arch, int kBlockM, int kBlockK, int kLogMaxSplits, bool IsEvenK, bool Varlen, typename Element, typename ElementPartial>
21
- void run_flash_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl) {
22
- using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
23
- using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kBlockK>>;
24
- using CombineKernel = flash::FlashAttnFwdCombine<TileShape_MK, kLogMaxSplits, 256 /*kNThreads*/, 1 /*AlignmentLSE*/,
25
- IsEvenK, Varlen, Element, ElementPartial, ArchTag>;
26
-
27
- typename CombineKernel::Arguments args {
28
- params.b,
29
- static_cast<ElementPartial const*>(params.oaccum_ptr),
30
- {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial
31
- {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial
32
- static_cast<float*>(params.softmax_lseaccum_ptr),
33
- {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial
34
- {_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0}, // stride_LSE_partial
35
- static_cast<Element*>(params.o_ptr),
36
- {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O
37
- static_cast<float*>(params.softmax_lse_ptr),
38
- {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE
39
- params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
40
- };
41
-
42
- typename CombineKernel::SchedulerArguments scheduler_args {
43
- params.b, params.seqlen_q, params.total_q, params.h, params.dv,
44
- params.cu_seqlens_q, params.seqused_q
45
- };
46
-
47
- typename CombineKernel::Params kernel_params = {
48
- CombineKernel::to_underlying_arguments(args),
49
- CombineKernel::TileScheduler::to_underlying_arguments(scheduler_args)
50
- };
51
-
52
- dim3 grid_m = CombineKernel::TileScheduler::get_grid_shape(scheduler_args);
53
- auto kernel = cutlass::device_kernel<CombineKernel>;
54
- int smem_size = CombineKernel::SharedStorageSize;
55
- if (smem_size >= 48 * 1024) {
56
- CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
57
- }
58
- // kernel<<<grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream>>>(kernel_params);
59
- cutlass::kernel_launch<CombineKernel>(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/);
60
- CHECK_CUDA_KERNEL_LAUNCH();
61
- }
62
-
63
- template<typename T, typename Tpartial, int kBlockK>
64
- void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl) {
65
- // We want kBlockM to be as small as possible to maximize parallelism.
66
- // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
67
- static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32");
68
- static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32);
69
- ARCH_SWITCH(params.arch, Arch, [&] {
70
- BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] {
71
- if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32.
72
- if (params.num_splits <= 16) {
73
- run_flash_fwd_combine<Arch, kBlockM, kBlockK, 4, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
74
- return;
75
- }
76
- }
77
- if (params.num_splits <= 32) {
78
- run_flash_fwd_combine<Arch, kBlockM, kBlockK, 5, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
79
- } else if (params.num_splits <= 64) {
80
- run_flash_fwd_combine<Arch, kBlockM, kBlockK, 6, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
81
- } else if (params.num_splits <= 128) {
82
- run_flash_fwd_combine<Arch, kBlockM, kBlockK, 7, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
83
- } else {
84
- run_flash_fwd_combine<Arch, kBlockM, kBlockK, 8, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
85
- }
86
- });
87
- });
88
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_fwd_kernel_sm80.h DELETED
@@ -1,215 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include <cutlass/cutlass.h>
10
- #include <cutlass/array.h>
11
- #include <cutlass/numeric_types.h>
12
- #include <cutlass/kernel_hardware_info.h>
13
-
14
- #include "seqlen.h"
15
- #include "utils.h"
16
- #include "softmax.h"
17
-
18
- namespace flash {
19
-
20
- using namespace cute;
21
-
22
- template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
23
- class FlashAttnFwdSm80 {
24
-
25
- public:
26
-
27
- // Type Aliases
28
- using CollectiveMainloop = CollectiveMainloop_;
29
- using CollectiveEpilogue = CollectiveEpilogue_;
30
- static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
31
- static constexpr bool Is_local = CollectiveMainloop::Is_local;
32
- static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
33
- static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
34
- static constexpr bool Varlen = CollectiveMainloop::Varlen;
35
- static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
36
- static constexpr bool Split = CollectiveMainloop::Split;
37
- static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
38
- static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
39
- static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
40
- static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
41
- static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
42
- using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
43
-
44
- // Mainloop derived types
45
- using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
46
- using TiledMma = typename CollectiveMainloop::TiledMma;
47
- using ArchTag = typename CollectiveMainloop::ArchTag;
48
- using MainloopArguments = typename CollectiveMainloop::Arguments;
49
- using MainloopParams = typename CollectiveMainloop::Params;
50
-
51
- // Epilogue derived types
52
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
53
- using EpilogueParams = typename CollectiveEpilogue::Params;
54
-
55
- static_assert(ArchTag::kMinComputeCapability >= 80);
56
-
57
- using TileScheduler = TileScheduler_;
58
- using TileSchedulerArguments = typename flash::TileSchedulerArguments;
59
- using TileSchedulerParams = typename TileScheduler::Params;
60
-
61
- static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{}));
62
- static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{}));
63
- static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1;
64
-
65
- // Kernel level shared memory storage
66
- // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q
67
- // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k).
68
- static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage))
69
- - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)))
70
- - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)));
71
- static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
72
- struct SharedStorage {
73
- struct TensorStorage : cute::aligned_struct<128> {
74
- union {
75
- struct {
76
- cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
77
- typename CollectiveMainloop::TensorStorage mainloop;
78
- };
79
- // We want smem_o to line up with the start of smem_v
80
- typename CollectiveEpilogue::TensorStorage epilogue;
81
- };
82
- } tensors;
83
-
84
- alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
85
-
86
- };
87
-
88
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
89
-
90
- // Device side arguments
91
- struct Arguments {
92
- MainloopArguments mainloop{};
93
- EpilogueArguments epilogue{};
94
- cutlass::KernelHardwareInfo hw_info{};
95
- TileSchedulerArguments scheduler{};
96
- };
97
-
98
- // Kernel entry point API
99
- struct Params {
100
- MainloopParams mainloop{};
101
- EpilogueParams epilogue{};
102
- cutlass::KernelHardwareInfo hw_info{};
103
- TileSchedulerParams scheduler{};
104
- };
105
-
106
- //
107
- // Methods
108
- //
109
-
110
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
111
- static
112
- Params
113
- to_underlying_arguments(Arguments const& args) {
114
- CUTLASS_TRACE_HOST("to_underlying_arguments():");
115
-
116
- // Get SM count if needed, otherwise use user supplied SM count
117
- int sm_count = args.hw_info.sm_count;
118
- if (sm_count <= 0) {
119
- CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
120
- " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
121
- sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
122
- }
123
-
124
- CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
125
-
126
- cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
127
- return {
128
- CollectiveMainloop::to_underlying_arguments(args.mainloop),
129
- CollectiveEpilogue::to_underlying_arguments(args.epilogue),
130
- hw_info,
131
- TileScheduler::to_underlying_arguments(args.scheduler)
132
- };
133
- }
134
-
135
- // Computes the kernel launch grid shape based on runtime parameters
136
- static dim3
137
- get_grid_shape(Params const& params) {
138
- return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor);
139
- }
140
-
141
- static dim3
142
- get_block_shape() {
143
- return dim3(MaxThreadsPerBlock, 1, 1);
144
- }
145
-
146
- CUTLASS_DEVICE
147
- void
148
- operator()(Params const& params, char* smem_buf) {
149
-
150
- static constexpr int kBlockM = get<0>(TileShape_MNK{});
151
-
152
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
153
-
154
- CollectiveMainloop mainloop;
155
- CollectiveEpilogue epilogue;
156
-
157
- TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
158
- // Initialize matmul objects.
159
- TiledMma tiled_mma;
160
-
161
- scheduler.init_consumer();
162
-
163
- int warp_idx = cutlass::canonical_warp_idx_sync();
164
- CUTLASS_PRAGMA_NO_UNROLL
165
- for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
166
- work_tile_info.is_valid(params.scheduler);
167
- work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
168
- // Attention output (GEMM-II) accumulator.
169
- Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{}));
170
- float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
171
- // If there's tanh softcap, the scaling will be done before tanh.
172
- auto block_coord = work_tile_info.get_block_coord(params.scheduler);
173
- int const bidb = get<2>(block_coord);
174
- if constexpr (Is_FP8 && !Has_softcap) {
175
- int const bidh = get<1>(block_coord);
176
- int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
177
- float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
178
- float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
179
- softmax_scale_log2 *= q_descale * k_descale;
180
- }
181
- flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
182
-
183
- SeqlenInfo_t seqlen_info{
184
- bidb,
185
- get<0>(params.mainloop.shape_Q),
186
- !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
187
- get<0>(params.mainloop.shape_K_new),
188
- params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
189
- params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
190
- params.mainloop.seqlens_rotary
191
- };
192
- if constexpr (AppendKV) {
193
- bool tile_new_valid = mainloop.store_kv_new(
194
- params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord);
195
- if (tile_new_valid) { __syncthreads(); }
196
- }
197
- bool tile_valid = mainloop.mma(
198
- params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord,
199
- shared_storage);
200
- scheduler.prefetch_next_work(params.scheduler, work_tile_info);
201
- if (tile_valid) {
202
- // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
203
- epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma,
204
- threadIdx.x, block_coord);
205
- } else {
206
- // Write 0 to gO and -inf to gLSE.
207
- epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
208
- }
209
- }
210
-
211
- }
212
-
213
- };
214
-
215
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_fwd_kernel_sm90.h DELETED
@@ -1,468 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include <cutlass/cutlass.h>
10
- #include <cutlass/arch/reg_reconfig.h>
11
- #include <cutlass/array.h>
12
- #include <cutlass/numeric_types.h>
13
- #include <cutlass/numeric_conversion.h>
14
- #include <cutlass/kernel_hardware_info.h>
15
- #include "cutlass/pipeline/pipeline.hpp"
16
-
17
- #include "cutlass/arch/grid_dependency_control.h"
18
-
19
- #include "seqlen.h"
20
- #include "utils.h"
21
- #include "softmax.h"
22
-
23
- namespace flash {
24
-
25
- using namespace cute;
26
-
27
- template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
28
- class FlashAttnFwdSm90 {
29
-
30
- public:
31
-
32
- // Type Aliases
33
- using CollectiveMainloop = CollectiveMainloop_;
34
- using CollectiveEpilogue = CollectiveEpilogue_;
35
- static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
36
- static constexpr bool Is_local = CollectiveMainloop::Is_local;
37
- static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
38
- static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
39
- static constexpr bool Varlen = CollectiveMainloop::Varlen;
40
- static constexpr bool Split = CollectiveMainloop::Split;
41
- static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
42
- static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
43
- static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
44
- static constexpr bool HasQv = CollectiveMainloop::HasQv;
45
- static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q;
46
- static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV;
47
- static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O;
48
- static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
49
- static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
50
- static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim;
51
- static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV;
52
- static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV);
53
- using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
54
-
55
- using SmemLayoutSAux = typename CollectiveMainloop::SmemLayoutSAux;
56
-
57
- // Mainloop derived types
58
- using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV;
59
- using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV;
60
- using ArchTag = typename CollectiveMainloop::ArchTag;
61
- using ClusterShape = typename CollectiveMainloop::ClusterShape;
62
- using MainloopArguments = typename CollectiveMainloop::Arguments;
63
- using MainloopParams = typename CollectiveMainloop::Params;
64
- using BarrierQ = std::conditional_t<Use_TMA_Q, cutlass::arch::ClusterTransactionBarrier, cutlass::arch::ClusterBarrier>;
65
-
66
- // Epilogue derived types
67
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
68
- using EpilogueParams = typename CollectiveEpilogue::Params;
69
-
70
- static_assert(ArchTag::kMinComputeCapability >= 90);
71
-
72
- using TileScheduler = TileScheduler_;
73
- using TileSchedulerArguments = typename flash::TileSchedulerArguments;
74
- using TileSchedulerParams = typename TileScheduler::Params;
75
-
76
- static constexpr uint32_t NumLoadWarpGroups = 1;
77
- static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup;
78
- static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
79
- static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
80
- static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
81
-
82
- /// Register requirement for Load and Math WGs
83
- // If we use cp.async to load K and V, we need more registers for the producer WG.
84
- static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);
85
- static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);
86
- // If you want to print from the producer warp, you'd need to increase the number of registers
87
- // Otherwise you'll get CUDA error.
88
- // static constexpr uint32_t LoadRegisterRequirement = 40;
89
- // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
90
-
91
- // Kernel level shared memory storage
92
- // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v
93
- // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v).
94
- static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)));
95
- static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
96
- struct SharedStorage {
97
- struct TensorStorage : cute::aligned_struct<128, _1> {
98
- union {
99
- struct {
100
- cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
101
- typename CollectiveMainloop::TensorStorage mainloop;
102
- };
103
- // We want smem_o to line up with the start of smem_v
104
- typename CollectiveEpilogue::TensorStorage epilogue;
105
- };
106
- } tensors;
107
- struct PipelineStorage : cute::aligned_struct<16, _1> {
108
- alignas(16) BarrierQ barrier_Q;
109
- alignas(16) BarrierQ barrier_Qv;
110
- alignas(16) cutlass::arch::ClusterBarrier barrier_O;
111
- alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;
112
- alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;
113
- alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt;
114
- alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new;
115
- alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new;
116
- alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
117
- } pipelines;
118
-
119
- };
120
-
121
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
122
-
123
- // Device side arguments
124
- struct Arguments {
125
- MainloopArguments mainloop{};
126
- EpilogueArguments epilogue{};
127
- cutlass::KernelHardwareInfo hw_info{};
128
- TileSchedulerArguments scheduler{};
129
- };
130
-
131
- // Kernel entry point API
132
- struct Params {
133
- MainloopParams mainloop{};
134
- EpilogueParams epilogue{};
135
- cutlass::KernelHardwareInfo hw_info{};
136
- TileSchedulerParams scheduler{};
137
- };
138
-
139
- //
140
- // Methods
141
- //
142
-
143
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
144
- static
145
- Params
146
- to_underlying_arguments(Arguments const& args) {
147
- CUTLASS_TRACE_HOST("to_underlying_arguments():");
148
-
149
- // Get SM count if needed, otherwise use user supplied SM count
150
- int sm_count = args.hw_info.sm_count;
151
- if (sm_count <= 0) {
152
- CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
153
- " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
154
- sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
155
- }
156
-
157
- CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
158
-
159
- cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
160
- return {
161
- CollectiveMainloop::to_underlying_arguments(args.mainloop),
162
- CollectiveEpilogue::to_underlying_arguments(args.epilogue),
163
- hw_info,
164
- TileScheduler::to_underlying_arguments(args.scheduler)
165
- };
166
- }
167
-
168
- // Computes the kernel launch grid shape based on runtime parameters
169
- static dim3
170
- get_grid_shape(Params const& params) {
171
- return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
172
- }
173
-
174
- static dim3
175
- get_block_shape() {
176
- return dim3(MaxThreadsPerBlock, 1, 1);
177
- }
178
-
179
- CUTLASS_DEVICE
180
- void
181
- operator()(Params const& params, char* smem_buf) {
182
-
183
- static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
184
- static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
185
- static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
186
-
187
- using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;
188
- using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;
189
- using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt;
190
- using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew;
191
- using PipelineState = typename CollectiveMainloop::PipelineState;
192
- using PipelineParamsK = typename MainloopPipelineK::Params;
193
- using PipelineParamsV = typename MainloopPipelineV::Params;
194
- using PipelineParamsVt = typename MainloopPipelineVt::Params;
195
- using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params;
196
-
197
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
198
-
199
- int const lane_predicate = cute::elect_one_sync();
200
- int const warp_idx = cutlass::canonical_warp_idx_sync();
201
-
202
- // Issue Tma Descriptor Prefetch from a single thread
203
- if (warp_idx == 0 && lane_predicate) {
204
- CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
205
- CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
206
- }
207
-
208
- // Obtain warp index
209
- int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
210
- int warp_group_idx = cutlass::canonical_warp_group_idx();
211
-
212
- if (warp_idx == 0 && lane_predicate) {
213
- shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
214
- if constexpr (HasQv) {
215
- shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
216
- }
217
- shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/);
218
- }
219
-
220
- // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
221
- PipelineParamsK pipeline_params_k;
222
- pipeline_params_k.role = warp_group_idx == 0
223
- ? MainloopPipelineK::ThreadCategory::Producer
224
- : MainloopPipelineK::ThreadCategory::Consumer;
225
- if constexpr (Use_TMA_KV) {
226
- pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
227
- pipeline_params_k.is_leader = warp_group_thread_idx == 0;
228
- pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
229
- } else {
230
- pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
231
- pipeline_params_k.producer_arv_count = NumProducerThreads;
232
- }
233
-
234
- static_assert(is_same_v<PipelineParamsK, PipelineParamsVt>);
235
- PipelineParamsVt pipeline_params_vt = pipeline_params_k;
236
- if constexpr (Use_TMA_KV && !SameHeadDim) {
237
- pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
238
- if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; }
239
- } else {
240
- if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; }
241
- }
242
-
243
- MainloopPipelineK pipeline_k = [&] {
244
- if constexpr (Use_TMA_KV) {
245
- return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});
246
- } else {
247
- return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k);
248
- }
249
- }();
250
- // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{});
251
- MainloopPipelineV pipeline_v = [&] {
252
- if constexpr (!Transpose_V) {
253
- static_assert(is_same_v<PipelineParamsK, PipelineParamsV>);
254
- if constexpr (Use_TMA_KV) {
255
- return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{});
256
- } else {
257
- return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt);
258
- }
259
- } else {
260
- PipelineParamsV pipeline_params_v;
261
- pipeline_params_v.role = warp_group_idx == 0
262
- ? MainloopPipelineV::ThreadCategory::Producer
263
- : MainloopPipelineV::ThreadCategory::Consumer;
264
- pipeline_params_v.producer_arv_count = NumProducerThreads;
265
- pipeline_params_v.consumer_arv_count = NumMmaThreads;
266
- return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v);
267
- }
268
- }();
269
- // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then
270
- // the producer WG will read from pipeline_vt and write to pipeline_v.
271
- // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used.
272
- // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers.
273
- // However, the thread role isn't used in the pipeline implementation.
274
- MainloopPipelineVt pipeline_vt = [&] {
275
- if constexpr (Use_TMA_KV) {
276
- pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG
277
- return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{});
278
- } else {
279
- pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG
280
- return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt);
281
- }
282
- }();
283
-
284
- PipelineParamsKVNew pipeline_params_kv_new;
285
- pipeline_params_kv_new.role = warp_group_idx == 0
286
- ? MainloopPipelineKVNew::ThreadCategory::Producer
287
- : MainloopPipelineKVNew::ThreadCategory::Consumer;
288
- pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
289
- pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0;
290
- pipeline_params_kv_new.num_consumers = NumMmaThreads;
291
- auto pipeline_k_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
292
- if constexpr (!SameHeadDim) {
293
- pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
294
- }
295
- auto pipeline_v_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
296
-
297
- CollectiveMainloop mainloop;
298
- CollectiveEpilogue epilogue;
299
-
300
- const int num_heads = get<2>(params.mainloop.shape_Q);
301
- Tensor gS_aux = make_tensor(make_gmem_ptr(params.mainloop.ptr_S_aux), make_shape(num_heads));
302
- Tensor sS_aux = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_s_aux.data()), SmemLayoutSAux{});
303
-
304
- if(params.mainloop.ptr_S_aux && threadIdx.x < num_heads) {
305
- sS_aux(threadIdx.x) = gS_aux(threadIdx.x);
306
- }
307
-
308
- // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
309
- if constexpr (size(ClusterShape{}) > 1) {
310
- cute::cluster_arrive_relaxed();
311
- cute::cluster_wait();
312
- } else {
313
- __syncthreads();
314
- }
315
-
316
- TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
317
-
318
- if (warp_group_idx == 0) { // Producer
319
- cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
320
-
321
- // The pipelines for AppendKV and main attention are different, since e.g. main attention
322
- // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load
323
- // KV_new. Since the pipeline states are different, we have to manually sync to make
324
- // sure the two pipelines don't race when accessing smem_k and smem_v.
325
- PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
326
- PipelineState smem_pipe_write_new = cutlass::make_producer_start_state<MainloopPipelineKVNew>();
327
- int work_idx = 0;
328
- int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
329
- static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
330
- if constexpr (SingleProducerWarp) {
331
- if (warp_idx_in_warpgroup != 0) { return; }
332
- }
333
- if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); }
334
-
335
- cutlass::arch::wait_on_dependent_grids();
336
-
337
- // Load Q, K, V
338
- for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
339
- work_tile_info.is_valid(params.scheduler);
340
- work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
341
-
342
- auto block_coord = work_tile_info.get_block_coord(params.scheduler);
343
- SeqlenInfo_t seqlen_info{
344
- get<2>(block_coord) /*bidb*/,
345
- get<0>(params.mainloop.shape_Q),
346
- !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
347
- get<0>(params.mainloop.shape_K_new),
348
- params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
349
- params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
350
- params.mainloop.seqlens_rotary
351
- };
352
- if constexpr (AppendKV) {
353
- bool tile_new_valid = mainloop.load_kv_new(
354
- params.mainloop, pipeline_k_new, pipeline_v_new,
355
- smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx);
356
- if (tile_new_valid) {
357
- // if (threadIdx.x == 0) { printf("Producer: Before sync\n"); }
358
- cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
359
- // if (threadIdx.x == 0) { printf("Producer: After sync\n"); }
360
- }
361
- }
362
- auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
363
- scheduler.prefetch_next_work(params.scheduler, work_tile_info);
364
- };
365
- // pipeline_vt won't be used if we don't need to transpose V.
366
- mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write,
367
- shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx);
368
- }
369
- mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx);
370
- } else { // Consumer
371
- cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
372
-
373
- // Initialize matmul objects.
374
- TiledMmaPV tiled_mma_pv;
375
-
376
- PipelineState smem_pipe_read;
377
- PipelineState smem_pipe_read_new;
378
- // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
379
- // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
380
-
381
- scheduler.init_consumer();
382
- mainloop.mma_init();
383
-
384
- int work_idx = 0;
385
- CUTLASS_PRAGMA_NO_UNROLL
386
- for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
387
- work_tile_info.is_valid(params.scheduler);
388
- // get_next_work will be called before the epilogue
389
- ) {
390
- auto block_coord = work_tile_info.get_block_coord(params.scheduler);
391
- int const bidb = get<2>(block_coord);
392
- SeqlenInfo_t seqlen_info{
393
- bidb,
394
- get<0>(params.mainloop.shape_Q),
395
- !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
396
- get<0>(params.mainloop.shape_K_new),
397
- params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
398
- params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
399
- params.mainloop.seqlens_rotary
400
- };
401
- if constexpr (AppendKV) {
402
- bool tile_new_valid = mainloop.store_kv_new(
403
- params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new,
404
- threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord);
405
- if (tile_new_valid) {
406
- // if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); }
407
- // We need this sync so that the gmem write from the consumers is visible to the producer
408
- // that might do TMA read after that.
409
- asm volatile ("fence.proxy.async.global;");
410
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
411
- // arrive is enough, we don't need sync. The producer will sync, which means
412
- // after that sync we're guaranteed that the AppendKV pipeline have finished
413
- // loading and consumer smem_k and smem_v.
414
- // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); }
415
- }
416
- }
417
- // If there's tanh softcap, the scaling will be done before tanh.
418
- float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
419
- if constexpr (Is_FP8 && !Has_softcap) {
420
- int const bidh = get<1>(block_coord);
421
- int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
422
- float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
423
- float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
424
- softmax_scale_log2 *= q_descale * k_descale;
425
- }
426
- flash::Softmax<!LargeHeadDimV ? 2 * (2 * kBlockM / NumMmaThreads) : 2, /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
427
- // Attention output (GEMM-II) accumulator.
428
- Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{}));
429
- bool tile_valid;
430
- if constexpr (!LargeHeadDimV) {
431
- tile_valid = mainloop.mma(
432
- params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
433
- tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
434
- } else { // mma_pv might not compile if !LargeHeadDimV
435
- if (warp_group_idx == 1) {
436
- tile_valid = mainloop.mma(
437
- params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
438
- tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
439
- } else {
440
- tile_valid = mainloop.mma_pv(
441
- params.mainloop, pipeline_v, smem_pipe_read,
442
- tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage);
443
- }
444
- }
445
- // Do this here before the epilogue so that the next tile is ready to go.
446
- work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info);
447
- if constexpr (Split && Varlen) {
448
- if (!work_tile_info.is_valid(params.scheduler)) { // Last tile
449
- cutlass::arch::launch_dependent_grids();
450
- }
451
- }
452
- if (tile_valid) {
453
- // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
454
- epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv,
455
- threadIdx.x - MmaThreadOffset, block_coord);
456
- } else {
457
- // Write 0 to gO and -inf to gLSE.
458
- epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);
459
- }
460
- }
461
- epilogue.store_tail();
462
- }
463
-
464
- }
465
-
466
- };
467
-
468
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_fwd_launch_template.h DELETED
@@ -1,231 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include "cutlass/cutlass.h"
10
- #include "cutlass/device_kernel.h" // For device_kernel
11
- #include <cutlass/kernel_hardware_info.h>
12
- #include "cutlass/cluster_launch.hpp"
13
- #include "cutlass/kernel_launch.h"
14
-
15
- #include "static_switch.h"
16
- #include "flash.h"
17
- #include "tile_size.h"
18
- #include "tile_scheduler.hpp"
19
- #include "flash_fwd_kernel_sm90.h"
20
- #include "flash_fwd_kernel_sm80.h"
21
- #include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
22
- #include "mainloop_fwd_sm80.hpp"
23
- #include "epilogue_fwd.hpp"
24
- #include "heuristics.h"
25
-
26
- using namespace cute;
27
-
28
- template <int Arch, int kHeadDim, int kHeadDimV, int ClusterM, typename Element, typename ElementOut,
29
- bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool PagedKVNonTMA, bool AppendKV, bool HasQv,
30
- bool PackGQA, bool Split, bool V_colmajor, bool Use_one_mma_wg>
31
- void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
32
- static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time");
33
- static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time");
34
- static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen");
35
- static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;
36
- static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;
37
- using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
38
- using ElementS = cutlass::bfloat16_t;
39
-
40
- // Can't use structured binding since it's not compatible with constexpr
41
- static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg);
42
- static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);
43
- static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);
44
- static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);
45
- static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);
46
- static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap);
47
- static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS);
48
- static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS);
49
- static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS);
50
-
51
- using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
52
- using TileShape_MNK_PV = cute::Shape<Int<kBlockM>, Int<kHeadDimV>, Int<kBlockN>>;
53
- using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;
54
- using CollectiveMainloop = std::conditional_t<
55
- Arch >= 90,
56
- flash::CollectiveMainloopFwdSm90<kStages, ClusterShape, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, HasQv, MmaPV_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor, ElementS>,
57
- flash::CollectiveMainloopFwdSm80<kNWarps, kStages, Q_in_regs, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm80, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, PackGQA, Split, ElementS>
58
- >;
59
- using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK_PV, ClusterShape, ElementOut, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, PackGQA, Split, FP8_TransposeV>;
60
-
61
- static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads;
62
- using SchedulerPersistent = std::conditional_t<Varlen,
63
- flash::VarlenDynamicPersistentTileScheduler<kBlockM, CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>,
64
- std::conditional_t<!Is_causal && !Is_local,
65
- flash::StaticPersistentTileScheduler<Split>,
66
- flash::DynamicPersistentTileScheduler<CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>
67
- >
68
- >;
69
- using SchedulerSingleTile = flash::SingleTileScheduler<Varlen, Split, PackGQA, kBlockM>;
70
- // If Split then we probably don't have enough work for PersistentScheduler to be useful.
71
- // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better
72
- // since we'll avoid launching a bunch of thread blocks that immediately exit.
73
- // On Sm80, noncausal persistent seems a bit slower.
74
- static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split));
75
- using Scheduler = std::conditional_t<!UsePersistentScheduler, SchedulerSingleTile, SchedulerPersistent>;
76
- using AttnKernel = std::conditional_t<
77
- Arch >= 90,
78
- flash::enable_sm90_or_later<flash::FlashAttnFwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
79
- flash::enable_sm80_to_sm89<flash::FlashAttnFwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
80
- >;
81
-
82
- bool const is_varlen_q = params.cu_seqlens_q;
83
- bool const is_varlen_k = params.cu_seqlens_k;
84
- bool const is_varlen_k_new = params.cu_seqlens_knew;
85
- int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
86
- int batch_q = !is_varlen_q ? params.b : 1;
87
- int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1;
88
- typename CollectiveMainloop::StrideV v_strides =
89
- cute::conditional_return<!V_colmajor>(
90
- make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),
91
- make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));
92
- typename CollectiveMainloop::Arguments mainloop_args {
93
- static_cast<Element const*>(params.q_ptr),
94
- {seqlen_q, params.d, params.h, batch_q}, // shape_Q
95
- {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
96
- static_cast<Element*>(params.k_ptr),
97
- {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size,
98
- params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K
99
- {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
100
- static_cast<Element*>(params.v_ptr),
101
- params.dv, // headdim_v
102
- v_strides, // stride_V
103
- static_cast<Element const*>(params.knew_ptr),
104
- {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new
105
- {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new
106
- static_cast<Element const*>(params.vnew_ptr),
107
- {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new
108
- static_cast<Element const*>(params.qv_ptr),
109
- {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv
110
- static_cast<Element const*>(params.rotary_cos_ptr),
111
- {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter
112
- {params.rotary_dim / 2, _1{}}, // stride_rotary_cos
113
- static_cast<Element const*>(params.rotary_sin_ptr),
114
- {params.rotary_dim / 2, _1{}}, // stride_rotary_sin
115
- params.is_rotary_interleaved,
116
- params.page_table,
117
- // if page_size is not set, avoid dividing by zero
118
- {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table
119
- {params.page_table_batch_stride, _1{}}, // stride_page_table
120
- params.scale_softmax,
121
- params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr,
122
- {params.q_descale_batch_stride, params.q_descale_head_stride},
123
- {params.k_descale_batch_stride, params.k_descale_head_stride},
124
- {params.v_descale_batch_stride, params.v_descale_head_stride},
125
- params.window_size_left, params.window_size_right,
126
- params.softcap,
127
- params.num_splits,
128
- params.kv_batch_idx,
129
- params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
130
- params.seqused_q, params.seqused_k,
131
- params.leftpad_k, params.seqlens_rotary,
132
- static_cast<ElementS const*>(params.s_aux_ptr)
133
- };
134
- typename CollectiveEpilogue::Arguments epilogue_args {
135
- static_cast<ElementOut*>(params.o_ptr),
136
- {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O
137
- {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O
138
- static_cast<float*>(params.oaccum_ptr),
139
- {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial
140
- static_cast<float*>(params.softmax_lse_ptr),
141
- {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE
142
- static_cast<float*>(params.softmax_lseaccum_ptr),
143
- {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial
144
- params.h_k,
145
- params.cu_seqlens_q, params.seqused_q
146
- };
147
-
148
- int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k);
149
- int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{}));
150
- num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{}));
151
- typename flash::TileSchedulerArguments scheduler_args {
152
- num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits,
153
- params.h / params.h_k,
154
- params.seqlen_q,
155
- params.seqlen_k, params.d, params.dv, sizeof(Element),
156
- params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q,
157
- // params.num_m_blocks_ptr,
158
- params.num_splits_dynamic_ptr,
159
- };
160
-
161
- if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) {
162
- prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/);
163
- CHECK_CUDA_KERNEL_LAUNCH();
164
- }
165
-
166
- int device;
167
- CHECK_CUDA(cudaGetDevice(&device));
168
- typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
169
- mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
170
- });
171
-
172
- dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
173
- dim3 block_dims = AttnKernel::get_block_shape();
174
- int smem_size = AttnKernel::SharedStorageSize;
175
- // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
176
- // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
177
- // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
178
- // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
179
- // Get the ptr to kernel function.
180
- if constexpr (size(ClusterShape{}) > 1) {
181
- void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
182
- if (smem_size >= 48 * 1024) {
183
- CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
184
- }
185
- dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
186
- cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
187
- cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
188
- } else {
189
- auto kernel = cutlass::device_kernel<AttnKernel>;
190
- if (smem_size >= 48 * 1024) {
191
- CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
192
- }
193
- // kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);
194
- cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params,
195
- Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/);
196
- }
197
- CHECK_CUDA_KERNEL_LAUNCH();
198
- }
199
-
200
- template<int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
201
- void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream) {
202
- static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported");
203
- static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
204
- using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
205
- CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
206
- VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] {
207
- static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1;
208
- VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] {
209
- BOOL_SWITCH(use_one_mma_wg(params), Use_one_mma_wg_, [&] {
210
- // Avoid over compiliation by making sure this only get set if it is actually used, i.e. we currently only support one mma wg for 128 head dim and hopper
211
- static constexpr bool Use_one_mma_wg = Use_one_mma_wg_ && Arch >= 90 && kHeadDim == 128;
212
-
213
- // Only needed here to decide if we should use cluster
214
- static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg)) : 128;
215
-
216
- static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen;
217
- BOOL_SWITCH(params.qv_ptr, HasQV_, [&] {
218
- static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256;
219
- APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
220
- // Only use Cluster if number of tiles along seqlen_q is even and not varlen
221
- CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
222
- static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1;
223
- run_flash_fwd<Arch, kHeadDim, kHeadDimV, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV && Varlen, HasQv, PackGQA, Split, V_colmajor, Use_one_mma_wg>(params, stream);
224
- });
225
- });
226
- });
227
- });
228
- });
229
- });
230
- });
231
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_prepare_scheduler.cu DELETED
@@ -1,124 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include "cutlass/fast_math.h"
6
- #include "cutlass/barrier.h"
7
- #include "cutlass/arch/barrier.h"
8
-
9
- #include "cutlass/arch/grid_dependency_control.h"
10
-
11
- #include "flash.h"
12
-
13
- namespace flash {
14
-
15
- __global__ void prepare_varlen_num_blocks_kernel(
16
- int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static,
17
- int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new,
18
- int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr,
19
- int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static,
20
- cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod,
21
- int* const tile_count_semaphore,
22
- // int* const num_m_blocks_ptr,
23
- int* const num_splits_dynamic_ptr,
24
- bool enable_pdl) {
25
-
26
- static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1;
27
- static constexpr int kSmemSize = 1;
28
- // Assume that there's only one block in the grid
29
- __shared__ int total_blocks_smem[kSmemSize];
30
-
31
- // There's only 1 block in the grid, so might as well start launching the main attn kernel
32
- if (enable_pdl) { cutlass::arch::launch_dependent_grids(); }
33
-
34
- if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; }
35
- __syncthreads();
36
-
37
- if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; }
38
-
39
- int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
40
-
41
- auto get_num_m_blocks = [&](int bidb_start) {
42
- int batch_idx = lane + bidb_start;
43
- int seqlen;
44
- if (seqused_q) {
45
- seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0;
46
- } else if (cu_seqlens_q) {
47
- int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0;
48
- int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
49
- seqlen = next_cu_seqlen - cur_cu_seqlen;
50
- } else {
51
- seqlen = seqlen_q_static;
52
- }
53
- seqlen *= qhead_per_khead;
54
- return batch_idx < num_batch && lane < kNumBatchPerWarp
55
- ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0;
56
- };
57
-
58
- auto get_num_n_blocks = [&](int bidb_start) {
59
- int batch_idx = lane + bidb_start;
60
- int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0;
61
- int seqlen;
62
- if (seqused_k) {
63
- seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0;
64
- } else if (cu_seqlens_k) {
65
- int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0;
66
- int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
67
- seqlen = next_cu_seqlen - cur_cu_seqlen;
68
- } else {
69
- seqlen = seqlen_k_static;
70
- }
71
- int seqlen_new;
72
- if (cu_seqlens_k_new) {
73
- int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0;
74
- int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1);
75
- seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new;
76
- } else {
77
- seqlen_new = seqlen_k_new_static;
78
- }
79
- // if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); }
80
- seqlen = seqlen - leftpad_k + seqlen_new;
81
- return batch_idx < num_batch && lane < kNumBatchPerWarp
82
- ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0;
83
- };
84
-
85
- int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp;
86
- int bidb_start = kNumBatchPerWarp * warp_idx;
87
- int num_m_blocks = get_num_m_blocks(bidb_start);
88
- int num_n_blocks = get_num_n_blocks(bidb_start);
89
-
90
- int total_blocks = num_m_blocks * num_n_blocks;
91
- // Warp sum
92
- #pragma unroll
93
- for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) {
94
- total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i);
95
- }
96
- if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); }
97
- __syncthreads();
98
- total_blocks = total_blocks_smem[0];
99
- // 10% margin
100
- int blocks_per_sm = static_cast<int>(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm)));
101
- // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM
102
- int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1);
103
- if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) {
104
- num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic;
105
- // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic);
106
- }
107
- }
108
-
109
- } // flash
110
-
111
- void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa,
112
- int blockM, int blockN, bool enable_pdl) {
113
- // Only support batch <= 992 (32 warps, each with 31 batches)
114
- int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k);
115
- flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>(
116
- params.seqlen_q, params.seqlen_k, params.seqlen_knew,
117
- params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
118
- params.seqused_q, params.seqused_k, params.leftpad_k,
119
- params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits,
120
- cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN),
121
- params.tile_count_semaphore,
122
- // params.num_m_blocks_ptr,
123
- params.num_splits_dynamic_ptr, enable_pdl);
124
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/heuristics.h DELETED
@@ -1,65 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <vector>
8
- #include "flash.h"
9
-
10
- inline bool use_one_mma_wg(Flash_fwd_params const& params) {
11
- return params.arch >= 90 && params.d == 128 &&
12
- params.seqlen_q * (!params.pack_gqa ? 1 : params.h / params.h_k) <= 64;
13
- };
14
-
15
- inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) {
16
- // If varlen, we don't actually know seqlen_q but only max_seqlen_q.
17
- if (varlen_q) return true;
18
- // Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM
19
- auto round_up = [](int a, int b) { return (a + b - 1) / b * b; };
20
- float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM));
21
- float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM));
22
- return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency;
23
- };
24
-
25
- // Find the number of splits that maximizes the occupancy. For example, if we have
26
- // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
27
- // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
28
- // splits as that would incur more HBM reads/writes.
29
- // So we find the best efficiency, then find the smallest number of splits that gets 85%
30
- // of the best efficiency.
31
- inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) {
32
- // If we have enough to almost fill the SMs, then just use 1 split
33
- // However, in the case of super long seqlen where each head of KV doesn't even fit into
34
- // L2 (we assume that L2 size is 50MB), we want to split.
35
- if (total_mblocks >= 0.8f * num_SMs) {
36
- int const size_l2 = 50 * 1024 * 1024;
37
- // Only split if there are enough queries to go over the KV at least twice
38
- // Don't split if causal
39
- if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) {
40
- return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits);
41
- } else {
42
- return 1;
43
- }
44
- }
45
- // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
46
- if (num_n_blocks <= 4) { return 1; }
47
- max_splits = std::min({max_splits, num_SMs, num_n_blocks});
48
- float max_efficiency = 0.f;
49
- std::vector<float> efficiency;
50
- efficiency.reserve(max_splits);
51
- for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
52
- float n_waves = float(total_mblocks * num_splits) / num_SMs;
53
- float eff = n_waves / ceil(n_waves);
54
- // printf("num_splits = %d, eff = %f\n", num_splits, eff);
55
- if (eff > max_efficiency) { max_efficiency = eff; }
56
- efficiency.push_back(eff);
57
- }
58
- for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
59
- if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
60
- // printf("num_splits chosen = %d\n", num_splits);
61
- return num_splits;
62
- }
63
- }
64
- return 1;
65
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM128
9
- template<>
10
- void run_mha_bwd_<80, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim128<80, cutlass::bfloat16_t, false>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim128<86, cutlass::bfloat16_t, false>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM128
8
- template<>
9
- void run_mha_bwd_<90, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim128<90, cutlass::bfloat16_t, false>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM128
9
- template<>
10
- void run_mha_bwd_<80, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim128<80, cutlass::bfloat16_t, true>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim128<86, cutlass::bfloat16_t, true>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM128
8
- template<>
9
- void run_mha_bwd_<90, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim128<90, cutlass::bfloat16_t, true>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu DELETED
@@ -1,6 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_hdim128_bf16_sm90.cu"
6
- #include "flash_bwd_hdim128_bf16_softcap_sm90.cu"
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM128
9
- template<>
10
- void run_mha_bwd_<80, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim128<80, cutlass::half_t, false>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim128<86, cutlass::half_t, false>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM128
8
- template<>
9
- void run_mha_bwd_<90, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim128<90, cutlass::half_t, false>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM128
9
- template<>
10
- void run_mha_bwd_<80, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim128<80, cutlass::half_t, true>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim128<86, cutlass::half_t, true>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM128
8
- template<>
9
- void run_mha_bwd_<90, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim128<90, cutlass::half_t, true>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu DELETED
@@ -1,6 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_hdim128_fp16_sm90.cu"
6
- #include "flash_bwd_hdim128_fp16_softcap_sm90.cu"
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM192
9
- template<>
10
- void run_mha_bwd_<80, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim192<80, cutlass::bfloat16_t, false>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim192<86, cutlass::bfloat16_t, false>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM192
8
- template<>
9
- void run_mha_bwd_<90, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim192<90, cutlass::bfloat16_t, false>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM192
9
- template<>
10
- void run_mha_bwd_<80, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim192<80, cutlass::bfloat16_t, true>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim192<86, cutlass::bfloat16_t, true>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM192
8
- template<>
9
- void run_mha_bwd_<90, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim192<90, cutlass::bfloat16_t, true>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu DELETED
@@ -1,6 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_hdim192_bf16_sm90.cu"
6
- #include "flash_bwd_hdim192_bf16_softcap_sm90.cu"
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM192
9
- template<>
10
- void run_mha_bwd_<80, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim192<80, cutlass::half_t, false>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim192<86, cutlass::half_t, false>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM192
8
- template<>
9
- void run_mha_bwd_<90, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim192<90, cutlass::half_t, false>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM192
9
- template<>
10
- void run_mha_bwd_<80, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim192<80, cutlass::half_t, true>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim192<86, cutlass::half_t, true>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM192
8
- template<>
9
- void run_mha_bwd_<90, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim192<90, cutlass::half_t, true>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu DELETED
@@ -1,6 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_hdim192_fp16_sm90.cu"
6
- #include "flash_bwd_hdim192_fp16_softcap_sm90.cu"
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM256
9
- template<>
10
- void run_mha_bwd_<80, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim256<80, cutlass::bfloat16_t, false>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim256<86, cutlass::bfloat16_t, false>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM256
8
- template<>
9
- void run_mha_bwd_<90, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim256<90, cutlass::bfloat16_t, false>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM256
9
- template<>
10
- void run_mha_bwd_<80, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim256<80, cutlass::bfloat16_t, true>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim256<86, cutlass::bfloat16_t, true>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM256
8
- template<>
9
- void run_mha_bwd_<90, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim256<90, cutlass::bfloat16_t, true>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu DELETED
@@ -1,6 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_hdim256_bf16_sm90.cu"
6
- #include "flash_bwd_hdim256_bf16_softcap_sm90.cu"
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM256
9
- template<>
10
- void run_mha_bwd_<80, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim256<80, cutlass::half_t, false>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim256<86, cutlass::half_t, false>(params, stream);
16
- }
17
- #endif
18
- #endif