danieldk HF Staff commited on
Commit
4d1b54e
·
0 Parent(s):

Convert FA3 to Kernel Hub format

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build.toml +591 -0
  2. flake.lock +168 -0
  3. flake.nix +17 -0
  4. flash-attn/block.h +94 -0
  5. flash-attn/copy_sm90_bulk_reduce.hpp +49 -0
  6. flash-attn/cuda_check.h +19 -0
  7. flash-attn/epilogue_bwd.hpp +523 -0
  8. flash-attn/epilogue_fwd.hpp +484 -0
  9. flash-attn/flash.h +220 -0
  10. flash-attn/flash_api.cpp +1623 -0
  11. flash-attn/flash_bwd_kernel_sm80.h +173 -0
  12. flash-attn/flash_bwd_kernel_sm90.h +282 -0
  13. flash-attn/flash_bwd_launch_template.h +377 -0
  14. flash-attn/flash_bwd_postprocess_kernel.h +256 -0
  15. flash-attn/flash_bwd_preprocess_kernel.h +252 -0
  16. flash-attn/flash_fwd_combine.cu +13 -0
  17. flash-attn/flash_fwd_combine_kernel.h +482 -0
  18. flash-attn/flash_fwd_combine_launch_template.h +80 -0
  19. flash-attn/flash_fwd_kernel_sm80.h +215 -0
  20. flash-attn/flash_fwd_kernel_sm90.h +468 -0
  21. flash-attn/flash_fwd_launch_template.h +231 -0
  22. flash-attn/flash_prepare_scheduler.cu +124 -0
  23. flash-attn/heuristics.h +65 -0
  24. flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu +18 -0
  25. flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu +12 -0
  26. flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu +18 -0
  27. flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu +12 -0
  28. flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu +6 -0
  29. flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu +18 -0
  30. flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu +12 -0
  31. flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu +18 -0
  32. flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu +12 -0
  33. flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu +6 -0
  34. flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu +18 -0
  35. flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu +12 -0
  36. flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu +18 -0
  37. flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu +12 -0
  38. flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu +6 -0
  39. flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu +18 -0
  40. flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu +12 -0
  41. flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu +18 -0
  42. flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu +12 -0
  43. flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu +6 -0
  44. flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu +18 -0
  45. flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu +12 -0
  46. flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu +18 -0
  47. flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu +12 -0
  48. flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu +6 -0
  49. flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu +18 -0
  50. flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu +12 -0
build.toml ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "flash_attn"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/pytorch_shim.h",
8
+ "torch-ext/torch_binding.cpp",
9
+ "torch-ext/torch_binding.h",
10
+ ]
11
+
12
+ [kernel.flash_attn]
13
+ backend = "cuda"
14
+ cuda-capabilities = ["8.0", "9.0a"]
15
+ cuda-flags = [
16
+ "-O3",
17
+ "-std=c++17",
18
+ "--ftemplate-backtrace-limit=0", # To debug template code
19
+ "--use_fast_math",
20
+ "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
21
+ "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
22
+ "-DCUTLASS_ENABLE_GDC_FOR_SM90",
23
+ "--expt-relaxed-constexpr",
24
+ "--expt-extended-lambda",
25
+ "--use_fast_math",
26
+ "-DNDEBUG",
27
+ ]
28
+
29
+ src = [
30
+ "flash-attn/cuda_check.h",
31
+ "flash-attn/flash_api.cpp",
32
+ "flash-attn/flash_fwd_combine.cu",
33
+ "flash-attn/flash_fwd_combine_kernel.h",
34
+ "flash-attn/flash_fwd_combine_launch_template.h",
35
+ "flash-attn/flash.h",
36
+ "flash-attn/flash_prepare_scheduler.cu",
37
+ "flash-attn/heuristics.h",
38
+ "flash-attn/seqlen.h",
39
+ "flash-attn/static_switch.h",
40
+ "flash-attn/tile_size.h",
41
+ "flash-attn/utils.h",
42
+ ]
43
+ depends = ["torch", "cutlass_3_9"]
44
+
45
+ [kernel.flash_attn_sm80]
46
+ backend = "cuda"
47
+ cuda-capabilities = ["8.0", "9.0a"]
48
+ cuda-flags = [
49
+ "-O3",
50
+ "-std=c++17",
51
+ "--ftemplate-backtrace-limit=0", # To debug template code
52
+ "--use_fast_math",
53
+ "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
54
+ "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
55
+ "-DCUTLASS_ENABLE_GDC_FOR_SM90",
56
+ "--expt-relaxed-constexpr",
57
+ "--expt-extended-lambda",
58
+ "--use_fast_math",
59
+ "-DNDEBUG",
60
+ "-DFLASHATTENTION_DISABLE_PYBIND",
61
+ ]
62
+ src = [
63
+ "flash-attn/block.h",
64
+ "flash-attn/copy_sm90_bulk_reduce.hpp",
65
+ "flash-attn/epilogue_bwd.hpp",
66
+ "flash-attn/epilogue_fwd.hpp",
67
+ "flash-attn/flash.h",
68
+ "flash-attn/flash_bwd_kernel_sm80.h",
69
+ "flash-attn/flash_bwd_kernel_sm90.h",
70
+ "flash-attn/flash_bwd_launch_template.h",
71
+ "flash-attn/flash_bwd_postprocess_kernel.h",
72
+ "flash-attn/flash_bwd_preprocess_kernel.h",
73
+ "flash-attn/flash_fwd_launch_template.h",
74
+ "flash-attn/flash_fwd_kernel_sm80.h",
75
+ "flash-attn/flash_fwd_kernel_sm90.h",
76
+ "flash-attn/heuristics.h",
77
+ "flash-attn/mainloop_bwd_sm80.hpp",
78
+ "flash-attn/mainloop_fwd_sm80.hpp",
79
+ "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
80
+ "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
81
+ "flash-attn/mask.h",
82
+ "flash-attn/named_barrier.hpp",
83
+ "flash-attn/pack_gqa.h",
84
+ "flash-attn/paged_kv.h",
85
+ "flash-attn/rotary.h",
86
+ "flash-attn/sm90_pipeline_no_cluster.hpp",
87
+ "flash-attn/softmax.h",
88
+ "flash-attn/tile_size.h",
89
+ "flash-attn/tile_scheduler.hpp",
90
+ "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu",
91
+ "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu",
92
+ "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu",
93
+ "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu",
94
+ "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu",
95
+ "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu",
96
+ "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu",
97
+ "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu",
98
+ "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu",
99
+ "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu",
100
+ "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu",
101
+ "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu",
102
+ "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu",
103
+ "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu",
104
+ "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu",
105
+ "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu",
106
+ "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu",
107
+ "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu",
108
+ "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu",
109
+ "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu",
110
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu",
111
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcapall_sm80.cu",
112
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu",
113
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcapall_sm80.cu",
114
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu",
115
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcapall_sm80.cu",
116
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu",
117
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcapall_sm80.cu",
118
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu",
119
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcapall_sm80.cu",
120
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu",
121
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcapall_sm80.cu",
122
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu",
123
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcapall_sm80.cu",
124
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu",
125
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcapall_sm80.cu",
126
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu",
127
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcapall_sm80.cu",
128
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu",
129
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcapall_sm80.cu",
130
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu",
131
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcapall_sm80.cu",
132
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu",
133
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcapall_sm80.cu",
134
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu",
135
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcapall_sm80.cu",
136
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu",
137
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcapall_sm80.cu",
138
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu",
139
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcapall_sm80.cu",
140
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu",
141
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcapall_sm80.cu",
142
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu",
143
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcapall_sm80.cu",
144
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu",
145
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcapall_sm80.cu",
146
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu",
147
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcapall_sm80.cu",
148
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu",
149
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcapall_sm80.cu",
150
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu",
151
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcapall_sm80.cu",
152
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu",
153
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcapall_sm80.cu",
154
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu",
155
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcapall_sm80.cu",
156
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu",
157
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcapall_sm80.cu",
158
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu",
159
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcapall_sm80.cu",
160
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu",
161
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcapall_sm80.cu",
162
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu",
163
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcapall_sm80.cu",
164
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu",
165
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcapall_sm80.cu",
166
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu",
167
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcapall_sm80.cu",
168
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu",
169
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcapall_sm80.cu",
170
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu",
171
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcapall_sm80.cu",
172
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu",
173
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcapall_sm80.cu",
174
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu",
175
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcapall_sm80.cu",
176
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu",
177
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcapall_sm80.cu",
178
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu",
179
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcapall_sm80.cu",
180
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu",
181
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcapall_sm80.cu",
182
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu",
183
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcapall_sm80.cu",
184
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu",
185
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcapall_sm80.cu",
186
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu",
187
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcapall_sm80.cu",
188
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu",
189
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcapall_sm80.cu",
190
+ ]
191
+ include = ["flash-attn"]
192
+ depends = ["torch", "cutlass_3_9"]
193
+
194
+ [kernel.flash_attn_sm90]
195
+ backend = "cuda"
196
+ cuda-capabilities = ["8.0", "9.0a"]
197
+ cuda-flags = [
198
+ "-O3",
199
+ "-std=c++17",
200
+ "--ftemplate-backtrace-limit=0", # To debug template code
201
+ "--use_fast_math",
202
+ "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
203
+ "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
204
+ "-DCUTLASS_ENABLE_GDC_FOR_SM90",
205
+ "--expt-relaxed-constexpr",
206
+ "--expt-extended-lambda",
207
+ "--use_fast_math",
208
+ "-DNDEBUG",
209
+ ]
210
+ src = [
211
+ "flash-attn/block.h",
212
+ "flash-attn/copy_sm90_bulk_reduce.hpp",
213
+ "flash-attn/epilogue_bwd.hpp",
214
+ "flash-attn/epilogue_fwd.hpp",
215
+ "flash-attn/flash.h",
216
+ "flash-attn/flash_bwd_kernel_sm80.h",
217
+ "flash-attn/flash_bwd_kernel_sm90.h",
218
+ "flash-attn/flash_bwd_launch_template.h",
219
+ "flash-attn/flash_bwd_postprocess_kernel.h",
220
+ "flash-attn/flash_bwd_preprocess_kernel.h",
221
+ "flash-attn/flash_fwd_launch_template.h",
222
+ "flash-attn/flash_fwd_kernel_sm80.h",
223
+ "flash-attn/flash_fwd_kernel_sm90.h",
224
+ "flash-attn/heuristics.h",
225
+ "flash-attn/mainloop_bwd_sm80.hpp",
226
+ "flash-attn/mainloop_fwd_sm80.hpp",
227
+ "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
228
+ "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
229
+ "flash-attn/mask.h",
230
+ "flash-attn/named_barrier.hpp",
231
+ "flash-attn/pack_gqa.h",
232
+ "flash-attn/paged_kv.h",
233
+ "flash-attn/rotary.h",
234
+ "flash-attn/sm90_pipeline_no_cluster.hpp",
235
+ "flash-attn/softmax.h",
236
+ "flash-attn/tile_size.h",
237
+ "flash-attn/tile_scheduler.hpp",
238
+
239
+ "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu",
240
+ "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu",
241
+ "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu",
242
+ "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu",
243
+ "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu",
244
+ "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu",
245
+ "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu",
246
+ "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu",
247
+ "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu",
248
+ "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu",
249
+ "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu",
250
+ "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu",
251
+ "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu",
252
+ "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu",
253
+ "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu",
254
+ "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu",
255
+ "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu",
256
+ "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu",
257
+ "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu",
258
+ "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu",
259
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu",
260
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu",
261
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu",
262
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu",
263
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu",
264
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu",
265
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu",
266
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu",
267
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu",
268
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu",
269
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu",
270
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu",
271
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu",
272
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu",
273
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu",
274
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu",
275
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu",
276
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu",
277
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu",
278
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu",
279
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu",
280
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu",
281
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu",
282
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu",
283
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu",
284
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu",
285
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu",
286
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu",
287
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu",
288
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu",
289
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu",
290
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu",
291
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu",
292
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu",
293
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu",
294
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu",
295
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu",
296
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu",
297
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu",
298
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu",
299
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu",
300
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu",
301
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu",
302
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu",
303
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu",
304
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu",
305
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu",
306
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu",
307
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu",
308
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu",
309
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu",
310
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu",
311
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu",
312
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu",
313
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu",
314
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu",
315
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu",
316
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu",
317
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu",
318
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu",
319
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu",
320
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu",
321
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu",
322
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu",
323
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu",
324
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu",
325
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu",
326
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu",
327
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu",
328
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu",
329
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu",
330
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu",
331
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu",
332
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu",
333
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu",
334
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu",
335
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu",
336
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu",
337
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu",
338
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu",
339
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu",
340
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu",
341
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu",
342
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu",
343
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu",
344
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu",
345
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu",
346
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu",
347
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu",
348
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu",
349
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu",
350
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu",
351
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu",
352
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu",
353
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu",
354
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu",
355
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu",
356
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu",
357
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu",
358
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu",
359
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu",
360
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu",
361
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu",
362
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu",
363
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu",
364
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu",
365
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu",
366
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu",
367
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu",
368
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu",
369
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu",
370
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu",
371
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu",
372
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu",
373
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu",
374
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu",
375
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu",
376
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu",
377
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu",
378
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu",
379
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu",
380
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu",
381
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu",
382
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu",
383
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu",
384
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu",
385
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu",
386
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu",
387
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu",
388
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu",
389
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu",
390
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu",
391
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu",
392
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu",
393
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu",
394
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu",
395
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu",
396
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu",
397
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu",
398
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu",
399
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu",
400
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu",
401
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu",
402
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu",
403
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu",
404
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu",
405
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu",
406
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu",
407
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu",
408
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu",
409
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu",
410
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu",
411
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu",
412
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu",
413
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu",
414
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu",
415
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu",
416
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu",
417
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu",
418
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu",
419
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu",
420
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu",
421
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu",
422
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu",
423
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu",
424
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu",
425
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu",
426
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu",
427
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu",
428
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu",
429
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu",
430
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu",
431
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu",
432
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu",
433
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu",
434
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu",
435
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu",
436
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu",
437
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu",
438
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu",
439
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu",
440
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu",
441
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu",
442
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu",
443
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu",
444
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu",
445
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu",
446
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu",
447
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu",
448
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu",
449
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu",
450
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu",
451
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu",
452
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu",
453
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu",
454
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu",
455
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu",
456
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu",
457
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu",
458
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu",
459
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu",
460
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu",
461
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu",
462
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu",
463
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu",
464
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu",
465
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu",
466
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu",
467
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu",
468
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu",
469
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu",
470
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu",
471
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu",
472
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu",
473
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu",
474
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu",
475
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu",
476
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu",
477
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu",
478
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu",
479
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu",
480
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu",
481
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu",
482
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu",
483
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu",
484
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu",
485
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu",
486
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu",
487
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu",
488
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu",
489
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu",
490
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu",
491
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu",
492
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu",
493
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu",
494
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu",
495
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu",
496
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu",
497
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu",
498
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu",
499
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu",
500
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu",
501
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu",
502
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu",
503
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu",
504
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu",
505
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu",
506
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu",
507
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu",
508
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu",
509
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu",
510
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu",
511
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu",
512
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu",
513
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu",
514
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu",
515
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu",
516
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu",
517
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu",
518
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu",
519
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu",
520
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu",
521
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu",
522
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu",
523
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu",
524
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu",
525
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu",
526
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu",
527
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu",
528
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu",
529
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu",
530
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu",
531
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu",
532
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu",
533
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu",
534
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu",
535
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu",
536
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu",
537
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu",
538
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu",
539
+ ]
540
+ include = ["flash-attn"]
541
+ depends = ["torch", "cutlass_3_9"]
542
+
543
+ # [kernel.flash_attn_sm100]
544
+ # backend = "cuda"
545
+ # cuda-capabilities = ["8.0", "9.0a", "10.0"]
546
+ # cuda-flags = [
547
+ # "-O3",
548
+ # "-std=c++17",
549
+ # "--ftemplate-backtrace-limit=0", # To debug template code
550
+ # "--use_fast_math",
551
+ # "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
552
+ # "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
553
+ # "-DCUTLASS_ENABLE_GDC_FOR_SM90",
554
+ # "--expt-relaxed-constexpr",
555
+ # "--expt-extended-lambda",
556
+ # "--use_fast_math",
557
+ # "-DNDEBUG",
558
+ # ]
559
+ # src = [
560
+ # "flash-attn/block.h",
561
+ # "flash-attn/copy_sm90_bulk_reduce.hpp",
562
+ # "flash-attn/epilogue_bwd.hpp",
563
+ # "flash-attn/epilogue_fwd.hpp",
564
+ # "flash-attn/flash.h",
565
+ # "flash-attn/flash_bwd_kernel_sm80.h",
566
+ # "flash-attn/flash_bwd_kernel_sm90.h",
567
+ # "flash-attn/flash_bwd_launch_template.h",
568
+ # "flash-attn/flash_bwd_postprocess_kernel.h",
569
+ # "flash-attn/flash_bwd_preprocess_kernel.h",
570
+ # "flash-attn/flash_fwd_launch_template.h",
571
+ # "flash-attn/flash_fwd_kernel_sm80.h",
572
+ # "flash-attn/flash_fwd_kernel_sm90.h",
573
+ # "flash-attn/heuristics.h",
574
+ # "flash-attn/mainloop_bwd_sm80.hpp",
575
+ # "flash-attn/mainloop_fwd_sm80.hpp",
576
+ # "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
577
+ # "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
578
+ # "flash-attn/mask.h",
579
+ # "flash-attn/named_barrier.hpp",
580
+ # "flash-attn/pack_gqa.h",
581
+ # "flash-attn/paged_kv.h",
582
+ # "flash-attn/rotary.h",
583
+ # "flash-attn/sm90_pipeline_no_cluster.hpp",
584
+ # "flash-attn/softmax.h",
585
+ # "flash-attn/tile_size.h",
586
+ # "flash-attn/tile_scheduler.hpp",
587
+ #
588
+ # "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu",
589
+ # ]
590
+ # include = ["flash-attn"]
591
+ # depends = ["torch", "cutlass_3_9"]
flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1750234878,
77
+ "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1750275112,
102
+ "narHash": "sha256-gqAxmLLt0tYvuRYumOZHQgryMeEFdt6j3nEC8B5rT14=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "1b63210b2a1fc3cda2e3a579e7aa8f8c8532626f",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1747820358,
117
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
+ "owner": "danieldk",
119
+ "repo": "nixpkgs",
120
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "danieldk",
125
+ "ref": "cudatoolkit-12.9-kernel-builder",
126
+ "repo": "nixpkgs",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Hopper Flash Attention kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
flash-attn/block.h ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ namespace flash {
8
+
9
+ template <class SeqlenInfo_t, int kBlockM, int kBlockN, bool Is_causal, bool Is_local, bool PackGQA=false, bool Split=false>
10
+ struct BlockMN {
11
+
12
+ static
13
+ CUTLASS_DEVICE
14
+ cute::tuple<int, int> get_n_block_min_max(
15
+ SeqlenInfo_t const& seqlen_info,
16
+ int const m_block, int const bidb, int const split_idx, int const num_splits,
17
+ int const window_size_left, int const window_size_right,
18
+ cutlass::FastDivmod const& qhead_per_khead_divmod) {
19
+
20
+ int const seqlen_k = seqlen_info.seqlen_k;
21
+ int const seqlen_q = seqlen_info.seqlen_q;
22
+ int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
23
+ if constexpr (Is_causal || Is_local) {
24
+ int m_idx_max = (m_block + 1) * kBlockM;
25
+ // TODO: check off-by-1 error
26
+ if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
27
+ n_block_max = std::min(n_block_max,
28
+ cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN));
29
+ }
30
+ int n_block_min = 0;
31
+ if constexpr (Is_local) {
32
+ int m_idx_min = m_block * kBlockM;
33
+ if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
34
+ n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - window_size_left) / kBlockN);
35
+ }
36
+ // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
37
+ if constexpr (Split) {
38
+ uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
39
+ int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
40
+ int split_idx_actual = split_idx & 0x0000FFFF;
41
+ int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
42
+ int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual);
43
+ n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split;
44
+ n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);
45
+ // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); }
46
+ }
47
+ // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
48
+ return {n_block_min, n_block_max};
49
+ }
50
+
51
+ static
52
+ CUTLASS_DEVICE
53
+ cute::tuple<int, int> get_n_block_k_new_min_max(
54
+ SeqlenInfo_t const& seqlen_info,
55
+ int const m_block, int const bidb, int const split_idx, int const num_splits,
56
+ int const window_size_left, int const window_size_right,
57
+ cutlass::FastDivmod const& qhead_per_khead_divmod) {
58
+
59
+ auto [n_block_min, n_block_max] = get_n_block_min_max(
60
+ seqlen_info, m_block, bidb, split_idx, num_splits,
61
+ window_size_left, window_size_right, qhead_per_khead_divmod);
62
+ int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
63
+ int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
64
+ int const n_block_new_min = idx_k_new_min / kBlockN;
65
+ int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
66
+ // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
67
+ return {n_block_new_min, n_block_new_max};
68
+ }
69
+
70
+ static
71
+ CUTLASS_DEVICE
72
+ cute::tuple<int, int> get_m_block_min_max(
73
+ SeqlenInfo_t const& seqlen_info,
74
+ int const n_block, int const bidb,
75
+ int const window_size_left, int const window_size_right, int const sink_token_length) {
76
+
77
+ int const seqlen_q = seqlen_info.seqlen_q;
78
+ int const seqlen_k = seqlen_info.seqlen_k;
79
+ int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
80
+ if constexpr (Is_local) {
81
+ if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) {
82
+ m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM));
83
+ }
84
+ }
85
+ int m_block_min = 0;
86
+ if constexpr (Is_causal || Is_local) {
87
+ m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM);
88
+ }
89
+ return {m_block_min, m_block_max};
90
+ }
91
+
92
+ };
93
+
94
+ } // namespace flash
flash-attn/copy_sm90_bulk_reduce.hpp ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include<cute/arch/copy_sm90_tma.hpp>
8
+
9
+ namespace cute
10
+ {
11
+
12
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
13
+
14
+ struct SM90_BULK_REDUCE_ADD
15
+ {
16
+ CUTE_HOST_DEVICE static void
17
+ copy(float const* smem_ptr,
18
+ float * gmem_ptr, int32_t store_bytes)
19
+ {
20
+ #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
21
+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
22
+ asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
23
+ :
24
+ : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes)
25
+ : "memory");
26
+ #else
27
+ CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
28
+ #endif
29
+ }
30
+
31
+ CUTE_HOST_DEVICE static void
32
+ copy(float const* smem_ptr,
33
+ float * gmem_ptr, int32_t store_bytes, uint64_t cache_hint)
34
+ {
35
+ #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
36
+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
37
+ asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n"
38
+ :
39
+ : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint)
40
+ : "memory");
41
+ #else
42
+ CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
43
+ #endif
44
+ }
45
+ };
46
+
47
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ } // end namespace cute
flash-attn/cuda_check.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <assert.h>
8
+ #include <stdlib.h>
9
+
10
+ #define CHECK_CUDA(call) \
11
+ do { \
12
+ cudaError_t status_ = call; \
13
+ if (status_ != cudaSuccess) { \
14
+ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
15
+ exit(1); \
16
+ } \
17
+ } while(0)
18
+
19
+ #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
flash-attn/epilogue_bwd.hpp ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cutlass/cutlass.h"
8
+ #include "cutlass/barrier.h"
9
+ #include "cute/tensor.hpp"
10
+
11
+ #include "cutlass/gemm/collective/builders/sm90_common.inl"
12
+
13
+ #include "seqlen.h"
14
+ #include "named_barrier.hpp"
15
+ #include "utils.h"
16
+
17
+ namespace flash {
18
+
19
+ using namespace cute;
20
+
21
+ template <class TileShape_MNK_, class Element_, class ArchTag_,
22
+ int NumEpilogueThreads_, bool Varlen_, bool dKV_swapAB_, int AtomLayoutKdKV=1>
23
+ struct CollectiveEpilogueBwd {
24
+
25
+ using TileShape_MNK = TileShape_MNK_;
26
+ using Element = Element_;
27
+ using ArchTag = ArchTag_;
28
+ static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
29
+ static constexpr bool Varlen = Varlen_;
30
+ static constexpr bool dKV_swapAB = dKV_swapAB_;
31
+ static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90;
32
+
33
+ static_assert(ArchTag::kMinComputeCapability >= 80);
34
+
35
+ using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE;
36
+
37
+ // These are for storing the output tensor without TMA (e.g., for setting output to zero)
38
+ static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
39
+ static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
40
+ static constexpr int kHeadDim = get<2>(TileShape_MNK{});
41
+ static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);
42
+ static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
43
+ using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
44
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
45
+ using GmemTiledCopydKV = decltype(
46
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
47
+ GmemLayoutAtom{},
48
+ Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
49
+
50
+ using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
51
+ // TODO: do we have to change this if dKV_swapAB is true?
52
+ decltype(cute::get<1>(TileShape_MNK{})), Int<CUTE_STATIC_V(cute::get<2>(TileShape_MNK{})) / AtomLayoutKdKV>>());
53
+ using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{})));
54
+ using SmemLayoutdKVtTMA =
55
+ decltype(cute::composition(SmemLayoutdKVTMA{},
56
+ make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
57
+ make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
58
+
59
+ // If we don't use TMA
60
+ static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16);
61
+ static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
62
+ using SmemLayoutAtomdKVSTG =
63
+ decltype(composition(Swizzle<kSwizzle, 3, 3>{},
64
+ Layout<Shape<Int<8>, Int<kBlockKSmem>>,
65
+ Stride<Int<kBlockKSmem>, _1>>{}));
66
+
67
+ using SmemLayoutAtomdKV = std::conditional_t<Use_TMA, SmemLayoutAtomdKVTMA, SmemLayoutAtomdKVSTG>;
68
+ using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{})));
69
+ using SmemLayoutdKVt =
70
+ decltype(cute::composition(SmemLayoutdKV{},
71
+ make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
72
+ make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
73
+
74
+ using SmemCopyAtomdKV = Copy_Atom<
75
+ std::conditional_t<
76
+ ArchTag::kMinComputeCapability >= 90,
77
+ std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
78
+ AutoVectorizingCopyWithAssumedAlignment<128>
79
+ >,
80
+ Element>;
81
+
82
+ static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128;
83
+ static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment");
84
+
85
+ struct TensorStorage : cute::aligned_struct<SmemAlignmentdKV> {
86
+ cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dk;
87
+ cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dv;
88
+ };
89
+
90
+ using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_k, d, head, batch)
91
+ using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
92
+
93
+ using TMA_dKV = std::conditional_t<
94
+ Use_TMA,
95
+ decltype(make_tma_copy(
96
+ GmemTiledCopydKVTMA{},
97
+ make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapedKV{}, StridedKV{}),
98
+ SmemLayoutdKVTMA{},
99
+ select<1, 2>(TileShape_MNK{}),
100
+ _1{})), // no mcast for dKV
101
+ std::nullptr_t
102
+ >;
103
+
104
+ // Host side kernel arguments
105
+ struct Arguments {
106
+ Element* ptr_dK;
107
+ ShapedKV const shape_dK;
108
+ StridedKV const stride_dK;
109
+ Element* ptr_dV;
110
+ StridedKV const stride_dV;
111
+ int const num_heads_q;
112
+ int* dk_semaphore;
113
+ int* dv_semaphore;
114
+ int const* cu_seqlens;
115
+ int const* seqused;
116
+ };
117
+
118
+ // Device side kernel params
119
+ struct Params {
120
+ Element* ptr_dK;
121
+ ShapedKV const shape_dK;
122
+ StridedKV const stride_dK;
123
+ Element* ptr_dV;
124
+ StridedKV const stride_dV;
125
+ TMA_dKV tma_store_dK, tma_store_dV;
126
+ int const* cu_seqlens = nullptr;
127
+ int const* seqused = nullptr;
128
+ };
129
+
130
+ static Params
131
+ to_underlying_arguments(Arguments const& args) {
132
+ Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK);
133
+ Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV);
134
+ TMA_dKV tma_store_dK = [&] {
135
+ if constexpr (Use_TMA) {
136
+ return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
137
+ } else {
138
+ return nullptr;
139
+ }
140
+ }();
141
+ TMA_dKV tma_store_dV = [&] {
142
+ if constexpr (Use_TMA) {
143
+ return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
144
+ } else {
145
+ return nullptr;
146
+ }
147
+ }();
148
+ return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV,
149
+ tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
150
+ }
151
+
152
+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
153
+ CUTLASS_DEVICE
154
+ static void prefetch_tma_descriptors(Params const& params) {
155
+ if constexpr (Use_TMA) {
156
+ cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor());
157
+ cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor());
158
+ }
159
+ }
160
+
161
+ template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
162
+ CUTLASS_DEVICE void
163
+ store(Params const& params,
164
+ FrgTensorO const& tdKrdK,
165
+ FrgTensorO const& tdVrdV,
166
+ SharedStorage& shared_storage,
167
+ TiledMma tiled_mma,
168
+ int thread_idx,
169
+ cute::tuple<int32_t, int32_t, int32_t> const& block_coord
170
+ ) {
171
+
172
+ auto [n_block, bidh, bidb] = block_coord;
173
+ Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{}));
174
+ Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{}));
175
+ Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{}));
176
+ Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{}));
177
+ auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma);
178
+ auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx);
179
+
180
+ Tensor tdVrdV_out = make_tensor_like<Element>(tdVrdV);
181
+ flash::convert_type_out(tdVrdV, tdVrdV_out);
182
+ Tensor tdKrdK_out = make_tensor_like<Element>(tdKrdK);
183
+ flash::convert_type_out(tdKrdK, tdKrdK_out);
184
+ Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
185
+ Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
186
+ // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); }
187
+ Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
188
+ Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
189
+
190
+ // Make sure all WGs have finished reading K and V
191
+ flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
192
+ cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
193
+ cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
194
+ if constexpr (Use_TMA) {
195
+ cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
196
+ cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
197
+ cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
198
+
199
+ Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK);
200
+ Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK);
201
+ Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
202
+ Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
203
+ auto block_tma_dK = params.tma_store_dK.get_slice(_0{});
204
+ auto block_tma_dV = params.tma_store_dV.get_slice(_0{});
205
+ Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
206
+ Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
207
+ Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
208
+ Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
209
+ int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
210
+ if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
211
+ cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
212
+ cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
213
+ if (cute::elect_one_sync()) {
214
+ cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);
215
+ cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);
216
+ tma_store_arrive();
217
+ }
218
+ }
219
+ tma_store_wait<0>();
220
+ // // Tell warp 0 that smem_k and smem_v are ready
221
+ // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
222
+
223
+ } else {
224
+ flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
225
+ static constexpr int kBlockN = get<1>(TileShape_MNK{});
226
+ flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
227
+ bool const is_varlen = Varlen && params.cu_seqlens;
228
+ Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
229
+ Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
230
+ Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
231
+ Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
232
+
233
+ GmemTiledCopydKV gmem_tiled_copy_dKV;
234
+ auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
235
+ Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
236
+ Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
237
+ Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
238
+ Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)
239
+ Tensor tdKVrdV = make_fragment_like(tdKVgdV);
240
+ Tensor tdKVrdK = make_fragment_like(tdKVgdK);
241
+ Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k)
242
+ // Repeat the partitioning with identity layouts
243
+ Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
244
+ Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
245
+ #pragma unroll
246
+ for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
247
+ // Need to check OOB when reading from smem if kBlockN isn't evenly tiled
248
+ static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
249
+ flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
250
+ gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdKV, kBlockN);
251
+ flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
252
+ gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdKV, kBlockN);
253
+ // // Tell warp 0 that smem_k and smem_v are ready
254
+ // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v
255
+ // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
256
+ // Construct identity layout for gdKV
257
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
258
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
259
+ gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
260
+ );
261
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
262
+ gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
263
+ );
264
+ }
265
+ }
266
+
267
+ CUTLASS_DEVICE void
268
+ store_tail() {
269
+ // if constexpr (Use_TMA) { tma_store_wait<0>(); }
270
+ }
271
+
272
+ // Write 0 to dK and dV
273
+ CUTLASS_DEVICE void
274
+ store_zero(
275
+ Params const& params,
276
+ int thread_idx,
277
+ cute::tuple<int32_t, int32_t, int32_t> const& block_coord
278
+ ) {
279
+ static constexpr int kBlockN = get<1>(TileShape_MNK{});
280
+ auto [n_block, bidh, bidb] = block_coord;
281
+ flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
282
+ bool const is_varlen = Varlen && params.cu_seqlens;
283
+ Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
284
+ Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
285
+ Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
286
+ Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
287
+
288
+ GmemTiledCopydKV gmem_tiled_copy_dKV;
289
+ auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
290
+ Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
291
+ Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
292
+ Tensor tdKVrdKV = make_fragment_like(tdKVgdK);
293
+ clear(tdKVrdKV);
294
+ // Construct identity layout for gdKV
295
+ Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
296
+ // Repeat the partitioning with identity layouts
297
+ Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
298
+ Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
299
+ #pragma unroll
300
+ for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
301
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
302
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
303
+ gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN
304
+ );
305
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
306
+ gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN
307
+ );
308
+ }
309
+
310
+ };
311
+
312
+ template <class TileShape_MNK_, class ElementAccum, class ArchTag_,
313
+ int NumEpilogueThreads_, bool Varlen_, bool Deterministic>
314
+ struct CollectiveEpilogueBwdGQA {
315
+
316
+ using TileShape_MNK = TileShape_MNK_;
317
+ using Element = ElementAccum;
318
+ using ArchTag = ArchTag_;
319
+ static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
320
+ static constexpr bool Varlen = Varlen_;
321
+ static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90;
322
+
323
+ static_assert(ArchTag::kMinComputeCapability >= 80);
324
+
325
+ static constexpr int kBlockN = get<1>(TileShape_MNK{});
326
+ static constexpr int kHeadDim = get<2>(TileShape_MNK{});
327
+ static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp");
328
+ static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup;
329
+ // Thread layout, 256 or 384 threads per row
330
+ // We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ
331
+ using R2SLayoutAtomdKVaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumWarpGroups>>>;
332
+ using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdKVaccum{},
333
+ Layout<Shape < _4>>{})); // Val layout, 4 vals per store
334
+ // For Sm80
335
+ using R2GLayoutAtomdKVaccum = Layout<Shape<Int<NumEpilogueThreads>>>;
336
+ using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2GLayoutAtomdKVaccum{},
337
+ Layout<Shape < _1>>{})); // Val layout, 1 vals per store
338
+
339
+ using SmemLayoutdKVaccum = Layout<Shape<Int<kBlockN * kHeadDim / NumWarpGroups>, Int<NumWarpGroups>>>;
340
+ using SmemLayoutdKVaccumFlat = Layout<Shape<Int<kBlockN * kHeadDim>>>;
341
+
342
+ // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we
343
+ // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue.
344
+ static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256);
345
+ struct TensorStorageTMA : cute::aligned_struct<SmemAlignment> {
346
+ cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdKVaccum>, SmemAlignment> smem_dkv;
347
+ };
348
+ struct TensorStorageSTG {
349
+ cute::array<ElementAccum, 0> smem_dkv;
350
+ };
351
+ using TensorStorage = std::conditional_t<Use_TMA, TensorStorageTMA, TensorStorageSTG>;
352
+
353
+ using ShapedKV = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_k_rounded * d, head, batch)
354
+ using StridedKV = cute::Stride<_1, int64_t, int64_t>;
355
+
356
+ // Host side kernel arguments
357
+ struct Arguments {
358
+ ElementAccum* ptr_dKaccum;
359
+ ShapedKV const shape_dKaccum;
360
+ StridedKV const stride_dKaccum;
361
+ ElementAccum* ptr_dVaccum;
362
+ StridedKV const stride_dVaccum;
363
+ int num_heads_q;
364
+ int* dk_semaphore;
365
+ int* dv_semaphore;
366
+ int const* cu_seqlens;
367
+ int const* seqused;
368
+ };
369
+
370
+ // Device side kernel params
371
+ struct Params {
372
+ ElementAccum* ptr_dKaccum;
373
+ ShapedKV const shape_dKaccum;
374
+ StridedKV const stride_dKaccum;
375
+ ElementAccum* ptr_dVaccum;
376
+ StridedKV const stride_dVaccum;
377
+ cutlass::FastDivmod qhead_per_khead_divmod;
378
+ int* dk_semaphore;
379
+ int* dv_semaphore;
380
+ int const* cu_seqlens = nullptr;
381
+ int const* seqused = nullptr;
382
+ };
383
+
384
+ static Params
385
+ to_underlying_arguments(Arguments const& args) {
386
+ if constexpr (Deterministic) {
387
+ assert(args.dk_semaphore != nullptr);
388
+ assert(args.dv_semaphore != nullptr);
389
+ }
390
+ return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.stride_dVaccum,
391
+ cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))),
392
+ args.dk_semaphore, args.dv_semaphore,
393
+ args.cu_seqlens, args.seqused};
394
+ }
395
+
396
+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
397
+ CUTLASS_DEVICE
398
+ static void prefetch_tma_descriptors(Params const& params) {
399
+ }
400
+
401
+ template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
402
+ CUTLASS_DEVICE void
403
+ store(Params const& params,
404
+ FrgTensorO const& tdKrdK,
405
+ FrgTensorO const& tdVrdV,
406
+ SharedStorage& shared_storage,
407
+ TiledMma tiled_mma,
408
+ int thread_idx,
409
+ cute::tuple<int32_t, int32_t, int32_t> const& block_coord
410
+ ) {
411
+
412
+ auto [n_block, bidh, bidb] = block_coord;
413
+ int bidh_idx_in_group;
414
+ int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh);
415
+ Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{});
416
+ Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{});
417
+ static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum);
418
+
419
+ flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused};
420
+ bool const is_varlen = Varlen && params.cu_seqlens;
421
+ Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
422
+ Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dKaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
423
+ Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
424
+ Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
425
+
426
+ R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum;
427
+ auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
428
+ Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV);
429
+
430
+ // Only used if !Use_TMA
431
+ R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum;
432
+ auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
433
+
434
+ // Make sure all WGs have finished reading K and V, otherwise we get racy dQ
435
+ // because smem_q could be changed.
436
+ flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
437
+ if constexpr (Use_TMA) {
438
+ Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N)
439
+ cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum);
440
+ }
441
+
442
+ // int const num_batch = params.num_batch;
443
+ int const num_batch = get<2>(params.shape_dKaccum);
444
+ int const num_head_kv = get<1>(params.shape_dKaccum);
445
+ int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv;
446
+ using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
447
+
448
+ // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
449
+
450
+ if constexpr (Deterministic) {
451
+ Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
452
+ }
453
+ // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);}
454
+ if constexpr (Use_TMA) {
455
+ cutlass::arch::fence_view_async_shared();
456
+ cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
457
+ if (thread_idx == 0) {
458
+ SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
459
+ tma_store_arrive();
460
+ tma_store_wait<0>();
461
+ }
462
+ } else {
463
+ Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV);
464
+ Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum);
465
+ static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic)));
466
+ #pragma unroll
467
+ for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); }
468
+ }
469
+ if constexpr (Deterministic) {
470
+ Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
471
+ }
472
+
473
+ if constexpr (Use_TMA) {
474
+ cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
475
+ Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N)
476
+ cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum);
477
+ }
478
+ lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv;
479
+ // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
480
+
481
+ if constexpr (Deterministic) {
482
+ Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
483
+ }
484
+ // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);}
485
+ if constexpr (Use_TMA) {
486
+ cutlass::arch::fence_view_async_shared();
487
+ cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
488
+ if (thread_idx == 0) {
489
+ SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
490
+ tma_store_arrive();
491
+ tma_store_wait<0>();
492
+ }
493
+ } else {
494
+ Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK);
495
+ Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum);
496
+ static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic)));
497
+ #pragma unroll
498
+ for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); }
499
+ }
500
+ if constexpr (Deterministic) {
501
+ Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
502
+ }
503
+ // // Tell warp 0 that smem_k and smem_v are ready
504
+ // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
505
+ }
506
+
507
+ CUTLASS_DEVICE void
508
+ store_tail() {
509
+ }
510
+
511
+ // Write 0 to dK and dV
512
+ CUTLASS_DEVICE void
513
+ store_zero(
514
+ Params const& params,
515
+ int thread_idx,
516
+ cute::tuple<int32_t, int32_t, int32_t> const& block_coord
517
+ ) {
518
+ // Don't need to do anything since dKaccum and dVaccum are already zero-initialized
519
+ }
520
+
521
+ };
522
+
523
+ } // namespace flash
flash-attn/epilogue_fwd.hpp ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <cutlass/cutlass.h>
8
+ #include <cutlass/fast_math.h> // For FastDivMod
9
+ #include "cute/tensor.hpp"
10
+
11
+ #include "cutlass/gemm/collective/builders/sm90_common.inl"
12
+ #include "cutlass/epilogue/collective/builders/sm90_common.inl"
13
+
14
+ #include "seqlen.h"
15
+ #include "named_barrier.hpp"
16
+ #include "pack_gqa.h"
17
+ #include "utils.h"
18
+
19
+ namespace flash {
20
+
21
+ using namespace cute;
22
+
23
+ template <class TileShape_MNK_PV_, class ClusterShape_, class Element_, class ArchTag_,
24
+ int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false>
25
+ struct CollectiveEpilogueFwd {
26
+
27
+ using TileShape_MNK_PV = TileShape_MNK_PV_;
28
+ using ClusterShape = ClusterShape_;
29
+ using Element = Element_;
30
+ using ElementPartial = float;
31
+ using ArchTag = ArchTag_;
32
+ static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
33
+ static constexpr bool Varlen = Varlen_;
34
+ static constexpr bool PackGQA = PackGQA_;
35
+ static constexpr bool Split = Split_;
36
+ static constexpr bool Use_smem = !(Split && !Varlen);
37
+ static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA;
38
+
39
+ static_assert(ArchTag::kMinComputeCapability >= 80);
40
+ static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);
41
+ static_assert(sizeof(Element) <= 2);
42
+
43
+ static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
44
+ static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{});
45
+
46
+ static constexpr bool LargeHeadDimV = kHeadDimV > 256;
47
+
48
+ using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
49
+
50
+ // These are for storing the output tensor without TMA (e.g., for setting output to zero)
51
+ static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);
52
+ static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
53
+ // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements
54
+ // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times
55
+ // we need to call divmod.
56
+ static constexpr int kBytePerRow = kHeadDimV * sizeof(Element);
57
+ static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
58
+ static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore;
59
+ // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp
60
+ static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);
61
+ static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
62
+ using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
63
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
64
+ static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow");
65
+ using GmemTiledCopyO = decltype(
66
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
67
+ GmemLayoutAtom{},
68
+ Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); // Val layout, 8 or 16 vals per store
69
+
70
+ using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
71
+ decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>());
72
+ using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{})));
73
+ static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
74
+ static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
75
+ using SmemLayoutAtomO = decltype(
76
+ composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
77
+ Layout<Shape<_8, Int<kBlockKGmem>>,
78
+ Stride<Int<kBlockKGmem>, _1>>{}));
79
+ using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{})));
80
+ using SmemLayoutO = std::conditional_t<ArchTag::kMinComputeCapability >= 90, SmemLayoutOTMA, SmemLayoutOSTS>;
81
+
82
+ using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch, num_splits)
83
+ using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
84
+ using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits)
85
+ // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
86
+ using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
87
+ using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;
88
+ // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
89
+ using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
90
+ using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;
91
+
92
+ using CopyOpR2S = std::conditional_t<
93
+ ArchTag::kMinComputeCapability >= 90,
94
+ // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
95
+ decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>()),
96
+ AutoVectorizingCopyWithAssumedAlignment<128>
97
+ >;
98
+ using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
99
+
100
+ // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{});
101
+ // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment");
102
+ // struct TensorStorage : cute::aligned_struct<SmemAlignmentO> {
103
+ // cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0, SmemAlignmentO> smem_o;
104
+ // };
105
+ struct TensorStorage : cute::aligned_struct<128> {
106
+ cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0> smem_o;
107
+ };
108
+
109
+ using TMA_O = std::conditional_t<
110
+ Use_TMA_O,
111
+ decltype(make_tma_copy(
112
+ GmemTiledCopyOTMA{},
113
+ make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
114
+ SmemLayoutOTMA{},
115
+ select<0, 1>(TileShape_MNK_PV{}),
116
+ _1{})), // no mcast for O
117
+ std::nullptr_t
118
+ >;
119
+
120
+ // Host side kernel arguments
121
+ struct Arguments {
122
+ Element* ptr_O;
123
+ ShapeO const shape_O;
124
+ StrideO const stride_O;
125
+ ElementPartial* ptr_O_partial;
126
+ StrideO const stride_O_partial;
127
+ float* ptr_LSE;
128
+ StrideLSE const stride_LSE;
129
+ float* ptr_LSE_partial;
130
+ StrideLSE const stride_LSE_partial;
131
+ int32_t const nheads_kv;
132
+ int const* cu_seqlens = nullptr;
133
+ int const* seqused = nullptr;
134
+ };
135
+
136
+ // Device side kernel params
137
+ struct Params {
138
+ Element* ptr_O;
139
+ ShapeO const shape_O;
140
+ StrideO const stride_O;
141
+ ShapeOPacked const shape_O_packed;
142
+ StrideOPacked const stride_O_packed;
143
+ ElementPartial* ptr_O_partial;
144
+ StrideO const stride_O_partial;
145
+ StrideOPacked const stride_O_partial_packed;
146
+ float* ptr_LSE;
147
+ StrideLSE const stride_LSE;
148
+ ShapeLSEPacked const shape_LSE_packed;
149
+ StrideLSEPacked const stride_LSE_packed;
150
+ float* ptr_LSE_partial;
151
+ StrideLSE const stride_LSE_partial;
152
+ StrideLSEPacked const stride_LSE_partial_packed;
153
+ cutlass::FastDivmod qhead_per_khead_divmod;
154
+ TMA_O tma_store_O;
155
+ int const* cu_seqlens = nullptr;
156
+ int const* seqused = nullptr;
157
+ };
158
+
159
+ static Params
160
+ to_underlying_arguments(Arguments const& args) {
161
+ Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
162
+ TMA_O tma_store_O = [&]{
163
+ if constexpr (Use_TMA_O) {
164
+ return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast
165
+ } else {
166
+ return nullptr;
167
+ }
168
+ }();
169
+ // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)
170
+ int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);
171
+ auto const shape_O_packed = cute::conditional_return<!PackGQA>(
172
+ args.shape_O,
173
+ make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
174
+ );
175
+ auto const stride_O_packed = cute::conditional_return<!PackGQA>(
176
+ args.stride_O,
177
+ make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O))
178
+ );
179
+ auto const stride_O_partial_packed = cute::conditional_return<!PackGQA>(
180
+ args.stride_O_partial,
181
+ make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial))
182
+ );
183
+ // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)
184
+ auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(
185
+ select<0, 2, 3, 4>(args.shape_O),
186
+ make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
187
+ );
188
+ auto const stride_LSE_packed = cute::conditional_return<!PackGQA>(
189
+ args.stride_LSE,
190
+ make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE))
191
+ );
192
+ auto const stride_LSE_partial_packed = cute::conditional_return<!PackGQA>(
193
+ args.stride_LSE_partial,
194
+ make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial))
195
+ );
196
+ return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed,
197
+ args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed,
198
+ args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed,
199
+ args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed,
200
+ cutlass::FastDivmod(qhead_per_khead),
201
+ tma_store_O, args.cu_seqlens, args.seqused};
202
+ }
203
+
204
+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
205
+ CUTLASS_DEVICE
206
+ static void prefetch_tma_descriptors(Params const& params) {
207
+ if constexpr (Use_TMA_O) {
208
+ cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor());
209
+ }
210
+ }
211
+
212
+ template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
213
+ CUTLASS_DEVICE void
214
+ store(Params const& params,
215
+ FrgTensorO& tOrO,
216
+ FrgTensorLSE const& lse,
217
+ SharedStorage& shared_storage,
218
+ TiledMma tiled_mma,
219
+ int thread_idx,
220
+ cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
221
+ ) {
222
+
223
+ auto [m_block, bidh, bidb, split_idx] = block_coord;
224
+ int num_splits = get<4>(params.shape_O_packed);
225
+ if constexpr (Split && Varlen) {
226
+ uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
227
+ int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
228
+ num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
229
+ split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
230
+ }
231
+ bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
232
+
233
+ Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});
234
+ // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO);
235
+
236
+ static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4);
237
+ // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion.
238
+ // Otherwise we can permute after conversion.
239
+ if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); }
240
+ Tensor tOrO_out = make_tensor_like<Element>(tOrO);
241
+ flash::convert_type_out(tOrO, tOrO_out);
242
+ if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); }
243
+
244
+ // Make sure all WGs have finished reading V
245
+ // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that
246
+ // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with
247
+ // cp.async if we need).
248
+ flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
249
+
250
+ // Step 1: Write O from rmem -> smem
251
+ if constexpr (Use_smem) {
252
+ auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
253
+ auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
254
+ Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
255
+ Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
256
+ // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N)
257
+ cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
258
+ if constexpr (Use_TMA_O) {
259
+ cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
260
+ cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
261
+ cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
262
+ } else {
263
+ flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
264
+ }
265
+ } else {
266
+ if constexpr (ArchTag::kMinComputeCapability >= 90) {
267
+ #pragma unroll
268
+ for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
269
+ shared_storage.pipelines.barrier_O.arrive(cta_id);
270
+ }
271
+ }
272
+ }
273
+
274
+ flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
275
+ bool is_varlen = Varlen && params.cu_seqlens;
276
+ int offset_o = seqlen_info.offset;
277
+ int seqlen_o = seqlen_info.seqlen;
278
+ int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
279
+
280
+ // Step 2: Write LSE from rmem -> gmem
281
+ auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
282
+ // (MMA,MMA_M,MMA_K)
283
+ Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
284
+ static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
285
+ static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
286
+ Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));
287
+ Tensor taccOcO_row = taccOcO_rowcol(_, _0{});
288
+ CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
289
+
290
+ using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
291
+ using PackGQApartial_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>;
292
+
293
+ Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
294
+ params.shape_LSE_packed,
295
+ !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
296
+ // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
297
+ if (!LargeHeadDimV || warp_group_idx == 0) {
298
+ if constexpr (!PackGQA) {
299
+ #pragma unroll
300
+ for (int mi = 0; mi < size(lse); ++mi) {
301
+ int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));
302
+ if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); }
303
+ }
304
+ } else {
305
+ PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
306
+ }
307
+ }
308
+
309
+ // Step 3: Write O from smem -> gmem
310
+ if constexpr (Use_TMA_O) {
311
+ Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
312
+ Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
313
+ auto block_tma_O = params.tma_store_O.get_slice(_0{});
314
+ Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
315
+ Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
316
+ int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
317
+ if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
318
+ cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
319
+ cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
320
+ if (cute::elect_one_sync()) {
321
+ cute::copy(params.tma_store_O, tOsO, tOgO);
322
+ tma_store_arrive();
323
+ tma_store_wait<0>();
324
+ #pragma unroll
325
+ for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
326
+ shared_storage.pipelines.barrier_O.arrive(cta_id);
327
+ }
328
+ }
329
+ }
330
+ } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence
331
+ if (!is_split) {
332
+ Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
333
+ Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
334
+ // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }
335
+ GmemTiledCopyO gmem_tiled_copy_O;
336
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
337
+ Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
338
+ // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N)
339
+ Tensor tOrO = make_fragment_like(tOsO);
340
+ cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
341
+ if constexpr (ArchTag::kMinComputeCapability >= 90) {
342
+ cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v
343
+ #pragma unroll
344
+ for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
345
+ shared_storage.pipelines.barrier_O.arrive(cta_id);
346
+ }
347
+ }
348
+ if constexpr (!PackGQA) {
349
+ // (BLK_M,BLK_K) -> (blk_m,blk_k)
350
+ Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
351
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO)));
352
+ #pragma unroll
353
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
354
+ Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
355
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
356
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
357
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
358
+ );
359
+ } else {
360
+ // If PackGQA, we split the work of compute O_ptr among threads in the same row
361
+ PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
362
+ }
363
+ } else {
364
+ Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
365
+ Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
366
+ // We already arrived on barrier_O earlier if !Use_smem
367
+ if constexpr (Use_smem) {
368
+ if constexpr (ArchTag::kMinComputeCapability >= 90) {
369
+ #pragma unroll
370
+ for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
371
+ shared_storage.pipelines.barrier_O.arrive(cta_id);
372
+ }
373
+ }
374
+ }
375
+ if constexpr (!PackGQA) {
376
+ static constexpr int kGmemElemsPerStoreDirect = 2;
377
+ cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial> gmem_copy_direct;
378
+ // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
379
+ Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));
380
+ Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
381
+ Tensor tOgO = thread_mma.partition_C(gOpartial);
382
+ Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout()));
383
+ Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
384
+ Tensor taccOcO_col = taccOcO_rowcol(_0{}, _);
385
+ #pragma unroll
386
+ for (int m = 0; m < size(taccOcO_row); ++m) {
387
+ if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) {
388
+ #pragma unroll
389
+ for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) {
390
+ if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) {
391
+ cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k));
392
+ }
393
+ }
394
+ }
395
+ }
396
+ } else {
397
+ PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
398
+ }
399
+ }
400
+ }
401
+ }
402
+
403
+ CUTLASS_DEVICE void
404
+ store_tail() {
405
+ // Don't need to do tma_store_wait<0>() here since we already did in @store
406
+ }
407
+
408
+ // Write 0 to output and -inf to LSE
409
+ CUTLASS_DEVICE void
410
+ store_zero(
411
+ Params const& params,
412
+ int thread_idx,
413
+ cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
414
+ ) {
415
+ static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
416
+ auto [m_block, bidh, bidb, split_idx] = block_coord;
417
+ int num_splits = get<4>(params.shape_O_packed);
418
+ if constexpr (Split && Varlen) {
419
+ uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
420
+ int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
421
+ num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
422
+ split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
423
+ }
424
+ bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
425
+
426
+ flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
427
+ bool const is_varlen = Varlen && params.cu_seqlens;
428
+ int offset_o = seqlen_info.offset;
429
+ int seqlen_o = seqlen_info.seqlen;
430
+ int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
431
+ Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
432
+ params.shape_LSE_packed,
433
+ !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
434
+ Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block));
435
+
436
+ static_assert(kBlockM <= NumEpilogueThreads);
437
+ if (thread_idx < kBlockM) {
438
+ const int row = m_block * kBlockM + thread_idx;
439
+ if constexpr (!PackGQA) {
440
+ if (row < seqlen_o) { mLSE(row) = -INFINITY; }
441
+ } else {
442
+ if (row < seqlen_o * qhead_per_khead) {
443
+ int m_idx, h_idx;
444
+ m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row);
445
+ // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord"
446
+ mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY;
447
+ }
448
+ }
449
+ }
450
+
451
+ // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used,
452
+ // since it will not use the value of O if LSE is -inf.
453
+ if (!is_split) {
454
+ Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
455
+
456
+ GmemTiledCopyO gmem_tiled_copy_O;
457
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
458
+ Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
459
+ if constexpr (!PackGQA) {
460
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
461
+ #pragma unroll
462
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
463
+ Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
464
+ Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
465
+ Tensor tOrO = make_fragment_like(tOgO);
466
+ cute::clear(tOrO);
467
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
468
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
469
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
470
+ );
471
+ } else {
472
+ // If PackGQA, we split the work of compute O_ptr among threads in the same row
473
+ using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
474
+ Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO)));
475
+ cute::clear(tOrO);
476
+ PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
477
+ }
478
+ }
479
+
480
+ }
481
+
482
+ };
483
+
484
+ } // namespace flash
flash-attn/flash.h ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <cuda.h>
8
+ #include <vector>
9
+
10
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
11
+
12
+ struct Qkv_params {
13
+ using index_t = int64_t;
14
+ // The QKV matrices.
15
+ void *__restrict__ q_ptr;
16
+ void *__restrict__ k_ptr;
17
+ void *__restrict__ v_ptr;
18
+
19
+ // The stride between rows of the Q, K and V matrices.
20
+ index_t q_batch_stride;
21
+ index_t k_batch_stride;
22
+ index_t v_batch_stride;
23
+ index_t q_row_stride;
24
+ index_t k_row_stride;
25
+ index_t v_row_stride;
26
+ index_t q_head_stride;
27
+ index_t k_head_stride;
28
+ index_t v_head_stride;
29
+ index_t v_dim_stride;
30
+
31
+ // The number of heads.
32
+ int h, h_k;
33
+ };
34
+
35
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
36
+
37
+ struct Flash_fwd_params : public Qkv_params {
38
+ using index_t = int64_t;
39
+
40
+ // The O matrix (output).
41
+ void * __restrict__ o_ptr;
42
+ void * __restrict__ oaccum_ptr;
43
+
44
+ // The stride between rows of O.
45
+ index_t o_batch_stride;
46
+ index_t o_row_stride;
47
+ index_t o_head_stride;
48
+
49
+ // The pointer to the softmax sum.
50
+ void * __restrict__ softmax_lse_ptr;
51
+ void * __restrict__ softmax_lseaccum_ptr;
52
+
53
+ // For FP8 scaling
54
+ float * __restrict__ q_descale_ptr;
55
+ float * __restrict__ k_descale_ptr;
56
+ float * __restrict__ v_descale_ptr;
57
+ index_t q_descale_batch_stride;
58
+ index_t q_descale_head_stride;
59
+ index_t k_descale_batch_stride;
60
+ index_t k_descale_head_stride;
61
+ index_t v_descale_batch_stride;
62
+ index_t v_descale_head_stride;
63
+
64
+ // The dimensions.
65
+ int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
66
+ int total_q, total_k, total_knew;
67
+ int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q
68
+ int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim
69
+
70
+ // The scaling factors for the kernel.
71
+ float scale_softmax;
72
+ float softcap;
73
+
74
+ // array of length b+1 holding starting offset of each sequence.
75
+ int * __restrict__ cu_seqlens_q;
76
+ int * __restrict__ cu_seqlens_k;
77
+ int * __restrict__ cu_seqlens_knew;
78
+ int * __restrict__ leftpad_k;
79
+
80
+ // If provided, the actual length of each q/k sequence.
81
+ int *__restrict__ seqused_q;
82
+ int *__restrict__ seqused_k;
83
+
84
+ // The stride between rows of Oaccum.
85
+ index_t oaccum_split_stride;
86
+ index_t oaccum_batch_stride;
87
+ index_t oaccum_row_stride;
88
+ index_t oaccum_head_stride;
89
+
90
+ // The stride between rows of LSEaccum.
91
+ index_t lseaccum_split_stride;
92
+ index_t lseaccum_batch_stride;
93
+ index_t lseaccum_head_stride;
94
+
95
+ // The K_new and V_new matrices.
96
+ void * __restrict__ knew_ptr;
97
+ void * __restrict__ vnew_ptr;
98
+
99
+ // The stride between rows of the Q, K and V matrices.
100
+ index_t knew_batch_stride;
101
+ index_t vnew_batch_stride;
102
+ index_t knew_row_stride;
103
+ index_t vnew_row_stride;
104
+ index_t knew_head_stride;
105
+ index_t vnew_head_stride;
106
+
107
+ void *__restrict__ qv_ptr;
108
+ index_t qv_batch_stride;
109
+ index_t qv_row_stride;
110
+ index_t qv_head_stride;
111
+
112
+ // The cos and sin matrices for rotary embedding.
113
+ void * __restrict__ rotary_cos_ptr;
114
+ void * __restrict__ rotary_sin_ptr;
115
+ int *__restrict__ seqlens_rotary;
116
+
117
+ // The indices to index into the KV cache.
118
+ int * __restrict__ kv_batch_idx;
119
+
120
+ // Paged KV cache
121
+ int * __restrict__ page_table;
122
+ index_t page_table_batch_stride;
123
+ int page_size;
124
+ int num_pages;
125
+ bool pagedkv_tma;
126
+
127
+ // The dropout probability (probability of keeping an activation).
128
+ float p_dropout;
129
+ // uint32_t p_dropout_in_uint;
130
+ // uint16_t p_dropout_in_uint16_t;
131
+ uint8_t p_dropout_in_uint8_t;
132
+
133
+ // Scale factor of 1 / (1 - p_dropout).
134
+ float rp_dropout;
135
+
136
+ // Local window size
137
+ int window_size_left, window_size_right;
138
+
139
+ // Pointer to the RNG seed (idx 0) and offset (idx 1).
140
+ uint64_t * rng_state;
141
+
142
+ bool is_bf16;
143
+ bool is_fp32;
144
+ bool is_e4m3;
145
+ bool is_causal;
146
+ bool is_local;
147
+
148
+ bool is_rotary_interleaved;
149
+
150
+ int num_splits; // For split-KV version
151
+ bool pack_gqa;
152
+
153
+ int * __restrict__ tile_count_semaphore;
154
+ // int * __restrict__ num_m_blocks_ptr;
155
+ // int * __restrict__ num_n_blocks_ptr;
156
+ int * __restrict__ num_splits_dynamic_ptr;
157
+ bool skip_scheduler_metadata_computation;
158
+
159
+ int arch;
160
+ int num_sm;
161
+
162
+ // The S extra matrix, (num_heads)
163
+ void *__restrict__ s_aux_ptr;
164
+ };
165
+
166
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
167
+
168
+ struct Flash_bwd_params : public Flash_fwd_params {
169
+ using index_t = int64_t;
170
+
171
+ // The dO and dQKV matrices.
172
+ void *__restrict__ do_ptr;
173
+ void *__restrict__ dq_ptr;
174
+ void *__restrict__ dk_ptr;
175
+ void *__restrict__ dv_ptr;
176
+
177
+ // To accumulate dQ
178
+ void *__restrict__ dq_accum_ptr;
179
+ void *__restrict__ dk_accum_ptr;
180
+ void *__restrict__ dv_accum_ptr;
181
+
182
+ // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
183
+ // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
184
+ // dv_accum_ptr;
185
+
186
+ // The stride between rows of the dO, dQ, dK and dV matrices.
187
+ index_t do_batch_stride;
188
+ index_t do_row_stride;
189
+ index_t do_head_stride;
190
+ index_t dq_batch_stride;
191
+ index_t dk_batch_stride;
192
+ index_t dv_batch_stride;
193
+ index_t dq_row_stride;
194
+ index_t dk_row_stride;
195
+ index_t dv_row_stride;
196
+ index_t dq_head_stride;
197
+ index_t dk_head_stride;
198
+ index_t dv_head_stride;
199
+
200
+ // The pointer to the softmax d sum.
201
+ void *__restrict__ dsoftmax_sum;
202
+ void *__restrict__ softmax_lse_log2_ptr;
203
+
204
+ int *__restrict__ dq_semaphore;
205
+ int *__restrict__ dk_semaphore;
206
+ int *__restrict__ dv_semaphore;
207
+
208
+ bool deterministic;
209
+ index_t dq_accum_split_stride;
210
+ };
211
+
212
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
213
+
214
+ template <int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
215
+ void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
216
+ void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl);
217
+ template <int Arch, typename T, int kHeadDim, bool Has_softcap>
218
+ void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
219
+ template <typename T, typename Tpartial, int kBlockK>
220
+ void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
flash-attn/flash_api.cpp ADDED
@@ -0,0 +1,1623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
6
+ #include <torch/nn/functional.h>
7
+ #include <torch/version.h> // For TORCH_VERSION* macros
8
+ #include <ATen/cuda/CUDAContext.h>
9
+ #include <c10/cuda/CUDAGuard.h>
10
+
11
+ #include <cutlass/numeric_types.h>
12
+
13
+ #include "flash.h"
14
+ #include "static_switch.h"
15
+ #include "tile_size.h"
16
+ #include "heuristics.h"
17
+ #include "cuda_check.h"
18
+
19
+ // Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909
20
+ // This is so that we can pass in torch.dtype as a parameter to the function.
21
+ #if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4)
22
+
23
+ #include <pybind11/pybind11.h>
24
+ #include <pybind11/stl.h>
25
+
26
+ namespace pybind11::detail {
27
+
28
+ template <>
29
+ struct type_caster<at::ScalarType> {
30
+ public:
31
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
32
+ PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype"));
33
+ // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType
34
+ // cannot be default-initialized, we provide this constructor to explicitly
35
+ // initialize that field. The value doesn't matter as it will be overwritten
36
+ // after a successful call to load.
37
+ type_caster() : value(at::kFloat) {}
38
+ bool load(handle src, bool) {
39
+ PyObject* obj = src.ptr();
40
+ if (THPDtype_Check(obj)) {
41
+ value = reinterpret_cast<THPDtype*>(obj)->scalar_type;
42
+ return true;
43
+ }
44
+ return false;
45
+ }
46
+ static handle cast(
47
+ const at::ScalarType& src,
48
+ return_value_policy /* policy */,
49
+ handle /* parent */) {
50
+ return Py_NewRef(torch::getTHPDtype(src));
51
+ }
52
+ };
53
+
54
+ } // namespace pybind11::detail
55
+
56
+ #endif
57
+
58
+ #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
59
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
60
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
61
+
62
+ void set_params_fprop(Flash_fwd_params &params,
63
+ // sizes
64
+ const size_t b,
65
+ const size_t seqlen_q,
66
+ const size_t seqlen_k,
67
+ const size_t seqlen_q_rounded,
68
+ const size_t seqlen_k_rounded,
69
+ const size_t h,
70
+ const size_t h_k,
71
+ const size_t d,
72
+ const size_t d_rounded,
73
+ // device pointers
74
+ const at::Tensor q,
75
+ const at::Tensor k,
76
+ const at::Tensor v,
77
+ at::Tensor out,
78
+ void *cu_seqlens_q_d,
79
+ void *cu_seqlens_k_d,
80
+ void *seqused_q,
81
+ void *seqused_k,
82
+ void *softmax_lse_d,
83
+ float p_dropout,
84
+ float softmax_scale,
85
+ int window_size_left,
86
+ int window_size_right,
87
+ const float softcap=0.f,
88
+ const int sm_margin=0) {
89
+
90
+ // Reset the parameters
91
+ params = {};
92
+
93
+ params.is_bf16 = q.dtype() == torch::kBFloat16;
94
+ params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
95
+
96
+ // Set the pointers and strides.
97
+ params.q_ptr = q.data_ptr();
98
+ params.k_ptr = k.data_ptr();
99
+ params.v_ptr = v.data_ptr();
100
+ // All stride are in elements, not bytes.
101
+ params.q_row_stride = q.stride(-3);
102
+ params.k_row_stride = k.stride(-3);
103
+ params.v_row_stride = v.stride(-3);
104
+ params.q_head_stride = q.stride(-2);
105
+ params.k_head_stride = k.stride(-2);
106
+ params.v_head_stride = v.stride(-2);
107
+ params.v_dim_stride = v.stride(-1);
108
+ params.o_ptr = out.data_ptr();
109
+ params.o_row_stride = out.stride(-3);
110
+ params.o_head_stride = out.stride(-2);
111
+
112
+ if (cu_seqlens_q_d == nullptr) {
113
+ params.q_batch_stride = q.stride(0);
114
+ params.o_batch_stride = out.stride(0);
115
+ }
116
+ if (cu_seqlens_k_d == nullptr) {
117
+ params.k_batch_stride = k.stride(0);
118
+ params.v_batch_stride = v.stride(0);
119
+ }
120
+
121
+ params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
122
+ params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
123
+ params.seqused_q = static_cast<int *>(seqused_q);
124
+ params.seqused_k = static_cast<int *>(seqused_k);
125
+
126
+ // Softmax sum
127
+ params.softmax_lse_ptr = softmax_lse_d;
128
+
129
+ // Set the dimensions.
130
+ params.b = b;
131
+ params.h = h;
132
+ params.h_k = h_k;
133
+ params.seqlen_q = seqlen_q;
134
+ params.seqlen_k = seqlen_k;
135
+ params.seqlen_q_rounded = seqlen_q_rounded;
136
+ params.seqlen_k_rounded = seqlen_k_rounded;
137
+ params.d = d;
138
+ params.d_rounded = d_rounded;
139
+
140
+ // Set the different scale values.
141
+ params.scale_softmax = softmax_scale;
142
+ params.softcap = softcap;
143
+
144
+ // Set this to probability of keeping an element to simplify things.
145
+ params.p_dropout = 1.f - p_dropout;
146
+ // Convert p from float to int so we don't have to convert the random uint to float to compare.
147
+ // [Minor] We want to round down since when we do the comparison we use <= instead of <
148
+ // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
149
+ // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
150
+ params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
151
+ params.rp_dropout = 1.f / params.p_dropout;
152
+ TORCH_CHECK(p_dropout < 1.f);
153
+ #ifdef FLASHATTENTION_DISABLE_DROPOUT
154
+ TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
155
+ #endif
156
+
157
+ // Causal is the special case where window_size_right == 0 and window_size_left < 0.
158
+ // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
159
+ params.is_causal = window_size_left < 0 && window_size_right == 0;
160
+ params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
161
+
162
+ // TODO: check this
163
+ if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k - 1; }
164
+ if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_q - 1; }
165
+ params.window_size_left = window_size_left;
166
+ params.window_size_right = window_size_right;
167
+
168
+ params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
169
+ params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
170
+
171
+ #ifdef FLASHATTENTION_DISABLE_LOCAL
172
+ TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
173
+ #endif
174
+ }
175
+
176
+ void set_params_dgrad(Flash_bwd_params &params,
177
+ // sizes
178
+ const size_t b,
179
+ const size_t seqlen_q,
180
+ const size_t seqlen_k,
181
+ const size_t seqlen_q_rounded,
182
+ const size_t seqlen_k_rounded,
183
+ const size_t h,
184
+ const size_t h_k,
185
+ const size_t d,
186
+ const size_t d_rounded,
187
+ // device pointers
188
+ const at::Tensor q,
189
+ const at::Tensor k,
190
+ const at::Tensor v,
191
+ const at::Tensor out,
192
+ const at::Tensor dout,
193
+ at::Tensor dq,
194
+ at::Tensor dk,
195
+ at::Tensor dv,
196
+ void *cu_seqlens_q_d,
197
+ void *cu_seqlens_k_d,
198
+ void *seqused_q,
199
+ void *seqused_k,
200
+ void *dq_accum_d,
201
+ void *dk_accum_d,
202
+ void *dv_accum_d,
203
+ void *softmax_lse_d,
204
+ void *dsoftmax_sum_d,
205
+ float p_dropout,
206
+ float softmax_scale,
207
+ int window_size_left,
208
+ int window_size_right,
209
+ const float softcap=0.f,
210
+ bool deterministic=false,
211
+ int const sm_margin=0) {
212
+
213
+ set_params_fprop(params,
214
+ b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
215
+ q, k, v, out,
216
+ cu_seqlens_q_d,
217
+ cu_seqlens_k_d,
218
+ seqused_q,
219
+ seqused_k,
220
+ softmax_lse_d,
221
+ p_dropout,
222
+ softmax_scale,
223
+ window_size_left,
224
+ window_size_right,
225
+ softcap,
226
+ sm_margin);
227
+
228
+ // Set the pointers and strides.
229
+ params.do_ptr = dout.data_ptr();
230
+ params.do_row_stride = dout.stride(-3);
231
+ params.do_head_stride = dout.stride(-2);
232
+ params.dq_ptr = dq.data_ptr();
233
+ params.dk_ptr = dk.data_ptr();
234
+ params.dv_ptr = dv.data_ptr();
235
+ params.dq_row_stride = dq.stride(-3);
236
+ params.dk_row_stride = dk.stride(-3);
237
+ params.dv_row_stride = dv.stride(-3);
238
+ params.dq_head_stride = dq.stride(-2);
239
+ params.dk_head_stride = dk.stride(-2);
240
+ params.dv_head_stride = dv.stride(-2);
241
+
242
+ if (cu_seqlens_q_d == nullptr) {
243
+ params.do_batch_stride = dout.stride(0);
244
+ params.dq_batch_stride = dq.stride(0);
245
+ params.dk_batch_stride = dk.stride(0);
246
+ params.dv_batch_stride = dv.stride(0);
247
+ }
248
+
249
+ params.dq_accum_ptr = dq_accum_d;
250
+ params.dk_accum_ptr = dk_accum_d;
251
+ params.dv_accum_ptr = dv_accum_d;
252
+
253
+ // Softmax sum
254
+ params.dsoftmax_sum = dsoftmax_sum_d;
255
+
256
+ params.deterministic = deterministic;
257
+ }
258
+
259
+ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
260
+ // HEADDIM_SWITCH(params.d, [&] {
261
+ // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
262
+ // });
263
+ TORCH_CHECK(params.num_splits >= 1);
264
+ ARCH_SWITCH(params.arch, Arch, [&] {
265
+ SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
266
+ PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] {
267
+ PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {
268
+ // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation
269
+ static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split;
270
+ SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
271
+ if (!params.is_e4m3) {
272
+ if (params.is_bf16) {
273
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
274
+ if (params.d <= 64) {
275
+ if (params.dv > 256 && Arch == 90) {
276
+ return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
277
+ } else if (params.dv > 64 && Arch == 90) {
278
+ return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
279
+ } else {
280
+ return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
281
+ }
282
+ }
283
+ #endif
284
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
285
+ if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
286
+ #endif
287
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
288
+ if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
289
+ #endif
290
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
291
+ if (params.d <= 192) {
292
+ if (params.dv <= 128 && Arch == 90) {
293
+ return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
294
+ } else {
295
+ return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
296
+ }
297
+ }
298
+ #endif
299
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
300
+ if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
301
+ #endif
302
+ } else {
303
+ #ifndef FLASHATTENTION_DISABLE_FP16
304
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
305
+ if (params.d <= 64) {
306
+ if (params.dv > 256 && Arch == 90) {
307
+ return run_mha_fwd_<Arch, cutlass::half_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
308
+ } else if (params.dv > 64 && Arch == 90) {
309
+ return run_mha_fwd_<Arch, cutlass::half_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
310
+ } else {
311
+ return run_mha_fwd_<Arch, cutlass::half_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
312
+ }
313
+ }
314
+ #endif
315
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
316
+ if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::half_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
317
+ #endif
318
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
319
+ if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::half_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
320
+ #endif
321
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
322
+ if (params.d <= 192) {
323
+ if (params.dv <= 128 && Arch == 90) {
324
+ return run_mha_fwd_<Arch, cutlass::half_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
325
+ } else {
326
+ return run_mha_fwd_<Arch, cutlass::half_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
327
+ }
328
+ }
329
+ #endif
330
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
331
+ if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::half_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
332
+ #endif
333
+ #else
334
+ TORCH_CHECK(false, "This flash attention build does not support FP16.");
335
+ #endif
336
+ }
337
+ } else {
338
+ #ifndef FLASHATTENTION_DISABLE_FP8
339
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
340
+ if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
341
+ #endif
342
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
343
+ if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
344
+ #endif
345
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
346
+ if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
347
+ #endif
348
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
349
+ if (params.d <= 192) {
350
+ if (params.dv <= 128 && Arch == 90) {
351
+ return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
352
+ } else {
353
+ return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
354
+ }
355
+ }
356
+ #endif
357
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
358
+ if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
359
+ #endif
360
+ #else
361
+ TORCH_CHECK(false, "This flash attention build does not support FP8.");
362
+ #endif
363
+ }
364
+ });
365
+ });
366
+ });
367
+ });
368
+ });
369
+ }
370
+
371
+ void run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl=false) {
372
+ #ifndef FLASHATTENTION_DISABLE_SPLIT
373
+ // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
374
+ // so that kBlockM is smaller and we have more parallelism.
375
+ if (params.is_fp32) {
376
+ if (params.dv <= 64) {
377
+ run_mha_fwd_combine_<float, float, 64>(params, stream, enable_pdl);
378
+ } else {
379
+ run_mha_fwd_combine_<float, float, 128>(params, stream, enable_pdl);
380
+ }
381
+ } else if (params.is_bf16) {
382
+ if (params.dv <= 64) {
383
+ run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream, enable_pdl);
384
+ } else {
385
+ run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream, enable_pdl);
386
+ }
387
+ } else {
388
+ if (params.dv <= 64) {
389
+ run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream, enable_pdl);
390
+ } else {
391
+ run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream, enable_pdl);
392
+ }
393
+ }
394
+ #else
395
+ TORCH_CHECK(false, "This flash attention build does not support combine kernels.");
396
+ #endif
397
+ }
398
+
399
+ inline bool get_pagedkv_tma(Flash_fwd_params const& params) {
400
+ if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; }
401
+ // This needs to match the kernel configs
402
+ auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f);
403
+ int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
404
+ int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90);
405
+ // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower,
406
+ // at least for MLA.
407
+ return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM;
408
+ }
409
+
410
+ inline bool get_pack_gqa(Flash_fwd_params const& params) {
411
+ // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size.
412
+ // Has little effect on speed.
413
+ if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; }
414
+ #ifdef FLASHATTENTION_DISABLE_PACKGQA
415
+ return false;
416
+ #else
417
+ // params.page_table must already be set
418
+ if (params.h == params.h_k) { return false; }
419
+ // This needs to match the kernel configs
420
+ auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
421
+ int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
422
+ return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
423
+ #endif
424
+ }
425
+
426
+ inline int get_num_splits(Flash_fwd_params const& params) {
427
+ #ifdef FLASHATTENTION_DISABLE_SPLIT
428
+ return 1;
429
+ #else
430
+ // Always enable PackGQA for Split
431
+ // params.page_table must already be set
432
+ // This needs to match the kernel configs
433
+ bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
434
+ auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
435
+ // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
436
+ // has not been set here. It's OK though because we might just underestimate kBlockN a bit
437
+ auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
438
+ int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
439
+ int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
440
+ int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);
441
+ // If is_local, we're not going to load all of seqlen_k
442
+ int const seqlen_k_loaded = !params.is_local
443
+ ? params.seqlen_k
444
+ : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));
445
+ int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;
446
+ int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;
447
+ int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2);
448
+ // Always enable PackGQA for Split
449
+ // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits.
450
+ // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending
451
+ // that batch = 1.
452
+ int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks;
453
+ return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128);
454
+ #endif
455
+ }
456
+
457
+ inline int get_max_headdim() {
458
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
459
+ return 256;
460
+ #endif
461
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
462
+ return 192;
463
+ #endif
464
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
465
+ return 128;
466
+ #endif
467
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
468
+ return 96;
469
+ #endif
470
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
471
+ return 64;
472
+ #endif
473
+ return 0;
474
+ }
475
+
476
+ inline int round_up_headdim(int head_size) {
477
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
478
+ if (head_size <= 64) { return 64; }
479
+ #endif
480
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
481
+ if (head_size <= 96) { return 96; }
482
+ #endif
483
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
484
+ if (head_size <= 128) { return 128; }
485
+ #endif
486
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
487
+ if (head_size <= 192) { return 192; }
488
+ #endif
489
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
490
+ if (head_size <= 256) { return 256; }
491
+ #endif
492
+ return 256;
493
+ }
494
+
495
+ inline int round_up_headdimv(int head_size) {
496
+ if (head_size <= 64) { return 64; }
497
+ if (head_size <= 96) { return 96; }
498
+ if (head_size <= 128) { return 128; }
499
+ if (head_size <= 192) { return 192; }
500
+ if (head_size <= 256) { return 256; }
501
+ return 512;
502
+ }
503
+
504
+ // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
505
+ at::Tensor
506
+ mha_fwd_get_scheduler_metadata(
507
+ int batch_size,
508
+ int max_seqlen_q,
509
+ int max_seqlen_k,
510
+ int num_heads,
511
+ int num_heads_k,
512
+ int headdim,
513
+ int headdim_v,
514
+ at::ScalarType qkv_dtype,
515
+ const at::Tensor &seqused_k, // b
516
+ std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
517
+ std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
518
+ std::optional<const at::Tensor> &cu_seqlens_k_new_, // b+1
519
+ std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
520
+ std::optional<const at::Tensor> &leftpad_k_, // b
521
+ std::optional<int> page_size,
522
+ int max_seqlen_k_new, // 0 means we're not appending new KV
523
+ bool is_causal,
524
+ int window_size_left,
525
+ int window_size_right,
526
+ bool has_softcap,
527
+ int num_splits,
528
+ std::optional<bool> pack_gqa_,
529
+ int const sm_margin
530
+ ) {
531
+
532
+ TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn,
533
+ "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
534
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
535
+
536
+ // Reset the parameters
537
+ Flash_fwd_params params{};
538
+ params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16;
539
+ params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn;
540
+ params.b = batch_size;
541
+ params.seqlen_q = max_seqlen_q;
542
+ params.seqlen_k = max_seqlen_k;
543
+ params.h = num_heads;
544
+ params.h_k = num_heads_k;
545
+ params.d = headdim;
546
+ params.dv = headdim_v;
547
+ params.d_rounded = round_up_headdim(headdim);
548
+ params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v);
549
+ params.seqlen_knew = max_seqlen_k_new;
550
+
551
+ bool const is_varlen_q = cu_seqlens_q_.has_value();
552
+ params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr<int>() : nullptr;
553
+ bool const is_varlen_k = cu_seqlens_k_.has_value();
554
+ params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr<int>() : nullptr;
555
+ params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr<int>() : nullptr;
556
+ params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr<int>() : nullptr;
557
+ params.seqused_k = seqused_k.data_ptr<int>();
558
+ params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr<int>() : nullptr;
559
+ params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast<int*>(1) : nullptr;
560
+ if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }
561
+ if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }
562
+ // causal=true is the same as causal=false in this case
563
+ if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) {
564
+ // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
565
+ if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) {
566
+ is_causal = false;
567
+ }
568
+ }
569
+ if (is_causal) { window_size_right = 0; }
570
+
571
+ params.is_causal = window_size_left < 0 && window_size_right == 0;
572
+ params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
573
+ if (window_size_left < 0 && window_size_right >= 0) { window_size_left = max_seqlen_k - 1; }
574
+ if (window_size_left >= 0 && window_size_right < 0) { window_size_right = max_seqlen_q - 1; }
575
+ params.window_size_left = window_size_left;
576
+ params.window_size_right = window_size_right;
577
+ params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
578
+ params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
579
+ params.softcap = has_softcap ? 1.0f : 0.0f;
580
+
581
+ params.page_size = page_size.has_value() ? page_size.value() : 1;
582
+ params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1);
583
+
584
+ bool const use_dynamic_split = params.b <= 992;
585
+ params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
586
+
587
+ params.pagedkv_tma = get_pagedkv_tma(params);
588
+ // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits)
589
+ params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
590
+ params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
591
+ // Always enable PackGQA for Split
592
+ params.pack_gqa = params.num_splits > 1;
593
+
594
+ bool is_varlen = true;
595
+
596
+ // Otherwise the kernel will be launched from cuda:0 device
597
+ // Cast to char to avoid compiler warning about narrowing
598
+ at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()};
599
+
600
+ auto opts = seqused_k.options();
601
+ // This needs to be set after get_num_splits
602
+ at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
603
+ bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1;
604
+ if (scheduler_needs_semaphore || use_dynamic_split) {
605
+ tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32));
606
+ if (scheduler_needs_semaphore) {
607
+ if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing
608
+ params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
609
+ } else {
610
+ params.tile_count_semaphore = nullptr;
611
+ }
612
+ params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
613
+ }
614
+
615
+ if (params.num_splits_dynamic_ptr) {
616
+ auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
617
+ auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
618
+ int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
619
+ int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
620
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
621
+ prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/);
622
+ CHECK_CUDA_KERNEL_LAUNCH();
623
+ }
624
+ return tile_count_semaphore;
625
+ }
626
+
627
+ // b: batch_size
628
+ // b_k: batch_size_k
629
+ // s_q: seqlen_q
630
+ // s_k: seqlen_k
631
+ // s_k_new: seqlen_k_new
632
+ // h: num_heads
633
+ // h_k: num_heads_k
634
+ // d: head_size
635
+ std::vector<at::Tensor>
636
+ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
637
+ const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
638
+ const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table.
639
+ std::optional<const at::Tensor> &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
640
+ std::optional<const at::Tensor> &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
641
+ std::optional<const at::Tensor> &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
642
+ std::optional<at::Tensor> &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
643
+ std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
644
+ std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
645
+ std::optional<const at::Tensor> &cu_seqlens_k_new_, // b+1
646
+ std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
647
+ std::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
648
+ std::optional<int> max_seqlen_q_,
649
+ // TODO: check if we need max_seqlen_k
650
+ std::optional<int> max_seqlen_k_,
651
+ std::optional<const at::Tensor> &page_table_, // (b_k, max_num_pages_per_seq)
652
+ std::optional<const at::Tensor> &kv_batch_idx_, // b. indices to index into the KV cache
653
+ std::optional<const at::Tensor> &leftpad_k_, // b
654
+ std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
655
+ std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
656
+ std::optional<const at::Tensor> &seqlens_rotary_, // b
657
+ std::optional<at::Tensor> &q_descale_, // (b, h_k), not (b, h)
658
+ std::optional<at::Tensor> &k_descale_, // (b, h_k)
659
+ std::optional<at::Tensor> &v_descale_, // (b, h_k)
660
+ float const softmax_scale,
661
+ bool is_causal,
662
+ int window_size_left,
663
+ int window_size_right,
664
+ float const softcap,
665
+ bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
666
+ std::optional<at::Tensor> &scheduler_metadata_, // (b + 1)
667
+ int num_splits,
668
+ std::optional<bool> pack_gqa_,
669
+ int const sm_margin,
670
+ std::optional<const at::Tensor> &s_aux_ // (h)
671
+ ) {
672
+
673
+ auto dprops = at::cuda::getCurrentDeviceProperties();
674
+ bool is_sm8x = dprops->major >= 8;
675
+ TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
676
+
677
+ auto q_type = q.scalar_type();
678
+ TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
679
+ "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
680
+ if (dprops->major < 9) {
681
+ TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
682
+ "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type");
683
+ }
684
+ TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
685
+ TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
686
+
687
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
688
+
689
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
690
+ TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
691
+ TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
692
+
693
+ at::Tensor page_table;
694
+ const bool paged_KV = page_table_.has_value();
695
+ if (paged_KV) {
696
+ page_table = page_table_.value();
697
+ CHECK_DEVICE(page_table);
698
+ TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32");
699
+ TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension");
700
+ }
701
+
702
+ at::Tensor cu_seqlens_q;
703
+ bool const is_varlen_q = cu_seqlens_q_.has_value();
704
+ if (is_varlen_q) {
705
+ cu_seqlens_q = cu_seqlens_q_.value();
706
+ CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
707
+ TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
708
+ TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
709
+ }
710
+ at::Tensor cu_seqlens_k;
711
+ bool const is_varlen_k = cu_seqlens_k_.has_value();
712
+ if (is_varlen_k) {
713
+ cu_seqlens_k = cu_seqlens_k_.value();
714
+ CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
715
+ TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
716
+ TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
717
+ TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported");
718
+ TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported");
719
+ }
720
+
721
+ auto const sizes = q.sizes();
722
+ const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
723
+ int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
724
+ int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
725
+ int num_heads = q.size(-2);
726
+ int const head_size = q.size(-1);
727
+ int const head_size_v = v.size(-1);
728
+ int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);
729
+ int const num_pages = !paged_KV ? 0 : k.size(0);
730
+ int const page_size = !paged_KV ? 1 : k.size(1);
731
+ int const seqlen_k = !max_seqlen_k_.has_value() ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();
732
+ int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
733
+ int const num_heads_k = k.size(-2);
734
+ int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);
735
+ if (!kv_batch_idx_.has_value()) {
736
+ TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
737
+ }
738
+ int const max_headdim = get_max_headdim();
739
+ TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
740
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
741
+ if (head_size_v != head_size) {
742
+ TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) ||
743
+ (head_size <= 64 && head_size_v <= 512),
744
+ "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], "
745
+ "or (Q/K <= 64 and V <= 512).");
746
+ TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim");
747
+ if (head_size_v > 256) {
748
+ TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
749
+ "HeaddimV > 256 requires fp16 and bf16 data type");
750
+ }
751
+ }
752
+
753
+ // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
754
+ // TODO: check this
755
+ if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
756
+ if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
757
+ // causal=true is the same as causal=false in this case
758
+ if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) {
759
+ // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
760
+ if ((head_size <= 64 || head_size > 128) || !paged_KV) {
761
+ is_causal = false;
762
+ }
763
+ }
764
+ if (is_causal) { window_size_right = 0; }
765
+ // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_fprop will set params.is_causal=true.
766
+ // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM.
767
+ is_causal = window_size_left < 0 && window_size_right == 0;
768
+
769
+ if (!is_varlen_q) {
770
+ CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
771
+ } else {
772
+ CHECK_SHAPE(q, total_q, num_heads, head_size);
773
+ CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
774
+ }
775
+ if (!paged_KV) {
776
+ if (!is_varlen_k) {
777
+ CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size);
778
+ CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v);
779
+ } else {
780
+ CHECK_SHAPE(k, total_k, num_heads_k, head_size);
781
+ CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);
782
+ CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
783
+ }
784
+ } else {
785
+ CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);
786
+ CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v);
787
+ CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);
788
+ }
789
+
790
+ if (seqused_q_.has_value()){
791
+ auto seqused_q = seqused_q_.value();
792
+ TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
793
+ CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
794
+ CHECK_SHAPE(seqused_q, batch_size);
795
+ }
796
+ if (seqused_k_.has_value()) {
797
+ auto seqused_k = seqused_k_.value();
798
+ TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
799
+ CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
800
+ CHECK_SHAPE(seqused_k, batch_size);
801
+ }
802
+
803
+ if (leftpad_k_.has_value()) {
804
+ auto leftpad_k = leftpad_k_.value();
805
+ TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
806
+ CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k);
807
+ CHECK_SHAPE(leftpad_k, batch_size);
808
+ }
809
+
810
+ // This is what we will template on
811
+ bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value();
812
+ #ifdef FLASHATTENTION_DISABLE_VARLEN
813
+ TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
814
+ #endif
815
+
816
+ int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
817
+ TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment));
818
+ TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment));
819
+
820
+ auto opts = q.options();
821
+ auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
822
+ at::Tensor out;
823
+ if (out_.has_value()) {
824
+ out = out_.value();
825
+ TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
826
+ CHECK_DEVICE(out);
827
+ TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
828
+ if (!is_varlen_q) {
829
+ CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);
830
+ } else {
831
+ CHECK_SHAPE(out, total_q, num_heads, head_size_v);
832
+ }
833
+ } else {
834
+ out = !is_varlen_q
835
+ ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type))
836
+ : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type));
837
+ }
838
+
839
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
840
+ int const head_size_rounded = round_up_headdim(head_size);
841
+ int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v);
842
+ int const seqlen_q_rounded = round_multiple(seqlen_q, 128);
843
+ int const seqlen_k_rounded = round_multiple(seqlen_k, 128);
844
+
845
+ // Otherwise the kernel will be launched from cuda:0 device
846
+ // Cast to char to avoid compiler warning about narrowing
847
+ at::cuda::CUDAGuard device_guard{(char)q.get_device()};
848
+
849
+ at::Tensor softmax_lse;
850
+ if (!is_varlen_q) {
851
+ softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
852
+ } else {
853
+ softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
854
+ }
855
+
856
+ Flash_fwd_params params;
857
+ set_params_fprop(params,
858
+ batch_size,
859
+ seqlen_q, seqlen_k,
860
+ seqlen_q_rounded, seqlen_k_rounded,
861
+ num_heads, num_heads_k,
862
+ head_size, head_size_rounded,
863
+ q, k, v, out,
864
+ !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
865
+ !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
866
+ seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
867
+ seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
868
+ softmax_lse.data_ptr(),
869
+ /*p_dropout=*/0.f,
870
+ softmax_scale,
871
+ window_size_left,
872
+ window_size_right,
873
+ softcap,
874
+ sm_margin);
875
+ params.total_q = total_q;
876
+ params.total_k = total_k;
877
+ params.b_k = batch_size_k;
878
+ params.dv = head_size_v;
879
+ params.dv_rounded = head_size_v_rounded;
880
+ if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma
881
+ params.leftpad_k = static_cast<int *>(leftpad_k_.value().data_ptr());
882
+ }
883
+ if (paged_KV) {
884
+ params.page_table = page_table.data_ptr<int>();
885
+ params.page_table_batch_stride = page_table.stride(0);
886
+ }
887
+ params.page_size = page_size;
888
+ params.num_pages = num_pages;
889
+
890
+ if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma
891
+ at::Tensor k_new, v_new;
892
+ TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in");
893
+ TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in");
894
+ TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache");
895
+ at::Tensor cu_seqlens_k_new;
896
+ bool const is_varlen_k_new = cu_seqlens_k_new_.has_value();
897
+ if (is_varlen_k_new) {
898
+ cu_seqlens_k_new = cu_seqlens_k_new_.value();
899
+ CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new);
900
+ TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32");
901
+ }
902
+ k_new = k_new_.value();
903
+ v_new = v_new_.value();
904
+ TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query");
905
+ TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query");
906
+ CHECK_DEVICE(k_new); CHECK_DEVICE(v_new);
907
+ TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension");
908
+ TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension");
909
+ // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new
910
+ int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0;
911
+ int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0);
912
+ if (!is_varlen_k_new) {
913
+ CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size);
914
+ CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v);
915
+ } else {
916
+ CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size);
917
+ CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v);
918
+ CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1);
919
+ }
920
+ params.seqlen_knew = seqlen_k_new;
921
+ params.total_knew = total_k_new;
922
+ params.knew_ptr = k_new.data_ptr();
923
+ params.vnew_ptr = v_new.data_ptr();
924
+ // All stride are in elements, not bytes.
925
+ params.knew_row_stride = k_new.stride(-3);
926
+ params.vnew_row_stride = v_new.stride(-3);
927
+ params.knew_head_stride = k_new.stride(-2);
928
+ params.vnew_head_stride = v_new.stride(-2);
929
+ if (!is_varlen_k_new) {
930
+ params.knew_batch_stride = k_new.stride(0);
931
+ params.vnew_batch_stride = v_new.stride(0);
932
+ }
933
+ if (is_varlen_k_new) {
934
+ params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());
935
+ }
936
+ }
937
+
938
+ // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel
939
+ bool const use_dynamic_split = is_varlen && params.b <= 992;
940
+ // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it
941
+ params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
942
+
943
+ params.pagedkv_tma = get_pagedkv_tma(params);
944
+ // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits)
945
+ params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
946
+ params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
947
+ // Always enable PackGQA for Split
948
+ params.pack_gqa = params.num_splits > 1;
949
+
950
+ // This needs to be set after get_num_splits
951
+ at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
952
+ // We don't use the persistent scheduler if Split and not Varlen
953
+ bool const scheduler_needs_semaphore = params.arch >= 90
954
+ ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
955
+ : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
956
+ if (scheduler_needs_semaphore || use_dynamic_split) {
957
+ int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b;
958
+ params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value();
959
+ if (scheduler_metadata_.has_value()) {
960
+ at::Tensor scheduler_metadata = scheduler_metadata_.value();
961
+ CHECK_DEVICE(scheduler_metadata);
962
+ CHECK_SHAPE(scheduler_metadata, metadata_size);
963
+ CHECK_CONTIGUOUS(scheduler_metadata);
964
+ TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32");
965
+ tile_count_semaphore = scheduler_metadata;
966
+ } else {
967
+ tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32));
968
+ }
969
+ if (scheduler_needs_semaphore && !use_dynamic_split) {
970
+ tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing
971
+ }
972
+ params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr<int>() : nullptr;
973
+ params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
974
+ }
975
+
976
+ if (q_v_.has_value()) {
977
+ TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64");
978
+ TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
979
+ "q_v is only supported for fp16 and bf16 data type");
980
+ TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs");
981
+ at::Tensor q_v = q_v_.value();
982
+ TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query");
983
+ CHECK_DEVICE(q_v);
984
+ TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension");
985
+ if (!is_varlen_q) {
986
+ CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v);
987
+ } else {
988
+ CHECK_SHAPE(q_v, total_q, num_heads, head_size_v);
989
+ }
990
+ params.qv_ptr = q_v.data_ptr();
991
+ // All stride are in elements, not bytes.
992
+ params.qv_row_stride = q_v.stride(-3);
993
+ params.qv_head_stride = q_v.stride(-2);
994
+ if (!is_varlen_q) {
995
+ params.qv_batch_stride = q_v.stride(0);
996
+ }
997
+ }
998
+
999
+ if (rotary_cos_.has_value()) {
1000
+ TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
1001
+ auto rotary_cos = rotary_cos_.value();
1002
+ CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos);
1003
+ params.rotary_dim = rotary_cos.size(1) * 2;
1004
+ TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
1005
+ TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
1006
+ const int seqlen_ro = rotary_cos.size(0);
1007
+ if (paged_KV) {
1008
+ TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
1009
+ }
1010
+ CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
1011
+ TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
1012
+
1013
+ TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
1014
+ auto rotary_sin = rotary_sin_.value();
1015
+ CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin);
1016
+ CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
1017
+ TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
1018
+ params.rotary_cos_ptr = rotary_cos.data_ptr();
1019
+ params.rotary_sin_ptr = rotary_sin.data_ptr();
1020
+ params.is_rotary_interleaved = is_rotary_interleaved;
1021
+ if (seqlens_rotary_.has_value()) {
1022
+ at::Tensor seqlens_rotary = seqlens_rotary_.value();
1023
+ CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary);
1024
+ TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32");
1025
+ CHECK_SHAPE(seqlens_rotary, batch_size);
1026
+ params.seqlens_rotary = seqlens_rotary.data_ptr<int>();
1027
+ }
1028
+ } else {
1029
+ params.rotary_dim = 0;
1030
+ }
1031
+
1032
+ if (kv_batch_idx_.has_value()) {
1033
+ auto kv_batch_idx = kv_batch_idx_.value();
1034
+ CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx);
1035
+ TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32");
1036
+ params.kv_batch_idx = reinterpret_cast<int *>(kv_batch_idx.data_ptr());
1037
+ }
1038
+
1039
+ at::Tensor out_accum, softmax_lse_accum;
1040
+ auto outaccum_type = at::ScalarType::Float;
1041
+ if (params.num_splits > 1) {
1042
+ TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
1043
+ if (!is_varlen_q) {
1044
+ out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type));
1045
+ softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
1046
+ params.oaccum_batch_stride = out_accum.stride(1);
1047
+ params.lseaccum_batch_stride = softmax_lse_accum.stride(1);
1048
+ } else {
1049
+ out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type));
1050
+ softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
1051
+ }
1052
+ params.is_fp32 = false;
1053
+ params.oaccum_ptr = out_accum.data_ptr();
1054
+ params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
1055
+ params.oaccum_split_stride = out_accum.stride(0);
1056
+ params.oaccum_row_stride = out_accum.stride(-2);
1057
+ params.oaccum_head_stride = out_accum.stride(-3);
1058
+ params.lseaccum_split_stride = softmax_lse_accum.stride(0);
1059
+ params.lseaccum_head_stride = softmax_lse_accum.stride(-2);
1060
+ }
1061
+
1062
+ if (q_type == at::ScalarType::Float8_e4m3fn) {
1063
+ if (q_descale_.has_value()) {
1064
+ auto q_descale = q_descale_.value();
1065
+ CHECK_DEVICE(q_descale);
1066
+ CHECK_SHAPE(q_descale, batch_size, num_heads_k);
1067
+ params.q_descale_ptr = q_descale.data_ptr<float>();
1068
+ params.q_descale_batch_stride = q_descale.stride(0);
1069
+ params.q_descale_head_stride = q_descale.stride(1);
1070
+ } else {
1071
+ params.q_descale_ptr = nullptr;
1072
+ }
1073
+ if (k_descale_.has_value()) {
1074
+ auto k_descale = k_descale_.value();
1075
+ CHECK_DEVICE(k_descale);
1076
+ CHECK_SHAPE(k_descale, batch_size, num_heads_k);
1077
+ params.k_descale_ptr = k_descale.data_ptr<float>();
1078
+ params.k_descale_batch_stride = k_descale.stride(0);
1079
+ params.k_descale_head_stride = k_descale.stride(1);
1080
+ } else {
1081
+ params.k_descale_ptr = nullptr;
1082
+ }
1083
+ if (v_descale_.has_value()) {
1084
+ auto v_descale = v_descale_.value();
1085
+ CHECK_DEVICE(v_descale);
1086
+ CHECK_SHAPE(v_descale, batch_size, num_heads_k);
1087
+ params.v_descale_ptr = v_descale.data_ptr<float>();
1088
+ params.v_descale_batch_stride = v_descale.stride(0);
1089
+ params.v_descale_head_stride = v_descale.stride(1);
1090
+ } else {
1091
+ params.v_descale_ptr = nullptr;
1092
+ }
1093
+ }
1094
+
1095
+ if(s_aux_.has_value()) {
1096
+ auto s_aux = s_aux_.value();
1097
+ TORCH_CHECK(s_aux.scalar_type() == at::ScalarType::BFloat16,
1098
+ "We only support bf16 dtype for S extra.");
1099
+ CHECK_DEVICE(s_aux);
1100
+ CHECK_SHAPE(s_aux, num_heads);
1101
+ CHECK_CONTIGUOUS(s_aux);
1102
+ params.s_aux_ptr = s_aux.data_ptr();
1103
+ } else {
1104
+ params.s_aux_ptr = nullptr;
1105
+ }
1106
+
1107
+ #ifdef FLASHATTENTION_DISABLE_LOCAL
1108
+ TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
1109
+ #endif
1110
+ #ifdef FLASHATTENTION_DISABLE_SOFTCAP
1111
+ TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
1112
+ #endif
1113
+ #ifdef FLASHATTENTION_DISABLE_SPLIT
1114
+ TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
1115
+ #endif
1116
+ #ifdef FLASHATTENTION_DISABLE_PACKGQA
1117
+ TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa.");
1118
+ #endif
1119
+ #ifdef FLASHATTENTION_DISABLE_PAGEDKV
1120
+ TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV.");
1121
+ #endif
1122
+ #ifdef FLASHATTENTION_DISABLE_APPENDKV
1123
+ TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV.");
1124
+ #endif
1125
+
1126
+ if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {
1127
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
1128
+ run_mha_fwd(params, stream);
1129
+ if (params.num_splits > 1) {
1130
+ if (out_type == at::ScalarType::BFloat16) {
1131
+ // Since we want output in BF16. Otherwise fwd_combine will output to FP16
1132
+ params.is_bf16 = true;
1133
+ }
1134
+ // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
1135
+ // and seqlen = total_q, and don't need to dispatch to Varlen there.
1136
+ // However, with dynamic split, each row needs to know which batch it belongs to
1137
+ // to read the number of splits, so we just use the varlen version of combine kernel.
1138
+ // if (is_varlen_q && !seqused_q_.has_value()) {
1139
+ // if (is_varlen_q) {
1140
+ // params.b = 1;
1141
+ // params.seqlen_q = total_q;
1142
+ // }
1143
+ // This will zero out the semaphore if needed
1144
+ run_mha_fwd_combine(params, stream, true /*enable_pdl*/);
1145
+ } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) {
1146
+ // need to zero out the semaphore in this case
1147
+ tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_();
1148
+ }
1149
+ } else if (total_q > 0 && num_heads_k > 0) {
1150
+ // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
1151
+ out.zero_();
1152
+ softmax_lse.fill_(std::numeric_limits<float>::infinity());
1153
+ }
1154
+
1155
+ // return {out, softmax_lse};
1156
+ return {out, softmax_lse, out_accum, softmax_lse_accum};
1157
+ }
1158
+
1159
+ void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
1160
+ #ifndef FLASHATTENTION_DISABLE_BACKWARD
1161
+ // FP16_SWITCH(!params.is_bf16, [&] {
1162
+ // HEADDIM_SWITCH(params.d, [&] {
1163
+ // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
1164
+ // });
1165
+ // });
1166
+ ARCH_SWITCH(params.arch, Arch, [&] {
1167
+ SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
1168
+ if (!params.is_bf16) {
1169
+ #ifndef FLASHATTENTION_DISABLE_FP16
1170
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
1171
+ if (params.d <= 64) { return run_mha_bwd_<Arch, cutlass::half_t, 64, Has_softcap>(params, stream); }
1172
+ #endif
1173
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
1174
+ if (params.d <= 96) { return run_mha_bwd_<Arch, cutlass::half_t, 96, Has_softcap>(params, stream); }
1175
+ #endif
1176
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
1177
+ if (params.d <= 128) { return run_mha_bwd_<Arch, cutlass::half_t, 128, Has_softcap>(params, stream); }
1178
+ #endif
1179
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
1180
+ if (params.d <= 192) { return run_mha_bwd_<Arch, cutlass::half_t, 192, Has_softcap>(params, stream); }
1181
+ #endif
1182
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
1183
+ if (params.d <= 256) { return run_mha_bwd_<Arch, cutlass::half_t, 256, Has_softcap>(params, stream); }
1184
+ #endif
1185
+ #else
1186
+ TORCH_CHECK(false, "This flash attention build does not support FP16.");
1187
+ #endif
1188
+ } else {
1189
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
1190
+ if (params.d <= 64) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 64, Has_softcap>(params, stream); }
1191
+ #endif
1192
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
1193
+ if (params.d <= 96) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 96, Has_softcap>(params, stream); }
1194
+ #endif
1195
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
1196
+ if (params.d <= 128) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 128, Has_softcap>(params, stream); }
1197
+ #endif
1198
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
1199
+ if (params.d <= 192) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 192, Has_softcap>(params, stream); }
1200
+ #endif
1201
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
1202
+ if (params.d <= 256) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 256, Has_softcap>(params, stream); }
1203
+ #endif
1204
+ }
1205
+ });
1206
+ });
1207
+ #endif
1208
+ }
1209
+
1210
+
1211
+ // b: batch_size
1212
+ // s_q: seqlen_q
1213
+ // s_k: seqlen_k
1214
+ // h: num_heads
1215
+ // h_k: num_heads_k
1216
+ // d: head_size
1217
+ std::vector<at::Tensor> mha_bwd(
1218
+ const at::Tensor &dout, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1219
+ const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1220
+ const at::Tensor &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
1221
+ const at::Tensor &v, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
1222
+ const at::Tensor &out, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1223
+ const at::Tensor &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q
1224
+ std::optional<at::Tensor> &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1225
+ std::optional<at::Tensor> &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
1226
+ std::optional<at::Tensor> &dv_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
1227
+ std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
1228
+ std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
1229
+ std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
1230
+ std::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
1231
+ std::optional<int> max_seqlen_q_,
1232
+ std::optional<int> max_seqlen_k_,
1233
+ float const softmax_scale,
1234
+ bool is_causal,
1235
+ int window_size_left,
1236
+ int window_size_right,
1237
+ float const softcap,
1238
+ bool const deterministic,
1239
+ int const sm_margin) {
1240
+
1241
+ #ifdef FLASHATTENTION_DISABLE_BACKWARD
1242
+ TORCH_CHECK(false, "This flash attention build does not support backward.");
1243
+ #endif
1244
+
1245
+ auto dprops = at::cuda::getCurrentDeviceProperties();
1246
+ bool is_sm8x = dprops->major >= 8;
1247
+ TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
1248
+
1249
+ auto q_type = q.dtype();
1250
+ TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
1251
+ "FlashAttention only support fp16 and bf16 data type");
1252
+ TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
1253
+ TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
1254
+ TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
1255
+ TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
1256
+
1257
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
1258
+ CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
1259
+
1260
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1261
+ TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1262
+ TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1263
+ TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
1264
+ TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
1265
+
1266
+ at::Tensor cu_seqlens_q;
1267
+ bool const is_varlen_q = cu_seqlens_q_.has_value();
1268
+ if (is_varlen_q) {
1269
+ cu_seqlens_q = cu_seqlens_q_.value();
1270
+ CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
1271
+ TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
1272
+ TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
1273
+ }
1274
+ at::Tensor cu_seqlens_k;
1275
+ bool const is_varlen_k = cu_seqlens_k_.has_value();
1276
+ if (is_varlen_k) {
1277
+ cu_seqlens_k = cu_seqlens_k_.value();
1278
+ CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
1279
+ TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
1280
+ TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
1281
+ }
1282
+ // This is what we will template on
1283
+ bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();
1284
+ #ifdef FLASHATTENTION_DISABLE_VARLEN
1285
+ TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
1286
+ #endif
1287
+
1288
+ auto const sizes = q.sizes();
1289
+ int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
1290
+ int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
1291
+ int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
1292
+ int const num_heads = q.size(-2);
1293
+ int const head_size = q.size(-1);
1294
+ int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value();
1295
+ int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
1296
+ int const num_heads_k = k.size(-2);
1297
+ TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
1298
+ int const max_headdim = get_max_headdim();
1299
+ TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
1300
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
1301
+
1302
+ // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
1303
+ if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
1304
+ if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
1305
+ if (is_causal) { window_size_right = 0; }
1306
+ // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.
1307
+ // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).
1308
+ is_causal = window_size_left < 0 && window_size_right == 0;
1309
+
1310
+ int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
1311
+ int const head_size_rounded = round_up_headdim(head_size);
1312
+ // Very important that these match the kernel configs
1313
+ bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
1314
+ int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
1315
+ : (head_size_rounded <= 96 ? 64
1316
+ : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)
1317
+ : 64));
1318
+ int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
1319
+ int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
1320
+ int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
1321
+ int const kBlockN_sm90 = head_size_rounded <= 128
1322
+ ? 128
1323
+ : (head_size_rounded <= 192 ? 96 : 80);
1324
+ int const kBlockN_sm80 = head_size_rounded <= 128
1325
+ ? 128
1326
+ : (head_size_rounded <= 192 ? 80 : 64);
1327
+ int const kBlockN_sm86 = head_size_rounded <= 64 ? 128
1328
+ : (head_size_rounded <= 96 ? 128
1329
+ : (head_size_rounded <= 128 ? 96
1330
+ : (head_size_rounded <= 192 ? 64 : 64)));
1331
+ int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
1332
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1333
+ int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
1334
+ int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
1335
+ int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);
1336
+ int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);
1337
+
1338
+ if (!is_varlen_q) {
1339
+ CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
1340
+ CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
1341
+ CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
1342
+ } else {
1343
+ CHECK_SHAPE(q, total_q, num_heads, head_size);
1344
+ CHECK_SHAPE(out, total_q, num_heads, head_size);
1345
+ CHECK_SHAPE(dout, total_q, num_heads, head_size);
1346
+ CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
1347
+ }
1348
+ if (!is_varlen_k) {
1349
+ CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
1350
+ CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
1351
+ } else {
1352
+ CHECK_SHAPE(k, total_k, num_heads_k, head_size);
1353
+ CHECK_SHAPE(v, total_k, num_heads_k, head_size);
1354
+ CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
1355
+ }
1356
+
1357
+ if (seqused_q_.has_value()){
1358
+ auto seqused_q = seqused_q_.value();
1359
+ TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
1360
+ CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
1361
+ CHECK_SHAPE(seqused_q, batch_size);
1362
+ }
1363
+ if (seqused_k_.has_value()){
1364
+ auto seqused_k = seqused_k_.value();
1365
+ TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
1366
+ CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
1367
+ CHECK_SHAPE(seqused_k, batch_size);
1368
+ }
1369
+
1370
+ at::Tensor dq, dk, dv;
1371
+ if (dq_.has_value()) {
1372
+ dq = dq_.value();
1373
+ TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
1374
+ CHECK_DEVICE(dq);
1375
+ TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
1376
+ if (!is_varlen_q) {
1377
+ CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
1378
+ } else {
1379
+ CHECK_SHAPE(dq, total_q, num_heads, head_size);
1380
+ }
1381
+ } else {
1382
+ dq = torch::empty_like(q);
1383
+ }
1384
+ if (dk_.has_value()) {
1385
+ dk = dk_.value();
1386
+ TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
1387
+ CHECK_DEVICE(dk);
1388
+ TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
1389
+ if (!is_varlen_k) {
1390
+ CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
1391
+ } else {
1392
+ CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
1393
+ }
1394
+ } else {
1395
+ dk = torch::empty_like(k);
1396
+ }
1397
+ if (dv_.has_value()) {
1398
+ dv = dv_.value();
1399
+ TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
1400
+ CHECK_DEVICE(dv);
1401
+ TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
1402
+ if (!is_varlen_k) {
1403
+ CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
1404
+ } else {
1405
+ CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
1406
+ }
1407
+ } else {
1408
+ dv = torch::empty_like(v);
1409
+ }
1410
+
1411
+ // Otherwise the kernel will be launched from cuda:0 device
1412
+ // Cast to char to avoid compiler warning about narrowing
1413
+ at::cuda::CUDAGuard device_guard{(char)q.get_device()};
1414
+
1415
+ auto opts = q.options();
1416
+ // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
1417
+ at::Tensor softmax_d, softmax_lse_log2;
1418
+ if (!is_varlen) {
1419
+ // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
1420
+ softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
1421
+ softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
1422
+ } else {
1423
+ softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
1424
+ softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
1425
+ }
1426
+ at::Tensor dq_accum, dk_accum, dv_accum;
1427
+ if (!is_varlen) {
1428
+ dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1429
+ } else {
1430
+ dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1431
+ }
1432
+ if (num_heads_k != num_heads) { // MQA / GQA
1433
+ if (!is_varlen) {
1434
+ dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1435
+ dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1436
+ } else {
1437
+ dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
1438
+ dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
1439
+ }
1440
+ }
1441
+
1442
+ Flash_bwd_params params;
1443
+ set_params_dgrad(params,
1444
+ batch_size,
1445
+ seqlen_q, seqlen_k,
1446
+ seqlen_q_rounded, seqlen_k_rounded,
1447
+ num_heads, num_heads_k,
1448
+ head_size, head_size_rounded,
1449
+ q, k, v, out,
1450
+ dout, dq, dk, dv,
1451
+ !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
1452
+ !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
1453
+ seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
1454
+ seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
1455
+ dq_accum.data_ptr(),
1456
+ num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
1457
+ num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
1458
+ softmax_lse.data_ptr(),
1459
+ softmax_d.data_ptr(),
1460
+ /*p_dropout=*/0.f,
1461
+ softmax_scale,
1462
+ window_size_left,
1463
+ window_size_right,
1464
+ softcap,
1465
+ deterministic,
1466
+ sm_margin);
1467
+ params.total_q = total_q;
1468
+ params.total_k = total_k;
1469
+ params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
1470
+ params.dv = head_size; // We don't support hdim_v being different from hdim_qk for now
1471
+
1472
+ // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
1473
+ // params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
1474
+ // Will be zero'ed out in the backward preprocess kernel
1475
+ at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
1476
+ params.dq_semaphore = dq_semaphore.data_ptr<int>();
1477
+ if (num_heads_k != num_heads && params.deterministic) {
1478
+ // TODO: do we need to zero them out?
1479
+ at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
1480
+ at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
1481
+ params.dk_semaphore = dk_semaphore.data_ptr<int>();
1482
+ params.dv_semaphore = dv_semaphore.data_ptr<int>();
1483
+ }
1484
+
1485
+ #ifdef FLASHATTENTION_DISABLE_LOCAL
1486
+ TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
1487
+ #endif
1488
+ #ifdef FLASHATTENTION_DISABLE_SOFTCAP
1489
+ TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
1490
+ #endif
1491
+
1492
+ if (total_q > 0 && total_k > 0 && num_heads_k > 0) {
1493
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
1494
+ run_mha_bwd(params, stream);
1495
+ } else if (total_k > 0 && num_heads_k > 0) {
1496
+ // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
1497
+ dk.zero_();
1498
+ dv.zero_();
1499
+ softmax_d.zero_();
1500
+ } else if (total_q > 0 && num_heads_k > 0) {
1501
+ dq.zero_();
1502
+ softmax_d.zero_();
1503
+ }
1504
+
1505
+ return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
1506
+ }
1507
+
1508
+ std::vector<at::Tensor>
1509
+ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x seqlen x num_heads x head_size
1510
+ const at::Tensor &lse_partial, // num_splits x batch_size x seqlen x num_heads
1511
+ std::optional<at::Tensor> out_, // batch_size x seqlen x num_heads x head_size
1512
+ std::optional<at::ScalarType> out_dtype_
1513
+ ) {
1514
+
1515
+ auto dprops = at::cuda::getCurrentDeviceProperties();
1516
+ bool is_sm8x = dprops->major >= 8;
1517
+ TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer.");
1518
+
1519
+ auto out_partial_type = out_partial.scalar_type();
1520
+ TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type");
1521
+ TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type");
1522
+
1523
+ CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);
1524
+
1525
+ TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1526
+ TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension");
1527
+
1528
+ const auto sizes = out_partial.sizes();
1529
+
1530
+ const int num_splits = sizes[0];
1531
+ const int batch_size = sizes[1];
1532
+ const int seqlen = sizes[2];
1533
+ const int num_heads = sizes[3];
1534
+ const int head_size_og = sizes[4];
1535
+ TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256");
1536
+
1537
+ CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);
1538
+ CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);
1539
+
1540
+ int const alignment = 4;
1541
+ at::Tensor out_partial_padded;
1542
+ auto pad = [](at::Tensor x, int alignment) {
1543
+ return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
1544
+ };
1545
+ out_partial_padded = pad(out_partial, alignment);
1546
+
1547
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1548
+ const int head_size = round_multiple(head_size_og, alignment);
1549
+
1550
+ auto opts = out_partial.options();
1551
+ at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());
1552
+ TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16");
1553
+ at::Tensor out;
1554
+ if (out_.has_value()) {
1555
+ out = out_.value();
1556
+ TORCH_CHECK(out.scalar_type() == out_type);
1557
+ CHECK_DEVICE(out);
1558
+ TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
1559
+ CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);
1560
+ if (head_size_og % alignment != 0) {
1561
+ out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
1562
+ }
1563
+ } else {
1564
+ out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
1565
+ }
1566
+
1567
+ // Otherwise the kernel will be launched from cuda:0 device
1568
+ // Cast to char to avoid compiler warning about narrowing
1569
+ at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()};
1570
+
1571
+ auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2);
1572
+
1573
+ Flash_fwd_params params {}; // Need to reset the params to set everything to zero
1574
+ params.is_fp32 = out_type == at::ScalarType::Float;
1575
+ params.is_bf16 = out_type == at::ScalarType::BFloat16;
1576
+ params.oaccum_ptr = out_partial_padded.data_ptr();
1577
+ params.softmax_lseaccum_ptr = lse_partial.data_ptr();
1578
+ params.o_ptr = out.data_ptr();
1579
+ params.softmax_lse_ptr = softmax_lse.data_ptr();
1580
+ params.b = batch_size;
1581
+ params.h = num_heads;
1582
+ params.seqlen_q = seqlen;
1583
+ params.dv = head_size;
1584
+ params.num_splits = num_splits;
1585
+ params.oaccum_split_stride = out_partial_padded.stride(0);
1586
+ params.oaccum_row_stride = out_partial_padded.stride(2);
1587
+ params.oaccum_head_stride = out_partial_padded.stride(3);
1588
+ params.oaccum_batch_stride = out_partial_padded.stride(1);
1589
+ params.lseaccum_split_stride = lse_partial.stride(0);
1590
+ params.lseaccum_head_stride = lse_partial.stride(3);
1591
+ params.lseaccum_batch_stride = lse_partial.stride(1);
1592
+ params.o_row_stride = out.stride(1);
1593
+ params.o_head_stride = out.stride(2);
1594
+ params.o_batch_stride = out.stride(0);
1595
+ params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
1596
+
1597
+ if (seqlen > 0 && batch_size > 0) {
1598
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
1599
+ run_mha_fwd_combine(params, stream, false /*enable_pdl*/);
1600
+ }
1601
+
1602
+ at::Tensor out_padded = out;
1603
+ if (head_size_og % alignment != 0) {
1604
+ out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
1605
+ // if (out_.has_value()) { out_.value().copy_(out); }
1606
+ }
1607
+
1608
+ return {out, softmax_lse};
1609
+ }
1610
+
1611
+ #ifndef FLASHATTENTION_DISABLE_PYBIND
1612
+
1613
+ #include <torch/python.h>
1614
+
1615
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1616
+ m.doc() = "FlashAttention";
1617
+ m.def("fwd", &mha_fwd, "Forward pass");
1618
+ m.def("bwd", &mha_bwd, "Backward pass");
1619
+ m.def("fwd_combine", &mha_combine, "Combine partial attention outputs");
1620
+ m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass");
1621
+ }
1622
+
1623
+ #endif
flash-attn/flash_bwd_kernel_sm80.h ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include <cutlass/cutlass.h>
10
+ #include <cutlass/array.h>
11
+ #include <cutlass/numeric_types.h>
12
+ #include <cutlass/kernel_hardware_info.h>
13
+
14
+ #include "utils.h"
15
+
16
+ namespace flash {
17
+
18
+ using namespace cute;
19
+
20
+ template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
21
+ class FlashAttnBwdSm80 {
22
+
23
+ public:
24
+
25
+ // Type Aliases
26
+ static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
27
+ static constexpr bool Is_local = CollectiveMainloop_::Is_local;
28
+ static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
29
+ static constexpr bool Varlen = CollectiveMainloop_::Varlen;
30
+
31
+ // Mainloop derived types
32
+ using CollectiveMainloop = CollectiveMainloop_;
33
+ using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
34
+ using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
35
+ using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
36
+ using ArchTag = typename CollectiveMainloop::ArchTag;
37
+ using MainloopArguments = typename CollectiveMainloop::Arguments;
38
+ using MainloopParams = typename CollectiveMainloop::Params;
39
+ static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
40
+
41
+ // Epilogue derived types
42
+ using CollectiveEpilogue = CollectiveEpilogue_;
43
+ using EpilogueArguments = typename CollectiveEpilogue::Arguments;
44
+ using EpilogueParams = typename CollectiveEpilogue::Params;
45
+
46
+ static_assert(ArchTag::kMinComputeCapability >= 80);
47
+
48
+ using TileScheduler = TileScheduler_;
49
+ using TileSchedulerArguments = typename flash::TileSchedulerArguments;
50
+ using TileSchedulerParams = typename TileScheduler::Params;
51
+
52
+ static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{}));
53
+ static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{}));
54
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
55
+
56
+ // Kernel level shared memory storage
57
+ struct SharedStorage {
58
+ struct TensorStorage : cute::aligned_struct<128> {
59
+ union {
60
+ typename CollectiveMainloop::TensorStorage mainloop;
61
+ typename CollectiveEpilogue::TensorStorage epilogue;
62
+ };
63
+ } tensors;
64
+
65
+ alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
66
+
67
+ };
68
+
69
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
70
+
71
+ // Device side arguments
72
+ struct Arguments {
73
+ MainloopArguments mainloop{};
74
+ EpilogueArguments epilogue{};
75
+ cutlass::KernelHardwareInfo hw_info{};
76
+ TileSchedulerArguments scheduler{};
77
+ };
78
+
79
+ // Kernel entry point API
80
+ struct Params {
81
+ MainloopParams mainloop{};
82
+ EpilogueParams epilogue{};
83
+ cutlass::KernelHardwareInfo hw_info{};
84
+ TileSchedulerParams scheduler{};
85
+ };
86
+
87
+ //
88
+ // Methods
89
+ //
90
+
91
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
92
+ static
93
+ Params
94
+ to_underlying_arguments(Arguments const& args) {
95
+ CUTLASS_TRACE_HOST("to_underlying_arguments():");
96
+
97
+ // Get SM count if needed, otherwise use user supplied SM count
98
+ int sm_count = args.hw_info.sm_count;
99
+ if (sm_count <= 0) {
100
+ CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
101
+ " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
102
+ sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
103
+ }
104
+
105
+ CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
106
+
107
+ cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
108
+ return {
109
+ CollectiveMainloop::to_underlying_arguments(args.mainloop),
110
+ CollectiveEpilogue::to_underlying_arguments(args.epilogue),
111
+ hw_info,
112
+ TileScheduler::to_underlying_arguments(args.scheduler)
113
+ };
114
+ }
115
+
116
+ // Computes the kernel launch grid shape based on runtime parameters
117
+ static dim3
118
+ get_grid_shape(Params const& params) {
119
+ return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
120
+ }
121
+
122
+ static dim3
123
+ get_block_shape() {
124
+ return dim3(MaxThreadsPerBlock, 1, 1);
125
+ }
126
+
127
+ CUTLASS_DEVICE
128
+ void
129
+ operator()(Params const& params, char* smem_buf) {
130
+
131
+ static constexpr int kBlockM = get<0>(TileShape_MNK{});
132
+ static constexpr int kBlockN = get<1>(TileShape_MNK{});
133
+
134
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
135
+
136
+ CollectiveMainloop mainloop;
137
+ CollectiveEpilogue epilogue;
138
+
139
+ TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
140
+ // Initialize matmul objects.
141
+ TiledMmadKV tiled_mma_dKV;
142
+
143
+ scheduler.init_consumer();
144
+
145
+ int warp_idx = cutlass::canonical_warp_idx_sync();
146
+ CUTLASS_PRAGMA_NO_UNROLL
147
+ for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
148
+ work_tile_info.is_valid(params.scheduler);
149
+ work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
150
+
151
+ auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
152
+ auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
153
+ cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
154
+
155
+ // dK and dV output accumulator.
156
+ Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
157
+ Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
158
+ bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x,
159
+ block_coord, shared_storage);
160
+ scheduler.prefetch_next_work(params.scheduler, work_tile_info);
161
+ if (tile_valid) {
162
+ epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
163
+ threadIdx.x, block_coord);
164
+ } else {
165
+ epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
166
+ }
167
+ }
168
+
169
+ }
170
+
171
+ };
172
+
173
+ } // namespace flash
flash-attn/flash_bwd_kernel_sm90.h ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ /******************************************************************************
3
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
4
+ ******************************************************************************/
5
+
6
+ #pragma once
7
+
8
+ #include "cute/tensor.hpp"
9
+
10
+ #include <cutlass/cutlass.h>
11
+ #include <cutlass/arch/reg_reconfig.h>
12
+ #include <cutlass/array.h>
13
+ #include <cutlass/numeric_types.h>
14
+ #include <cutlass/numeric_conversion.h>
15
+ #include <cutlass/kernel_hardware_info.h>
16
+ #include "cutlass/pipeline/pipeline.hpp"
17
+
18
+ #include "utils.h"
19
+
20
+ namespace flash {
21
+
22
+ using namespace cute;
23
+
24
+ template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
25
+ class FlashAttnBwdSm90 {
26
+
27
+ public:
28
+
29
+ // Type Aliases
30
+ static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
31
+ static constexpr bool Is_local = CollectiveMainloop_::Is_local;
32
+ static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
33
+ static constexpr bool Varlen = CollectiveMainloop_::Varlen;
34
+
35
+ // Mainloop derived types
36
+ using CollectiveMainloop = CollectiveMainloop_;
37
+ using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
38
+ using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
39
+ using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
40
+ using ArchTag = typename CollectiveMainloop::ArchTag;
41
+ using ClusterShape = typename CollectiveMainloop::ClusterShape;
42
+ using MainloopArguments = typename CollectiveMainloop::Arguments;
43
+ using MainloopParams = typename CollectiveMainloop::Params;
44
+ static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
45
+
46
+ // Epilogue derived types
47
+ using CollectiveEpilogue = CollectiveEpilogue_;
48
+ using EpilogueArguments = typename CollectiveEpilogue::Arguments;
49
+ using EpilogueParams = typename CollectiveEpilogue::Params;
50
+
51
+ static_assert(ArchTag::kMinComputeCapability >= 90);
52
+
53
+ using TileScheduler = TileScheduler_;
54
+ using TileSchedulerArguments = typename flash::TileSchedulerArguments;
55
+ using TileSchedulerParams = typename TileScheduler::Params;
56
+
57
+ static constexpr uint32_t NumLoadWarpGroups = 1;
58
+ static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup;
59
+ static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
60
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
61
+ static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
62
+
63
+ /// Register requirement for Load and Math WGs
64
+ static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32;
65
+ static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160;
66
+ // If you want to print from the producer warp, you'd need to increase the number of registers
67
+ // Otherwise you'll get CUDA error.
68
+ // static constexpr uint32_t LoadRegisterRequirement = 40;
69
+ // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
70
+
71
+ // Kernel level shared memory storage
72
+ struct SharedStorage {
73
+ struct TensorStorage : cute::aligned_struct<128> {
74
+ union {
75
+ typename CollectiveMainloop::TensorStorage mainloop;
76
+ typename CollectiveEpilogue::TensorStorage epilogue;
77
+ };
78
+ } tensors;
79
+
80
+ struct PipelineStorage : cute::aligned_struct<16> {
81
+ alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV;
82
+ alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q;
83
+ alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do;
84
+ alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
85
+ } pipelines;
86
+
87
+ };
88
+
89
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
90
+
91
+ // Device side arguments
92
+ struct Arguments {
93
+ MainloopArguments mainloop{};
94
+ EpilogueArguments epilogue{};
95
+ cutlass::KernelHardwareInfo hw_info{};
96
+ TileSchedulerArguments scheduler{};
97
+ };
98
+
99
+ // Kernel entry point API
100
+ struct Params {
101
+ MainloopParams mainloop{};
102
+ EpilogueParams epilogue{};
103
+ cutlass::KernelHardwareInfo hw_info{};
104
+ TileSchedulerParams scheduler{};
105
+ };
106
+
107
+ //
108
+ // Methods
109
+ //
110
+
111
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
112
+ static
113
+ Params
114
+ to_underlying_arguments(Arguments const& args) {
115
+ CUTLASS_TRACE_HOST("to_underlying_arguments():");
116
+
117
+ // Get SM count if needed, otherwise use user supplied SM count
118
+ int sm_count = args.hw_info.sm_count;
119
+ if (sm_count <= 0) {
120
+ CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
121
+ " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
122
+ sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
123
+ }
124
+
125
+ CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
126
+
127
+ cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
128
+ return {
129
+ CollectiveMainloop::to_underlying_arguments(args.mainloop),
130
+ CollectiveEpilogue::to_underlying_arguments(args.epilogue),
131
+ hw_info,
132
+ TileScheduler::to_underlying_arguments(args.scheduler)
133
+ };
134
+ }
135
+
136
+ // Computes the kernel launch grid shape based on runtime parameters
137
+ static dim3
138
+ get_grid_shape(Params const& params) {
139
+ return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
140
+ }
141
+
142
+ static dim3
143
+ get_block_shape() {
144
+ return dim3(MaxThreadsPerBlock, 1, 1);
145
+ }
146
+
147
+ CUTLASS_DEVICE
148
+ void
149
+ operator()(Params const& params, char* smem_buf) {
150
+
151
+ static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
152
+ static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
153
+ static constexpr int kBlockM = get<0>(TileShape_MNK{});
154
+ static constexpr int kBlockN = get<1>(TileShape_MNK{});
155
+
156
+ using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
157
+ using PipelineParams = typename MainloopPipeline::Params;
158
+ using PipelineState = typename MainloopPipeline::PipelineState;
159
+ using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO;
160
+ using PipelineParams_dO = typename MainloopPipeline_dO::Params;
161
+ using PipelineState_dO = typename MainloopPipeline_dO::PipelineState;
162
+ static constexpr bool Q_dO_same_stages = std::is_same_v<MainloopPipeline, MainloopPipeline_dO>;
163
+
164
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
165
+
166
+ int const lane_predicate = cute::elect_one_sync();
167
+ int const warp_idx = cutlass::canonical_warp_idx_sync();
168
+
169
+ // Issue Tma Descriptor Prefetch from a single thread
170
+ if (warp_idx == 0 && lane_predicate) {
171
+ CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
172
+ CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
173
+ }
174
+
175
+ // Obtain warp index
176
+ int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
177
+
178
+ PipelineParams pipeline_params;
179
+ pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE;
180
+ int warp_group_idx = cutlass::canonical_warp_group_idx();
181
+ pipeline_params.role = warp_group_idx == 0
182
+ ? MainloopPipeline::ThreadCategory::Producer
183
+ : MainloopPipeline::ThreadCategory::Consumer;
184
+ pipeline_params.is_leader = warp_group_thread_idx == 0;
185
+ pipeline_params.num_consumers = NumMmaThreads;
186
+
187
+ if (warp_idx == 0 && lane_predicate) {
188
+ shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/);
189
+ }
190
+ // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init();
191
+ MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{});
192
+ auto role_dO = warp_group_idx == 0
193
+ ? MainloopPipeline_dO::ThreadCategory::Producer
194
+ : MainloopPipeline_dO::ThreadCategory::Consumer;
195
+ PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers};
196
+ MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return<Q_dO_same_stages>(pipeline_params, pipeline_params_dO), ClusterShape{});
197
+
198
+ CollectiveMainloop mainloop;
199
+ CollectiveEpilogue epilogue;
200
+
201
+ // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
202
+ if constexpr (size(ClusterShape{}) > 1) {
203
+ cute::cluster_arrive_relaxed();
204
+ cute::cluster_wait();
205
+ } else {
206
+ __syncthreads();
207
+ }
208
+
209
+ TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
210
+
211
+ if (warp_group_idx == 0) { // Producer
212
+ cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
213
+
214
+ int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
215
+ if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO
216
+ PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
217
+ PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline_dO>();
218
+ for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler);
219
+ work_tile_info.is_valid(params.scheduler);
220
+ work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info)) {
221
+ auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
222
+ auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
223
+ cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
224
+ auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
225
+ scheduler.prefetch_next_work(params.scheduler, work_tile_info);
226
+ };
227
+ mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write,
228
+ smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord);
229
+ }
230
+ mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do);
231
+ } else if (warp_idx_in_warpgroup == 1) {
232
+ for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
233
+ work_tile_info.is_valid(params.scheduler);
234
+ work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
235
+ auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
236
+ auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
237
+ cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
238
+ mainloop.store_dq(params.mainloop, shared_storage, block_coord);
239
+ }
240
+ }
241
+ } else { // Consumer
242
+ cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
243
+ // Initialize matmul objects.
244
+ TiledMmadKV tiled_mma_dKV;
245
+
246
+ PipelineState smem_pipe_read;
247
+ PipelineState_dO smem_pipe_read_do;
248
+
249
+ mainloop.mma_init();
250
+ scheduler.init_consumer();
251
+
252
+ int work_idx = 0;
253
+ CUTLASS_PRAGMA_NO_UNROLL
254
+ for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
255
+ work_tile_info.is_valid(params.scheduler);
256
+ work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
257
+ auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
258
+ auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
259
+ cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
260
+
261
+ // dK and dV output accumulator.
262
+ Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
263
+ Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
264
+ bool tile_valid = mainloop.mma(
265
+ params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do,
266
+ tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage);
267
+ if (tile_valid) {
268
+ epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
269
+ threadIdx.x - NumCopyThreads, block_coord);
270
+ } else {
271
+ epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
272
+ }
273
+
274
+ }
275
+ epilogue.store_tail();
276
+ }
277
+
278
+ }
279
+
280
+ };
281
+
282
+ } // namespace flash
flash-attn/flash_bwd_launch_template.h ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include "cutlass/device_kernel.h" // For device_kernel
10
+ #include "cutlass/kernel_launch.h" // For kernel_launch
11
+ #include "cutlass/cluster_launch.hpp" // For ClusterLauncher
12
+
13
+ #include "static_switch.h"
14
+ #include "flash.h"
15
+ #include "flash_bwd_preprocess_kernel.h"
16
+ #include "flash_bwd_postprocess_kernel.h"
17
+ #include "tile_scheduler.hpp"
18
+ #include "mainloop_bwd_sm90_tma_gmma_ws.hpp"
19
+ #include "mainloop_bwd_sm80.hpp"
20
+ #include "epilogue_bwd.hpp"
21
+ #include "flash_bwd_kernel_sm90.h"
22
+ #include "flash_bwd_kernel_sm80.h"
23
+
24
+ using namespace cute;
25
+
26
+ template <int Arch, int kHeadDim, int kBlockM, int kBlockN, typename Element,
27
+ bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool Deterministic, bool GQA,
28
+ int Stages_dO=2, int Stages_dS_or_QSm80=2,
29
+ bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
30
+ int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
31
+ bool V_in_regs=false>
32
+ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
33
+ static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
34
+ using ElementAccum = float;
35
+ using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
36
+
37
+ int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM);
38
+ int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN);
39
+ bool const is_varlen_q = params.cu_seqlens_q;
40
+ bool const is_varlen_k = params.cu_seqlens_k;
41
+ int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
42
+ int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k;
43
+ int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded;
44
+ int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded;
45
+ int batch_q = !is_varlen_q ? params.b : 1;
46
+ int batch_k = !is_varlen_k ? params.b : 1;
47
+
48
+ using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
49
+ using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, ArchTag, /*Clear_dQaccum=*/true, Varlen>;
50
+ typename PreprocessKernel::Arguments preprocess_args {
51
+ static_cast<Element const*>(params.o_ptr),
52
+ {seqlen_q, params.d, params.h, batch_q}, // shape_O
53
+ {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O
54
+ static_cast<Element const*>(params.do_ptr),
55
+ {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
56
+ static_cast<float*>(params.dsoftmax_sum),
57
+ {seqlen_q_rounded, params.h, batch_q}, // shape_dPsum
58
+ {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
59
+ static_cast<float*>(params.softmax_lse_ptr),
60
+ {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE
61
+ static_cast<float*>(params.softmax_lse_log2_ptr),
62
+ {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
63
+ static_cast<ElementAccum*>(params.dq_accum_ptr),
64
+ {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
65
+ {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum
66
+ params.b,
67
+ params.dq_semaphore,
68
+ params.cu_seqlens_q,
69
+ params.seqused_q
70
+ };
71
+ typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
72
+ int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
73
+ dim3 grid_m(num_m_block, params.h, params.b);
74
+ cutlass::kernel_launch<PreprocessKernel>(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/);
75
+ CHECK_CUDA_KERNEL_LAUNCH();
76
+
77
+ using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
78
+ using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster
79
+ // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80
80
+ static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80;
81
+ static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1;
82
+ using CollectiveMainloop = std::conditional_t<
83
+ Arch >= 90,
84
+ flash::CollectiveMainloopBwdSm90<Stages, Stages_dO, Stages_dS, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
85
+ Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
86
+ SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,
87
+ flash::CollectiveMainloopBwdSm80<Stages, Stages_dO, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm80,
88
+ Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
89
+ SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>
90
+ >;
91
+ using CollectiveEpilogue = std::conditional_t<
92
+ !GQA,
93
+ flash::CollectiveEpilogueBwd<TileShape_MNK, Element, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, dKV_swapAB, NumMmaWarpGroups * (Arch >= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>,
94
+ flash::CollectiveEpilogueBwdGQA<TileShape_MNK, ElementAccum, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, Deterministic>
95
+ >;
96
+ using Scheduler = flash::SingleTileScheduler<Varlen, false /*Split*/, false /*PackGQA*/, kBlockN>;
97
+ using AttnKernel = std::conditional_t<
98
+ Arch >= 90,
99
+ flash::enable_sm90_or_later<flash::FlashAttnBwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
100
+ flash::enable_sm80_to_sm89<flash::FlashAttnBwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
101
+ >;
102
+
103
+ typename CollectiveMainloop::Arguments mainloop_args {
104
+ static_cast<Element const*>(params.q_ptr),
105
+ {seqlen_q, params.d, params.h, batch_q}, // shape_Q
106
+ {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
107
+ static_cast<Element const*>(params.k_ptr),
108
+ {seqlen_k, params.d, params.h_k, batch_k}, // shape_K
109
+ {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
110
+ static_cast<Element const*>(params.v_ptr),
111
+ {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V
112
+ static_cast<Element const*>(params.do_ptr),
113
+ {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
114
+ static_cast<ElementAccum*>(params.dq_accum_ptr),
115
+ {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
116
+ {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
117
+ static_cast<float*>(params.softmax_lse_log2_ptr),
118
+ {seqlen_q_rounded, params.h, batch_q}, // shape_LSE
119
+ {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
120
+ static_cast<float*>(params.dsoftmax_sum),
121
+ {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
122
+ params.scale_softmax,
123
+ params.window_size_left, params.window_size_right,
124
+ params.softcap,
125
+ params.b,
126
+ params.dq_semaphore,
127
+ params.cu_seqlens_q, params.cu_seqlens_k,
128
+ params.seqused_q, params.seqused_k
129
+ };
130
+ // The case work with GQA is ugly but idk how to fix it.
131
+ typename CollectiveEpilogue::Arguments epilogue_args {
132
+ static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dk_ptr : params.dk_accum_ptr),
133
+ [&] {
134
+ if constexpr (!GQA) {
135
+ return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK
136
+ } else {
137
+ return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum
138
+ }
139
+ }(),
140
+ [&] {
141
+ if constexpr (!GQA) {
142
+ return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK
143
+ } else {
144
+ return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum
145
+ }
146
+ }(),
147
+ static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dv_ptr : params.dv_accum_ptr),
148
+ [&] {
149
+ if constexpr (!GQA) {
150
+ return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV
151
+ } else {
152
+ return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum
153
+ }
154
+ }(),
155
+ params.h,
156
+ params.dk_semaphore,
157
+ params.dv_semaphore,
158
+ params.cu_seqlens_k,
159
+ params.seqused_k,
160
+ };
161
+
162
+ int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{}));
163
+ num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{}));
164
+ typename flash::TileSchedulerArguments scheduler_args {
165
+ num_blocks_n, params.h, params.b, 1 /*num_splits*/,
166
+ params.h / params.h_k,
167
+ params.seqlen_k,
168
+ params.seqlen_q, params.d, params.dv, sizeof(Element),
169
+ params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k
170
+ };
171
+
172
+ int device;
173
+ cudaGetDevice(&device);
174
+ typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
175
+ mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
176
+ });
177
+
178
+ dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
179
+ dim3 block_dims = AttnKernel::get_block_shape();
180
+ int smem_size = AttnKernel::SharedStorageSize;
181
+ // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
182
+ // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do));
183
+ // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds));
184
+ // int smem_size_dqacc = [&] {
185
+ // if constexpr (Arch >= 90) {
186
+ // return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc));
187
+ // } else {
188
+ // return 0;
189
+ // }
190
+ // }();
191
+ // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
192
+ // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
193
+ // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse));
194
+ // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum));
195
+ // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum);
196
+ if constexpr (size(ClusterShape{}) > 1) {
197
+ void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
198
+ if (smem_size >= 48 * 1024) {
199
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
200
+ }
201
+ dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
202
+ cutlass::ClusterLauncher::launch(
203
+ grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/);
204
+ } else {
205
+ if (smem_size >= 48 * 1024) {
206
+ CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<AttnKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
207
+ }
208
+ cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/);
209
+ }
210
+ CHECK_CUDA_KERNEL_LAUNCH();
211
+
212
+ using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, ArchTag,
213
+ AttnKernel::CollectiveMainloop::NumMmaThreads,
214
+ typename AttnKernel::CollectiveMainloop::TiledMmadQ,
215
+ AttnKernel::CollectiveMainloop::dQ_swapAB
216
+ >;
217
+ typename PostprocessKernel::Arguments postprocess_args {
218
+ static_cast<ElementAccum const*>(params.dq_accum_ptr),
219
+ {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
220
+ {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
221
+ static_cast<Element*>(params.dq_ptr),
222
+ {seqlen_q, params.d, params.h, batch_q}, // shape_dQ
223
+ {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
224
+ params.scale_softmax,
225
+ params.cu_seqlens_q,
226
+ params.seqused_q
227
+ };
228
+ typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
229
+ int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
230
+ dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b);
231
+ int smem_size_postprocess = PostprocessKernel::SharedStorageSize;
232
+ if (smem_size_postprocess >= 48 * 1024) {
233
+ CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
234
+ }
235
+ cutlass::kernel_launch<PostprocessKernel>(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/);
236
+ CHECK_CUDA_KERNEL_LAUNCH();
237
+
238
+ if constexpr (GQA) {
239
+ using TileShape_NK = cute::Shape<Int<kBlockN>, Int<kHeadDim>>;
240
+ using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_NK, Element, ElementAccum, ArchTag,
241
+ AttnKernel::CollectiveEpilogue::NumEpilogueThreads,
242
+ typename AttnKernel::CollectiveMainloop::TiledMmadKV,
243
+ AttnKernel::CollectiveMainloop::dKV_swapAB
244
+ >;
245
+ typename PostprocessKerneldKV::Arguments postprocess_dK_args {
246
+ static_cast<ElementAccum const*>(params.dk_accum_ptr),
247
+ {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum
248
+ {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum
249
+ static_cast<Element*>(params.dk_ptr),
250
+ {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK
251
+ {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK
252
+ 1.f,
253
+ params.cu_seqlens_k,
254
+ params.seqused_k
255
+ };
256
+ typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args);
257
+ typename PostprocessKerneldKV::Arguments postprocess_dV_args {
258
+ static_cast<ElementAccum const*>(params.dv_accum_ptr),
259
+ {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dVaccum
260
+ {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum
261
+ static_cast<Element*>(params.dv_ptr),
262
+ {seqlen_k, params.d, params.h_k, batch_k}, // shape_dV
263
+ {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV
264
+ 1.f,
265
+ params.cu_seqlens_k,
266
+ params.seqused_k
267
+ };
268
+ typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args);
269
+ int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{}));
270
+ dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b);
271
+ int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize;
272
+ if (smem_size_postprocess >= 48 * 1024) {
273
+ CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKerneldKV>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
274
+ }
275
+ cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/);
276
+ CHECK_CUDA_KERNEL_LAUNCH();
277
+ cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/);
278
+ CHECK_CUDA_KERNEL_LAUNCH();
279
+ }
280
+
281
+ }
282
+
283
+ template<int Arch, typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Has_softcap,
284
+ int Stages_dO=2, int Stages_dS_or_QSm80=2,
285
+ bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
286
+ int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
287
+ bool V_in_regs=false>
288
+ void run_mha_bwd_dispatch(Flash_bwd_params &params, cudaStream_t stream) {
289
+ VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
290
+ BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
291
+ // BOOL_SWITCH(params.deterministic, Deterministic, [&] {
292
+ // run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
293
+ run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
294
+ // });
295
+ });
296
+ });
297
+ }
298
+
299
+
300
+ template<int Arch, typename T, bool Has_softcap>
301
+ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
302
+ CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
303
+ if constexpr (Arch >= 90) {
304
+ if constexpr (Is_causal && Has_softcap) {
305
+ // register spill with 128 x 128
306
+ run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);
307
+ } else {
308
+ // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
309
+ run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);
310
+ }
311
+ } else if constexpr (Arch == 86 || Arch == 89) {
312
+ run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
313
+ // run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
314
+ // run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
315
+ // run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 1, 8, 4, false>(params, stream);
316
+ } else {
317
+ run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false>(params, stream);
318
+ }
319
+ });
320
+ }
321
+
322
+ template<int Arch, typename T, bool Has_softcap>
323
+ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
324
+ CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
325
+ if constexpr (Arch >= 90) {
326
+ run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);
327
+ } else if constexpr (Arch == 86 || Arch == 89) {
328
+ run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
329
+ } else {
330
+ run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false>(params, stream);
331
+ }
332
+ });
333
+ }
334
+
335
+ template<int Arch, typename T, bool Has_softcap>
336
+ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
337
+ CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
338
+ if constexpr (Arch >= 90) {
339
+ if constexpr (Is_causal || Is_local || Has_softcap) {
340
+ run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
341
+ } else {
342
+ run_mha_bwd_dispatch<Arch, T, 80, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
343
+ }
344
+ } else if constexpr (Arch == 86 || Arch == 89) {
345
+ run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true>(params, stream);
346
+ } else {
347
+ run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false>(params, stream);
348
+ }
349
+ });
350
+ }
351
+
352
+ template<int Arch, typename T, bool Has_softcap>
353
+ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
354
+ CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
355
+ if constexpr (Arch >= 90) {
356
+ run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
357
+ } else if constexpr (Arch == 86 || Arch == 89) {
358
+ run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true>(params, stream);
359
+ } else {
360
+ run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false>(params, stream);
361
+ }
362
+ });
363
+ }
364
+
365
+ template<int Arch, typename T, bool Has_softcap>
366
+ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
367
+ CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
368
+ if constexpr (Arch >= 90) {
369
+ run_mha_bwd_dispatch<Arch, T, 64, 80, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
370
+ } else if constexpr (Arch == 86 || Arch == 89) {
371
+ run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true>(params, stream);
372
+ // run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 1, 2, true>(params, stream);
373
+ } else {
374
+ run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 2, 2, false>(params, stream);
375
+ }
376
+ });
377
+ }
flash-attn/flash_bwd_postprocess_kernel.h ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include <cutlass/cutlass.h>
10
+ #include <cutlass/array.h>
11
+ #include <cutlass/numeric_types.h>
12
+ #include <cutlass/numeric_conversion.h>
13
+ #include "cutlass/arch/barrier.h"
14
+
15
+ #include "seqlen.h"
16
+ #include "utils.h"
17
+
18
+ namespace flash {
19
+
20
+ using namespace cute;
21
+
22
+ template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class TiledMma, bool dQ_swapAB>
23
+ class FlashAttnBwdPostprocessConvertdQ {
24
+
25
+ public:
26
+
27
+ // Type Aliases
28
+ using TileShape_MK = TileShape_MK_;
29
+ using ArchTag = ArchTag_;
30
+
31
+ static_assert(ArchTag::kMinComputeCapability >= 75);
32
+ static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90;
33
+
34
+ static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
35
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
36
+
37
+ static constexpr int kBlockM = get<0>(TileShape_MK{});
38
+ static constexpr int kHeadDim = get<1>(TileShape_MK{});
39
+ static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup");
40
+ static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup;
41
+ using R2SLayoutAtomdQaccum = std::conditional_t<
42
+ IsSm90,
43
+ Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumdQWarpGgroups>>>,
44
+ Layout<Shape<Int<kNThreads>>>
45
+ >;
46
+ using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},
47
+ Layout<Shape<Int<IsSm90 ? 4 : 1>>>{})); // Val layout, 1 or 4 vals per read
48
+ using G2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreads>>>;
49
+ // UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions
50
+ using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, ElementAccum>{}, G2SLayoutAtomdQaccum{},
51
+ Layout<Shape<_4>>{})); // Val layout, 4 vals per read
52
+ // We don't do bound checking for the gmem -> smem load so we just assert here.
53
+ static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0);
54
+ static constexpr int SmemdQaccumSize = size(TileShape_MK{});
55
+ using SmemLayoutdQaccumFlat = Layout<Shape<Int<SmemdQaccumSize>>>;
56
+ using SmemLayoutdQaccum = std::conditional_t<
57
+ IsSm90,
58
+ Layout<Shape<Int<kBlockM * kHeadDim / NumdQWarpGgroups>, Int<NumdQWarpGgroups>>>,
59
+ Layout<Shape<Int<kBlockM * kHeadDim>>>
60
+ >;
61
+
62
+ // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
63
+ // then setting kBlockKSmem to 32 will cause "Static shape_div failure".
64
+ // We want to treat it as 64 x 48, so kBlockKSmem should be 16.
65
+ static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{});
66
+ static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16);
67
+ static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
68
+ using SmemLayoutAtomdQ =
69
+ decltype(composition(Swizzle<kSwizzle, 3, 3>{},
70
+ Layout<Shape<Int<8>, Int<kBlockKSmem>>,
71
+ Stride<Int<kBlockKSmem>, _1>>{}));
72
+ using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{}));
73
+ using SmemLayoutdQt =
74
+ decltype(cute::composition(SmemLayoutdQ{},
75
+ make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})),
76
+ make_stride(Int<get<0>(TileShape_MK{})>{}, _1{}))));
77
+
78
+ using SmemCopyAtomdQ = Copy_Atom<
79
+ std::conditional_t<
80
+ IsSm90,
81
+ std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
82
+ AutoVectorizingCopyWithAssumedAlignment<128>
83
+ >,
84
+ Element>;
85
+
86
+ static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
87
+ static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
88
+ static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock));
89
+ static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
90
+ using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
91
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
92
+ using GmemTiledCopy = decltype(
93
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
94
+ GmemLayoutAtom{},
95
+ Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
96
+
97
+ struct SharedStorage : cute::aligned_struct<128> {
98
+ cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>> smem_dqacc;
99
+ cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
100
+ alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum;
101
+ };
102
+
103
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
104
+
105
+ using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
106
+ using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;
107
+ using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
108
+ using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
109
+
110
+ // Device side arguments
111
+ struct Arguments {
112
+ ElementAccum const* ptr_dQaccum;
113
+ ShapedQaccum const shape_dQaccum;
114
+ StridedQaccum const stride_dQaccum;
115
+ Element* ptr_dQ;
116
+ ShapedQ const shape_dQ;
117
+ StridedQ const stride_dQ;
118
+ float const softmax_scale;
119
+ int const* cu_seqlens = nullptr;
120
+ int const* seqused = nullptr;
121
+ };
122
+
123
+ // Kernel entry point API
124
+ struct Params {
125
+ ElementAccum const* ptr_dQaccum;
126
+ ShapedQaccum const shape_dQaccum;
127
+ StridedQaccum const stride_dQaccum;
128
+ Element* ptr_dQ;
129
+ ShapedQ const shape_dQ;
130
+ StridedQ const stride_dQ;
131
+ float const softmax_scale;
132
+ int const* cu_seqlens = nullptr;
133
+ int const* seqused = nullptr;
134
+ };
135
+
136
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
137
+ static
138
+ Params
139
+ to_underlying_arguments(Arguments const& args) {
140
+ return {
141
+ args.ptr_dQaccum,
142
+ args.shape_dQaccum,
143
+ args.stride_dQaccum,
144
+ args.ptr_dQ,
145
+ args.shape_dQ,
146
+ args.stride_dQ,
147
+ args.softmax_scale,
148
+ args.cu_seqlens,
149
+ args.seqused
150
+ };
151
+ }
152
+
153
+ CUTLASS_DEVICE
154
+ void
155
+ operator()(Params const& params, char* smem_buf) {
156
+
157
+ static constexpr int kBlockM = get<0>(TileShape_MK{});
158
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
159
+
160
+ Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{});
161
+ Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{});
162
+ Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{});
163
+ Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{});
164
+
165
+ int const thread_idx = threadIdx.x;
166
+ int const m_block = blockIdx.x;
167
+ int const bidh = blockIdx.y;
168
+ int const bidb = blockIdx.z;
169
+
170
+ flash::SeqlenInfo<true /*Varlen*/, kBlockM> seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused);
171
+ bool const is_varlen = params.cu_seqlens;
172
+ if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; }
173
+
174
+ // Step 1: load dQaccum from gmem to smem
175
+ Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum const*>(params.ptr_dQaccum)),
176
+ params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
177
+ Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block)); // (M * K)
178
+ if constexpr (IsSm90) { // Use BulkCopy
179
+ static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v<ElementAccum> / 8);
180
+ auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};
181
+ // if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); }
182
+ if (thread_idx == 0) {
183
+ shared_storage.barrier_dQaccum.init(1 /*numThreads*/);
184
+ shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
185
+ copy(bulk_copy.with(*reinterpret_cast<uint64_t*>(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat);
186
+ }
187
+ __syncthreads();
188
+ shared_storage.barrier_dQaccum.wait(0);
189
+ } else {
190
+ G2STiledCopydQaccum g2s_tiled_copy_dQaccum;
191
+ auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
192
+ Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum);
193
+ Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum);
194
+ cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s);
195
+ __syncthreads();
196
+ }
197
+
198
+ // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); }
199
+
200
+ // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16
201
+ R2STiledCopydQaccum s2r_tiled_copy_dQaccum;
202
+ auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx);
203
+ Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum);
204
+ TiledMma tiled_mma_dQ;
205
+ Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{}));
206
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); }
207
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); }
208
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); }
209
+ CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum));
210
+ Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum);
211
+ cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
212
+ #pragma unroll
213
+ for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; }
214
+ // Convert tdQrdQ from fp32 to fp16
215
+ Tensor rdQ = make_tensor_like<Element>(taccdQrdQaccum);
216
+ flash::convert_type_out(taccdQrdQaccum, rdQ);
217
+
218
+ // Step 3: Copy dQ from register to smem
219
+ auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ);
220
+ auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx);
221
+ Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
222
+ // if (cute::thread0()) { print(smem_tiled_copy_dQ); }
223
+ // if (cute::thread0()) { print(smem_thr_copy_dQ); }
224
+ // if (cute::thread0()) { print(sdQ); }
225
+ Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return<!dQ_swapAB>(sdQ, sdQt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
226
+ cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
227
+ __syncthreads();
228
+
229
+ // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
230
+ Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0);
231
+ Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
232
+ GmemTiledCopy gmem_tiled_copy_dQ;
233
+ auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx);
234
+ Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
235
+ Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
236
+
237
+ Tensor tdQrdQ = make_fragment_like(tdQsdQ);
238
+ Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{}));
239
+ Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
240
+ #pragma unroll
241
+ for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); }
242
+ // Need to check OOB when reading from smem if kBlockM isn't evenly tiled
243
+ static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
244
+ flash::copy</*Is_even_MN=*/EvenM, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
245
+ gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM);
246
+
247
+ // Step 5: Copy dQ from register to gmem
248
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
249
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
250
+ gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM)
251
+ );
252
+ }
253
+
254
+ };
255
+
256
+ } // namespace flash
flash-attn/flash_bwd_preprocess_kernel.h ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include <cutlass/cutlass.h>
10
+ #include <cutlass/array.h>
11
+ #include <cutlass/numeric_types.h>
12
+ #include <cutlass/numeric_conversion.h>
13
+
14
+ #include "seqlen.h"
15
+ #include "utils.h"
16
+
17
+ namespace flash {
18
+
19
+ using namespace cute;
20
+
21
+ template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, bool Clear_dQaccum, bool Varlen>
22
+ class FlashAttnBwdPreprocess {
23
+
24
+ public:
25
+
26
+ // Type Aliases
27
+ using TileShape_MK = TileShape_MK_;
28
+ using ArchTag = ArchTag_;
29
+
30
+ static_assert(std::is_same_v<Element, cutlass::half_t> && ArchTag::kMinComputeCapability >= 75 ||
31
+ std::is_same_v<Element, cutlass::bfloat16_t> && ArchTag::kMinComputeCapability >= 80 ||
32
+ std::is_same_v<Element, cutlass::float_e4m3_t> && ArchTag::kMinComputeCapability >= 89);
33
+
34
+ static constexpr uint32_t MaxThreadsPerBlock = 256;
35
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
36
+ static constexpr int SharedStorageSize = 0;
37
+
38
+ static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
39
+ static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
40
+ static constexpr int kBlockM = get<0>(TileShape_MK{});
41
+ static constexpr int kHeadDim = get<1>(TileShape_MK{});
42
+ // We want kBlockKGmem to be a power of 2 so that when we do the summing,
43
+ // it's just between threads in the same warp
44
+ static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
45
+ static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
46
+ static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
47
+ using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
48
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
49
+ using GmemTiledCopy = decltype(
50
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
51
+ GmemLayoutAtom{},
52
+ Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
53
+
54
+ static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
55
+ static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, "MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum");
56
+ using GmemLayoutAtomAccum = Layout<Shape<Int<MaxThreadsPerBlock>>>;
57
+ using GmemTiledCopyAccum = decltype(
58
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
59
+ GmemLayoutAtomAccum{},
60
+ Layout<Shape<Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
61
+
62
+ using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
63
+ using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
64
+ using ShapedPsum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q, head, batch)
65
+ using StridedPsum = cute::Stride<_1, int64_t, int64_t>;
66
+ using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
67
+ using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
68
+
69
+ // Device side arguments
70
+ struct Arguments {
71
+ Element const* ptr_O;
72
+ ShapeO const shape_O;
73
+ StrideO const stride_O;
74
+ Element const* ptr_dO;
75
+ StrideO const stride_dO;
76
+ float* ptr_dPsum;
77
+ ShapedPsum const shape_dPsum;
78
+ StridedPsum const stride_dPsum;
79
+ float const* ptr_LSE;
80
+ StridedPsum const stride_LSE;
81
+ float *ptr_LSE_log2;
82
+ StridedPsum const stride_LSE_log2;
83
+ ElementAccum* ptr_dQaccum;
84
+ ShapedQaccum const shape_dQaccum;
85
+ StridedQaccum const stride_dQaccum;
86
+ int num_batch; // We need this to know the size of dq_semaphore in case of varlen
87
+ int* dq_semaphore;
88
+ int const* cu_seqlens = nullptr;
89
+ int const* seqused = nullptr;
90
+ };
91
+
92
+ // Kernel entry point API
93
+ struct Params {
94
+ Element const* ptr_O;
95
+ ShapeO const shape_O;
96
+ StrideO const stride_O;
97
+ Element const* ptr_dO;
98
+ StrideO const stride_dO;
99
+ float* ptr_dPsum;
100
+ ShapedPsum const shape_dPsum;
101
+ StridedPsum const stride_dPsum;
102
+ float const* ptr_LSE;
103
+ StridedPsum const stride_LSE;
104
+ float* ptr_LSE_log2;
105
+ StridedPsum const stride_LSE_log2;
106
+ ElementAccum* ptr_dQaccum;
107
+ ShapedQaccum const shape_dQaccum;
108
+ StridedQaccum const stride_dQaccum;
109
+ int num_batch;
110
+ int* dq_semaphore;
111
+ int const* cu_seqlens = nullptr;
112
+ int const* seqused = nullptr;
113
+ };
114
+
115
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
116
+ static
117
+ Params
118
+ to_underlying_arguments(Arguments const& args) {
119
+ return {
120
+ args.ptr_O,
121
+ args.shape_O,
122
+ args.stride_O,
123
+ args.ptr_dO,
124
+ args.stride_dO,
125
+ args.ptr_dPsum,
126
+ args.shape_dPsum,
127
+ args.stride_dPsum,
128
+ args.ptr_LSE,
129
+ args.stride_LSE,
130
+ args.ptr_LSE_log2,
131
+ args.stride_LSE_log2,
132
+ args.ptr_dQaccum,
133
+ args.shape_dQaccum,
134
+ args.stride_dQaccum,
135
+ args.num_batch,
136
+ args.dq_semaphore,
137
+ args.cu_seqlens,
138
+ args.seqused
139
+ };
140
+ }
141
+
142
+ CUTLASS_DEVICE
143
+ void
144
+ operator()(Params const& params, [[maybe_unused]] char* smem_buf) {
145
+
146
+ static constexpr int kBlockM = get<0>(TileShape_MK{});
147
+
148
+ int const thread_idx = threadIdx.x;
149
+ int const m_block = blockIdx.x;
150
+ int const bidh = blockIdx.y;
151
+ int const bidb = blockIdx.z;
152
+
153
+ flash::SeqlenInfo<Varlen, kBlockM> seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused);
154
+ bool const is_varlen = Varlen && params.cu_seqlens;
155
+ int const seqlen_o = seqlen_info.seqlen;
156
+ if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }
157
+
158
+ Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);
159
+ Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
160
+ Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0);
161
+ Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
162
+
163
+ auto shape_LSE = select<0, 2, 3>(params.shape_O);
164
+ Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0);
165
+ Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape<Int<kBlockM>>{}, make_coord(m_block));
166
+ static_assert(kBlockM <= MaxThreadsPerBlock);
167
+ float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY;
168
+
169
+ GmemTiledCopy gmem_tiled_copy_O;
170
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
171
+
172
+ Tensor tOgO = gmem_thr_copy_O.partition_S(gO);
173
+ Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO);
174
+ // Construct identity layout for gO
175
+ Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
176
+ // Repeat the partitioning with identity layouts
177
+ Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
178
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
179
+ #pragma unroll
180
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
181
+
182
+ // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128)
183
+ Tensor tOrO = make_fragment_like(tOgO);
184
+ Tensor tOrdO = make_fragment_like(tOgdO);
185
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
186
+ gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM
187
+ );
188
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
189
+ gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM
190
+ );
191
+ // if (threadIdx.x == 222) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));}
192
+
193
+ // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64))
194
+ Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout())));
195
+ Tensor tOrO_l = make_tensor(tOrO.data(), l);
196
+ Tensor o_fp32 = make_tensor_like<float>(tOrO_l);
197
+ flash::convert_type_out(tOrO_l, o_fp32);
198
+ Tensor tOrdO_l = make_tensor(tOrdO.data(), l);
199
+ Tensor do_fp32 = make_tensor_like<float>(tOrdO_l);
200
+ flash::convert_type_out(tOrdO_l, do_fp32);
201
+ // Sum across the last dimension
202
+ Tensor dP_sum = make_tensor<float>(make_shape(size<0>(o_fp32)));
203
+ #pragma unroll
204
+ for (int mi = 0; mi < size<0>(o_fp32); ++mi) {
205
+ float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
206
+ #pragma unroll
207
+ for (int ni = 1; ni < size<1>(o_fp32); ni++) {
208
+ dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
209
+ }
210
+ flash::SumOp<float> sum_op;
211
+ dP_sum(mi) = flash::Allreduce<kGmemThreadsPerRow>::run(dP_sum_cur, sum_op);
212
+ }
213
+
214
+ Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0);
215
+ Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape<Int<kBlockM>>{}, make_coord(m_block));
216
+ if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) {
217
+ #pragma unroll
218
+ for (int mi = 0; mi < size(dP_sum); ++mi) {
219
+ int const row = get<0>(tOcO(_0{}, mi, _0{}));
220
+ gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0;
221
+ }
222
+ }
223
+
224
+ int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM);
225
+ Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0);
226
+ Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape<Int<kBlockM>>{}, make_coord(m_block));
227
+ if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) {
228
+ gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E);
229
+ }
230
+
231
+ if constexpr (Clear_dQaccum) {
232
+ Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
233
+ Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));
234
+ GmemTiledCopyAccum gmem_tiled_copy_dQaccum;
235
+ auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx);
236
+ Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
237
+ Tensor zero = make_fragment_like(tdQgdQaccum);
238
+ clear(zero);
239
+ cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, zero, tdQgdQaccum);
240
+ }
241
+
242
+ if (params.dq_semaphore != nullptr && thread_idx == 0) {
243
+ int const num_batch = params.num_batch;
244
+ int const num_head = get<2>(params.shape_O);
245
+ params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0;
246
+ }
247
+
248
+ }
249
+
250
+ };
251
+
252
+ } // namespace flash
flash-attn/flash_fwd_combine.cu ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+
4
+ #include "flash_fwd_combine_launch_template.h"
5
+
6
+ template void run_mha_fwd_combine_<float, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
7
+ template void run_mha_fwd_combine_<float, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
8
+
9
+ template void run_mha_fwd_combine_<cutlass::half_t, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
10
+ template void run_mha_fwd_combine_<cutlass::half_t, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
11
+
12
+ template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
13
+ template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
flash-attn/flash_fwd_combine_kernel.h ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include <cutlass/cutlass.h>
10
+ #include <cutlass/arch/memory.h>
11
+ #include <cutlass/array.h>
12
+ #include <cutlass/numeric_types.h>
13
+ #include <cutlass/numeric_conversion.h>
14
+
15
+ #include "cutlass/arch/grid_dependency_control.h"
16
+
17
+ #include "seqlen.h"
18
+ #include "utils.h"
19
+
20
+ namespace flash {
21
+
22
+ using namespace cute;
23
+
24
+ template <class TileShape_MK_, int kLogMaxSplits_, int kNThreads, int AlignmentLSE_,
25
+ bool Is_even_K, bool Varlen, class Element, class ElementPartial, class ArchTag_>
26
+ class FlashAttnFwdCombine {
27
+
28
+ public:
29
+
30
+ // Type Aliases
31
+ using TileShape_MK = TileShape_MK_;
32
+ using ArchTag = ArchTag_;
33
+ static constexpr int kMaxSplits = 1 << kLogMaxSplits_;
34
+ static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float)));
35
+ static_assert(AlignmentLSE >= 1);
36
+ static constexpr int kStages = 4;
37
+
38
+ static_assert(ArchTag::kMinComputeCapability >= 75);
39
+ static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;
40
+
41
+ static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
42
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
43
+
44
+ static constexpr int kBlockM = get<0>(TileShape_MK{});
45
+ static constexpr int kBlockK = get<1>(TileShape_MK{});
46
+
47
+ static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial);
48
+ static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad");
49
+ static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32);
50
+ static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
51
+ static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
52
+ using GmemCopyAtom = std::conditional_t<
53
+ Has_cp_async,
54
+ cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, ElementPartial>,
55
+ cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>
56
+ >;
57
+ using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
58
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
59
+ static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);
60
+ using GmemTiledCopyAccum = decltype(
61
+ make_tiled_copy(GmemCopyAtom{},
62
+ GmemLayoutAtom{},
63
+ Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
64
+ using GmemTiledCopy = decltype(
65
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
66
+ GmemLayoutAtom{},
67
+ Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
68
+
69
+ using AlignmentTypeLSE = cute::uint_byte_t<static_cast<int>(sizeof(float)) * AlignmentLSE>;
70
+ static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float);
71
+ static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE");
72
+ static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8");
73
+ static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8)));
74
+ static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE;
75
+ static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE");
76
+ using GmemLayoutAtomLSE = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRowLSE>, Int<kGmemThreadsPerRowLSE>>,
77
+ Stride<Int<kGmemThreadsPerRowLSE>, _1>>;
78
+ static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0);
79
+ using GmemCopyAtomLSE = std::conditional_t<
80
+ Has_cp_async,
81
+ cute::Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeLSE>, float>,
82
+ cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<AlignmentLSE * sizeof(float) * 8>, float>
83
+ >;
84
+ using GmemTiledCopyLSE = decltype(
85
+ make_tiled_copy(GmemCopyAtomLSE{},
86
+ GmemLayoutAtomLSE{},
87
+ Layout<Shape<_1, Int<kGmemElemsPerLoadLSE>>>{})); // Val layout, 4 vals per load
88
+
89
+ // Otherwise we get IMA when some threads access sLSE, as we're not doing any masking
90
+ static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE");
91
+ // This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
92
+ using SmemLSESwizzle = std::conditional_t<
93
+ kBlockMSmem == 8,
94
+ Swizzle<5, 0, 5>,
95
+ std::conditional_t<kBlockMSmem == 16, Swizzle<4, 0, 4>, Swizzle<3, 2, 3>>
96
+ >;
97
+ using SmemLayoutAtomLSE =
98
+ decltype(composition(SmemLSESwizzle{},
99
+ Layout<Shape<Int<8>, Int<kBlockMSmem>>,
100
+ Stride<Int<kBlockMSmem>, _1>>{}));
101
+ using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape<Int<kMaxSplits>, Int<kBlockM>>{}));
102
+
103
+ using SmemLayoutO = Layout<Shape<Int<kBlockM>, Int<kBlockK>, Int<kStages>>,
104
+ Stride<Int<kBlockK>, _1, Int<kBlockM * kBlockK>>>;
105
+
106
+ // We want each column (kMaxSplits) to be processed by threads in the same warp.
107
+ // To reduce the number of shuffles, we want as few threads on the same column as possible.
108
+ // E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column
109
+ // have have 64 such quads.
110
+ static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem");
111
+ static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem;
112
+ static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp");
113
+ using S2RLayoutAtomLSE = Layout<Shape<Int<kSmemThreadsPerColLSEt>, Int<MaxThreadsPerBlock / kSmemThreadsPerColLSEt>>>;
114
+ using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, float>{}, S2RLayoutAtomLSE{}, Layout<_1>{}));
115
+
116
+ using ShapeOPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, num_splits, head, batch)
117
+ using StrideOPartial = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
118
+ using ShapeLSEPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, num_splits, head, batch)
119
+ using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen, num_splits, head, batch)
120
+ using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
121
+ using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
122
+ using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
123
+ using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
124
+
125
+ struct SharedStorage : cute::aligned_struct<128> {
126
+ cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
127
+ cute::array_aligned<int, kBlockM> smem_max_valid_split;
128
+ cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
129
+ };
130
+
131
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
132
+
133
+ // Device side arguments
134
+ struct Arguments {
135
+ ElementPartial const* const ptr_O_partial;
136
+ ShapeOPartial const shape_O_partial;
137
+ StrideOPartial const stride_O_partial;
138
+ float const* const ptr_LSE_partial;
139
+ ShapeLSEPartial const shape_LSE_partial;
140
+ StrideLSEPartial const stride_LSE_partial;
141
+ Element* const ptr_O;
142
+ StrideO const stride_O;
143
+ float* const ptr_LSE;
144
+ StrideLSE const stride_LSE;
145
+ int const* const cu_seqlens = nullptr;
146
+ int const* const seqused = nullptr;
147
+ int const* const num_splits_dynamic_ptr = nullptr;
148
+ int* const semaphore_to_reset = nullptr;
149
+ };
150
+
151
+ // Kernel entry point API
152
+ struct Params {
153
+ ElementPartial const* const ptr_O_partial;
154
+ ShapeOPartial const shape_O_partial;
155
+ StrideOPartial const stride_O_partial;
156
+ float const* const ptr_LSE_partial;
157
+ ShapeLSEPartial const shape_LSE_partial;
158
+ StrideLSEPartial const stride_LSE_partial;
159
+ Element* const ptr_O;
160
+ StrideO const stride_O;
161
+ float* const ptr_LSE;
162
+ StrideLSE const stride_LSE;
163
+ cutlass::FastDivmod seqlen_divmod, head_divmod;
164
+ int const* const cu_seqlens = nullptr;
165
+ int const* const seqused = nullptr;
166
+ int const* const num_splits_dynamic_ptr = nullptr;
167
+ int* const semaphore_to_reset = nullptr;
168
+ };
169
+
170
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
171
+ static
172
+ Params
173
+ to_underlying_arguments(Arguments const& args) {
174
+ assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
175
+ return {
176
+ args.ptr_O_partial,
177
+ args.shape_O_partial,
178
+ args.stride_O_partial,
179
+ args.ptr_LSE_partial,
180
+ args.shape_LSE_partial,
181
+ args.stride_LSE_partial,
182
+ args.ptr_O,
183
+ args.stride_O,
184
+ args.ptr_LSE,
185
+ args.stride_LSE,
186
+ cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)),
187
+ args.cu_seqlens,
188
+ args.seqused,
189
+ args.num_splits_dynamic_ptr,
190
+ args.semaphore_to_reset
191
+ };
192
+ }
193
+
194
+ CUTLASS_DEVICE
195
+ void
196
+ operator()(Params const& params, char* smem_buf) {
197
+
198
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
199
+ Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
200
+ Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
201
+ Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
202
+
203
+ int const thread_idx = threadIdx.x;
204
+ int const m_block = blockIdx.x;
205
+ int const k_block = blockIdx.y;
206
+ int const batch = blockIdx.z;
207
+ int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
208
+
209
+ if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
210
+ cutlass::arch::wait_on_dependent_grids();
211
+ *params.semaphore_to_reset = 0;
212
+ }
213
+ if (num_splits <= 1) { return; }
214
+ flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
215
+ int const offset = seqlen_info.offset;
216
+ int const seqlen = seqlen_info.seqlen;
217
+ int max_idx = seqlen * get<2>(params.shape_LSE_partial);
218
+ if constexpr (Varlen) {
219
+ if (m_block * kBlockM >= max_idx) { return; }
220
+ }
221
+
222
+ cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
223
+
224
+ // Step 1: load LSE_partial from gmem -> smem
225
+ Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)),
226
+ select<1, 0, 2, 3>(params.shape_LSE_partial),
227
+ select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head)
228
+ Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int<kGmemElemsPerLoadLSE>>{});
229
+ GmemTiledCopyLSE gmem_tiled_copy_LSE;
230
+ auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx);
231
+ Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE);
232
+
233
+ // Construct identity layout for sLSE
234
+ Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE))); // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m)
235
+ // Repeat the partitioning with identity layouts
236
+ Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE);
237
+
238
+ cutlass::arch::wait_on_dependent_grids();
239
+
240
+ #pragma unroll
241
+ for (int m = 0; m < size<2>(tLSEcLSE); ++m) {
242
+ int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m)));
243
+ int idx = m_block * kBlockM + mi;
244
+ if (idx < max_idx) {
245
+ int m_idx, bidh;
246
+ if constexpr (!Varlen) {
247
+ bidh = params.seqlen_divmod.divmod(m_idx, idx);
248
+ } else {
249
+ bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
250
+ }
251
+ Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh);
252
+ #pragma unroll
253
+ for (int s = 0; s < size<1>(tLSEcLSE); ++s) {
254
+ int si = get<0>(tLSEcLSE(_0{}, s, _0{}));
255
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast<int>(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);}
256
+ if (si < num_splits) {
257
+ cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m));
258
+ } else {
259
+ cute::fill(tLSEsLSE(_, s, m), -INFINITY);
260
+ }
261
+ }
262
+ } else {
263
+ // We don't need to zero out the rest of the LSEs, as we will not write the output to gmem
264
+ // cute::fill(tLSEsLSE(_, _, m), -INFINITY);
265
+ }
266
+ }
267
+ if constexpr (Has_cp_async) { cute::cp_async_fence(); }
268
+
269
+ // Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2.
270
+ // We want these async loads to be in flight as we compute the LSE.
271
+ GmemTiledCopyAccum gmem_tiled_copy_O_partial;
272
+ auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx);
273
+ // Construct identity layout for gO
274
+ Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
275
+ // Repeat the partitioning with identity layouts
276
+ Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO);
277
+ Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)),
278
+ params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head)
279
+
280
+ // Precompute these values to avoid recomputing them in the loop
281
+ Tensor tOmidx = make_tensor<int>(make_shape(size<1>(tOcO)));
282
+ Tensor tObidh = make_tensor<int>(make_shape(size<1>(tOcO)));
283
+ Tensor tOrOptr = make_tensor<ElementPartial const*>(make_shape(size<1>(tOcO)));
284
+ #pragma unroll
285
+ for (int m = 0; m < size<1>(tOcO); ++m) {
286
+ int mi = get<0>(tOcO(_0{}, m, _0{}));
287
+ int idx = m_block * kBlockM + mi;
288
+ if constexpr (!Varlen) {
289
+ tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx);
290
+ } else {
291
+ tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx);
292
+ }
293
+ tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m));
294
+ if (idx >= max_idx) {
295
+ tObidh[m] = -1;
296
+ }
297
+ }
298
+
299
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
300
+ if constexpr (!(Is_even_K)) {
301
+ #pragma unroll
302
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; }
303
+ }
304
+
305
+ Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO);
306
+
307
+ auto load_O_partial = [&] (int split, int stage) {
308
+ Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage);
309
+ #pragma unroll
310
+ for (int m = 0; m < size<1>(tOcO); ++m) {
311
+ if (tObidh(m) >= 0) {
312
+ Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout());
313
+ Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape<Int<kGmemElemsPerLoad>>{});
314
+ #pragma unroll
315
+ for (int k = 0; k < size<2>(tOcO); ++k) {
316
+ int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
317
+ if (Is_even_K || tOpO(k)) {
318
+ cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k));
319
+ }
320
+ }
321
+ }
322
+ }
323
+ };
324
+
325
+ for (int s = 0; s < kStages - 1; ++s) {
326
+ if (s < num_splits) { load_O_partial(s, s); }
327
+ if constexpr (Has_cp_async) { cute::cp_async_fence(); }
328
+ }
329
+
330
+ // Step 3: load and transpose LSE_partial from smem -> rmem
331
+ if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
332
+ __syncthreads();
333
+
334
+ S2RTiledCopyLSE s2r_tiled_copy_LSE;
335
+ auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx);
336
+ Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE);
337
+ Tensor ts2rrLSE = make_fragment_like(ts2rsLSE);
338
+ cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE);
339
+
340
+ // Step 4: compute the final LSE along the split dimension
341
+ Tensor lse_sum = make_tensor<float>(make_shape(size<2>(ts2rrLSE)));
342
+ Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE);
343
+ // We compute the max valid split for each row to short-circuit the computation later
344
+ Tensor max_valid_split = make_tensor<int>(make_shape(size<2>(ts2rrLSE)));
345
+ static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1);
346
+ #pragma unroll
347
+ for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
348
+ float lse_max = ts2rrLSE(_0{}, _0{}, m);
349
+ #pragma unroll
350
+ for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); }
351
+ MaxOp<float> max_op;
352
+ lse_max = Allreduce<kSmemThreadsPerColLSEt>::run(lse_max, max_op);
353
+ int max_valid_idx = -1;
354
+ #pragma unroll
355
+ for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
356
+ if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); }
357
+ }
358
+ MaxOp<int> max_int_op;
359
+ max_valid_split[m] = Allreduce<kSmemThreadsPerColLSEt>::run(max_valid_idx, max_int_op);
360
+ float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
361
+ float lse_sum_cur = 0.f;
362
+ #pragma unroll
363
+ for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
364
+ float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur);
365
+ lse_sum_cur += scale;
366
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast<int>(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);}
367
+ // ts2rsLSE(_0{}, m, s) = scale;
368
+ ts2rrLSE(_0{}, s, m) = scale;
369
+ }
370
+ SumOp<float> sum_op;
371
+ lse_sum_cur = Allreduce<kSmemThreadsPerColLSEt>::run(lse_sum_cur, sum_op);
372
+ lse_sum(m) = logf(lse_sum_cur) + lse_max;
373
+ float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur;
374
+ #pragma unroll
375
+ for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; }
376
+ }
377
+ // Store the scales exp(lse - lse_logsum) back to smem
378
+ cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE);
379
+
380
+ // Store max_valid_split to smem
381
+ #pragma unroll
382
+ for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
383
+ if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem
384
+ int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
385
+ if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; }
386
+ }
387
+ }
388
+
389
+ // Step 5: store final LSE back to gmem
390
+ if (k_block == 0) {
391
+ auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial);
392
+ Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0);
393
+ #pragma unroll
394
+ for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
395
+ if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem
396
+ int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
397
+ int idx = m_block * kBlockM + mi;
398
+ if (idx < max_idx) {
399
+ int m_idx, bidh;
400
+ if constexpr (!Varlen) {
401
+ bidh = params.seqlen_divmod.divmod(m_idx, idx);
402
+ } else {
403
+ bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
404
+ }
405
+ // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m));
406
+ mLSE(m_idx, bidh) = lse_sum(m);
407
+ }
408
+ }
409
+ }
410
+ }
411
+
412
+ // Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O
413
+ __syncthreads();
414
+ int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))];
415
+ #pragma unroll
416
+ for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); }
417
+ Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor<ElementPartial>(TileShape_MK{})).layout();
418
+ Tensor tOrOpartial = make_fragment_like<ElementPartial>(tOrOpartial_layout);
419
+ Tensor tOrO = make_fragment_like<float>(tOrOpartial);
420
+ clear(tOrO);
421
+ int stage_load = kStages - 1, stage_compute = 0;
422
+ #pragma unroll 4 // Already tuned for speed
423
+ for (int s = 0; s <= thr_max_valid_split; ++s) {
424
+ Tensor scale = make_tensor<float>(make_shape(size<1>(tOrOpartial)));
425
+ #pragma unroll
426
+ for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); }
427
+
428
+ if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); }
429
+ if constexpr (Has_cp_async) { cute::cp_async_fence(); }
430
+ stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0;
431
+ if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
432
+ // We don't need __syncthreads() because each thread is just reading its own data from smem
433
+ cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>{},
434
+ tOsOpartial(_, _, _, stage_compute), tOrOpartial);
435
+ stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0;
436
+
437
+ #pragma unroll
438
+ for (int m = 0; m < size<1>(tOrOpartial); ++m) {
439
+ if (tObidh(m) >= 0 && scale(m) > 0.f) {
440
+ #pragma unroll
441
+ for (int k = 0; k < size<2>(tOrOpartial); ++k) {
442
+ if (Is_even_K || tOpO(k)) {
443
+ Tensor rOpartial = make_tensor_like<float>(tOrOpartial(_, m, k));
444
+ flash::convert_type_out(tOrOpartial(_, m, k), rOpartial);
445
+ #pragma unroll
446
+ for (int i = 0; i < size<0>(tOrOpartial); ++i) {
447
+ tOrO(i, m, k) += scale(m) * rOpartial[i];
448
+ }
449
+ }
450
+ }
451
+ }
452
+ }
453
+ }
454
+
455
+ // Step 7: Write the final O to gmem
456
+ Tensor rO = make_tensor_like<Element>(tOrO);
457
+ flash::convert_type_out(tOrO, rO);
458
+ auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial));
459
+ Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)),
460
+ shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0);
461
+ Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int<kGmemElemsPerLoad>>{});
462
+ GmemTiledCopy gmem_tiled_copy_O;
463
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
464
+
465
+ #pragma unroll
466
+ for (int m = 0; m < size<1>(tOcO); ++m) {
467
+ if (tObidh(m) >= 0) {
468
+ #pragma unroll
469
+ for (int k = 0; k < size<2>(tOcO); ++k) {
470
+ int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
471
+ if (Is_even_K || tOpO(k)) {
472
+ cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m)));
473
+ }
474
+ }
475
+ }
476
+ }
477
+
478
+ }
479
+
480
+ };
481
+
482
+ } // namespace flash
flash-attn/flash_fwd_combine_launch_template.h ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include "cutlass/cutlass.h"
10
+ #include "cutlass/arch/arch.h" // For cutlass::arch::Sm80
11
+ #include "cutlass/device_kernel.h" // For device_kernel
12
+ #include "cutlass/kernel_launch.h" // For kernel_launch
13
+
14
+ #include "static_switch.h"
15
+ #include "flash.h"
16
+ #include "flash_fwd_combine_kernel.h"
17
+
18
+ using namespace cute;
19
+
20
+ template <int Arch, int kBlockM, int kBlockK, int kLogMaxSplits, bool IsEvenK, bool Varlen, typename Element, typename ElementPartial>
21
+ void run_flash_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl) {
22
+ using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
23
+ using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kBlockK>>;
24
+ using CombineKernel = flash::FlashAttnFwdCombine<TileShape_MK, kLogMaxSplits, 256 /*kNThreads*/, 1 /*AlignmentLSE*/,
25
+ IsEvenK, Varlen, Element, ElementPartial, ArchTag>;
26
+
27
+ typename CombineKernel::Arguments args {
28
+ static_cast<ElementPartial const*>(params.oaccum_ptr),
29
+ {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial
30
+ {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial
31
+ static_cast<float*>(params.softmax_lseaccum_ptr),
32
+ {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial
33
+ {_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0}, // stride_LSE_partial
34
+ static_cast<Element*>(params.o_ptr),
35
+ {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O
36
+ static_cast<float*>(params.softmax_lse_ptr),
37
+ {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE
38
+ params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
39
+ };
40
+
41
+ typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args);
42
+ int num_blocks_k = cute::ceil_div(params.dv, kBlockK);
43
+ int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM);
44
+ dim3 grid_m(num_blocks_m, num_blocks_k, params.b);
45
+ auto kernel = cutlass::device_kernel<CombineKernel>;
46
+ int smem_size = CombineKernel::SharedStorageSize;
47
+ if (smem_size >= 48 * 1024) {
48
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
49
+ }
50
+ // kernel<<<grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream>>>(kernel_params);
51
+ cutlass::kernel_launch<CombineKernel>(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/);
52
+ CHECK_CUDA_KERNEL_LAUNCH();
53
+ }
54
+
55
+ template<typename T, typename Tpartial, int kBlockK>
56
+ void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl) {
57
+ // We want kBlockM to be as small as possible to maximize parallelism.
58
+ // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
59
+ static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32");
60
+ static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32);
61
+ ARCH_SWITCH(params.arch, Arch, [&] {
62
+ BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] {
63
+ if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32.
64
+ if (params.num_splits <= 16) {
65
+ run_flash_fwd_combine<Arch, kBlockM, kBlockK, 4, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
66
+ return;
67
+ }
68
+ }
69
+ if (params.num_splits <= 32) {
70
+ run_flash_fwd_combine<Arch, kBlockM, kBlockK, 5, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
71
+ } else if (params.num_splits <= 64) {
72
+ run_flash_fwd_combine<Arch, kBlockM, kBlockK, 6, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
73
+ } else if (params.num_splits <= 128) {
74
+ run_flash_fwd_combine<Arch, kBlockM, kBlockK, 7, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
75
+ } else {
76
+ run_flash_fwd_combine<Arch, kBlockM, kBlockK, 8, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
77
+ }
78
+ });
79
+ });
80
+ }
flash-attn/flash_fwd_kernel_sm80.h ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include <cutlass/cutlass.h>
10
+ #include <cutlass/array.h>
11
+ #include <cutlass/numeric_types.h>
12
+ #include <cutlass/kernel_hardware_info.h>
13
+
14
+ #include "seqlen.h"
15
+ #include "utils.h"
16
+ #include "softmax.h"
17
+
18
+ namespace flash {
19
+
20
+ using namespace cute;
21
+
22
+ template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
23
+ class FlashAttnFwdSm80 {
24
+
25
+ public:
26
+
27
+ // Type Aliases
28
+ using CollectiveMainloop = CollectiveMainloop_;
29
+ using CollectiveEpilogue = CollectiveEpilogue_;
30
+ static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
31
+ static constexpr bool Is_local = CollectiveMainloop::Is_local;
32
+ static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
33
+ static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
34
+ static constexpr bool Varlen = CollectiveMainloop::Varlen;
35
+ static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
36
+ static constexpr bool Split = CollectiveMainloop::Split;
37
+ static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
38
+ static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
39
+ static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
40
+ static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
41
+ static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
42
+ using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
43
+
44
+ // Mainloop derived types
45
+ using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
46
+ using TiledMma = typename CollectiveMainloop::TiledMma;
47
+ using ArchTag = typename CollectiveMainloop::ArchTag;
48
+ using MainloopArguments = typename CollectiveMainloop::Arguments;
49
+ using MainloopParams = typename CollectiveMainloop::Params;
50
+
51
+ // Epilogue derived types
52
+ using EpilogueArguments = typename CollectiveEpilogue::Arguments;
53
+ using EpilogueParams = typename CollectiveEpilogue::Params;
54
+
55
+ static_assert(ArchTag::kMinComputeCapability >= 80);
56
+
57
+ using TileScheduler = TileScheduler_;
58
+ using TileSchedulerArguments = typename flash::TileSchedulerArguments;
59
+ using TileSchedulerParams = typename TileScheduler::Params;
60
+
61
+ static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{}));
62
+ static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{}));
63
+ static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1;
64
+
65
+ // Kernel level shared memory storage
66
+ // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q
67
+ // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k).
68
+ static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage))
69
+ - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)))
70
+ - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)));
71
+ static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
72
+ struct SharedStorage {
73
+ struct TensorStorage : cute::aligned_struct<128> {
74
+ union {
75
+ struct {
76
+ cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
77
+ typename CollectiveMainloop::TensorStorage mainloop;
78
+ };
79
+ // We want smem_o to line up with the start of smem_v
80
+ typename CollectiveEpilogue::TensorStorage epilogue;
81
+ };
82
+ } tensors;
83
+
84
+ alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
85
+
86
+ };
87
+
88
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
89
+
90
+ // Device side arguments
91
+ struct Arguments {
92
+ MainloopArguments mainloop{};
93
+ EpilogueArguments epilogue{};
94
+ cutlass::KernelHardwareInfo hw_info{};
95
+ TileSchedulerArguments scheduler{};
96
+ };
97
+
98
+ // Kernel entry point API
99
+ struct Params {
100
+ MainloopParams mainloop{};
101
+ EpilogueParams epilogue{};
102
+ cutlass::KernelHardwareInfo hw_info{};
103
+ TileSchedulerParams scheduler{};
104
+ };
105
+
106
+ //
107
+ // Methods
108
+ //
109
+
110
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
111
+ static
112
+ Params
113
+ to_underlying_arguments(Arguments const& args) {
114
+ CUTLASS_TRACE_HOST("to_underlying_arguments():");
115
+
116
+ // Get SM count if needed, otherwise use user supplied SM count
117
+ int sm_count = args.hw_info.sm_count;
118
+ if (sm_count <= 0) {
119
+ CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
120
+ " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
121
+ sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
122
+ }
123
+
124
+ CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
125
+
126
+ cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
127
+ return {
128
+ CollectiveMainloop::to_underlying_arguments(args.mainloop),
129
+ CollectiveEpilogue::to_underlying_arguments(args.epilogue),
130
+ hw_info,
131
+ TileScheduler::to_underlying_arguments(args.scheduler)
132
+ };
133
+ }
134
+
135
+ // Computes the kernel launch grid shape based on runtime parameters
136
+ static dim3
137
+ get_grid_shape(Params const& params) {
138
+ return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor);
139
+ }
140
+
141
+ static dim3
142
+ get_block_shape() {
143
+ return dim3(MaxThreadsPerBlock, 1, 1);
144
+ }
145
+
146
+ CUTLASS_DEVICE
147
+ void
148
+ operator()(Params const& params, char* smem_buf) {
149
+
150
+ static constexpr int kBlockM = get<0>(TileShape_MNK{});
151
+
152
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
153
+
154
+ CollectiveMainloop mainloop;
155
+ CollectiveEpilogue epilogue;
156
+
157
+ TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
158
+ // Initialize matmul objects.
159
+ TiledMma tiled_mma;
160
+
161
+ scheduler.init_consumer();
162
+
163
+ int warp_idx = cutlass::canonical_warp_idx_sync();
164
+ CUTLASS_PRAGMA_NO_UNROLL
165
+ for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
166
+ work_tile_info.is_valid(params.scheduler);
167
+ work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
168
+ // Attention output (GEMM-II) accumulator.
169
+ Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{}));
170
+ float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
171
+ // If there's tanh softcap, the scaling will be done before tanh.
172
+ auto block_coord = work_tile_info.get_block_coord(params.scheduler);
173
+ int const bidb = get<2>(block_coord);
174
+ if constexpr (Is_FP8 && !Has_softcap) {
175
+ int const bidh = get<1>(block_coord);
176
+ int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
177
+ float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
178
+ float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
179
+ softmax_scale_log2 *= q_descale * k_descale;
180
+ }
181
+ flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
182
+
183
+ SeqlenInfo_t seqlen_info{
184
+ bidb,
185
+ get<0>(params.mainloop.shape_Q),
186
+ !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
187
+ get<0>(params.mainloop.shape_K_new),
188
+ params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
189
+ params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
190
+ params.mainloop.seqlens_rotary
191
+ };
192
+ if constexpr (AppendKV) {
193
+ bool tile_new_valid = mainloop.store_kv_new(
194
+ params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord);
195
+ if (tile_new_valid) { __syncthreads(); }
196
+ }
197
+ bool tile_valid = mainloop.mma(
198
+ params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord,
199
+ shared_storage);
200
+ scheduler.prefetch_next_work(params.scheduler, work_tile_info);
201
+ if (tile_valid) {
202
+ // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
203
+ epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma,
204
+ threadIdx.x, block_coord);
205
+ } else {
206
+ // Write 0 to gO and -inf to gLSE.
207
+ epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
208
+ }
209
+ }
210
+
211
+ }
212
+
213
+ };
214
+
215
+ } // namespace flash
flash-attn/flash_fwd_kernel_sm90.h ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include <cutlass/cutlass.h>
10
+ #include <cutlass/arch/reg_reconfig.h>
11
+ #include <cutlass/array.h>
12
+ #include <cutlass/numeric_types.h>
13
+ #include <cutlass/numeric_conversion.h>
14
+ #include <cutlass/kernel_hardware_info.h>
15
+ #include "cutlass/pipeline/pipeline.hpp"
16
+
17
+ #include "cutlass/arch/grid_dependency_control.h"
18
+
19
+ #include "seqlen.h"
20
+ #include "utils.h"
21
+ #include "softmax.h"
22
+
23
+ namespace flash {
24
+
25
+ using namespace cute;
26
+
27
+ template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
28
+ class FlashAttnFwdSm90 {
29
+
30
+ public:
31
+
32
+ // Type Aliases
33
+ using CollectiveMainloop = CollectiveMainloop_;
34
+ using CollectiveEpilogue = CollectiveEpilogue_;
35
+ static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
36
+ static constexpr bool Is_local = CollectiveMainloop::Is_local;
37
+ static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
38
+ static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
39
+ static constexpr bool Varlen = CollectiveMainloop::Varlen;
40
+ static constexpr bool Split = CollectiveMainloop::Split;
41
+ static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
42
+ static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
43
+ static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
44
+ static constexpr bool HasQv = CollectiveMainloop::HasQv;
45
+ static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q;
46
+ static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV;
47
+ static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O;
48
+ static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
49
+ static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
50
+ static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim;
51
+ static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV;
52
+ static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV);
53
+ using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
54
+
55
+ using SmemLayoutSAux = typename CollectiveMainloop::SmemLayoutSAux;
56
+
57
+ // Mainloop derived types
58
+ using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV;
59
+ using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV;
60
+ using ArchTag = typename CollectiveMainloop::ArchTag;
61
+ using ClusterShape = typename CollectiveMainloop::ClusterShape;
62
+ using MainloopArguments = typename CollectiveMainloop::Arguments;
63
+ using MainloopParams = typename CollectiveMainloop::Params;
64
+ using BarrierQ = std::conditional_t<Use_TMA_Q, cutlass::arch::ClusterTransactionBarrier, cutlass::arch::ClusterBarrier>;
65
+
66
+ // Epilogue derived types
67
+ using EpilogueArguments = typename CollectiveEpilogue::Arguments;
68
+ using EpilogueParams = typename CollectiveEpilogue::Params;
69
+
70
+ static_assert(ArchTag::kMinComputeCapability >= 90);
71
+
72
+ using TileScheduler = TileScheduler_;
73
+ using TileSchedulerArguments = typename flash::TileSchedulerArguments;
74
+ using TileSchedulerParams = typename TileScheduler::Params;
75
+
76
+ static constexpr uint32_t NumLoadWarpGroups = 1;
77
+ static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup;
78
+ static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
79
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
80
+ static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
81
+
82
+ /// Register requirement for Load and Math WGs
83
+ // If we use cp.async to load K and V, we need more registers for the producer WG.
84
+ static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);
85
+ static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);
86
+ // If you want to print from the producer warp, you'd need to increase the number of registers
87
+ // Otherwise you'll get CUDA error.
88
+ // static constexpr uint32_t LoadRegisterRequirement = 40;
89
+ // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
90
+
91
+ // Kernel level shared memory storage
92
+ // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v
93
+ // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v).
94
+ static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)));
95
+ static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
96
+ struct SharedStorage {
97
+ struct TensorStorage : cute::aligned_struct<128, _1> {
98
+ union {
99
+ struct {
100
+ cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
101
+ typename CollectiveMainloop::TensorStorage mainloop;
102
+ };
103
+ // We want smem_o to line up with the start of smem_v
104
+ typename CollectiveEpilogue::TensorStorage epilogue;
105
+ };
106
+ } tensors;
107
+ struct PipelineStorage : cute::aligned_struct<16, _1> {
108
+ alignas(16) BarrierQ barrier_Q;
109
+ alignas(16) BarrierQ barrier_Qv;
110
+ alignas(16) cutlass::arch::ClusterBarrier barrier_O;
111
+ alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;
112
+ alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;
113
+ alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt;
114
+ alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new;
115
+ alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new;
116
+ alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
117
+ } pipelines;
118
+
119
+ };
120
+
121
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
122
+
123
+ // Device side arguments
124
+ struct Arguments {
125
+ MainloopArguments mainloop{};
126
+ EpilogueArguments epilogue{};
127
+ cutlass::KernelHardwareInfo hw_info{};
128
+ TileSchedulerArguments scheduler{};
129
+ };
130
+
131
+ // Kernel entry point API
132
+ struct Params {
133
+ MainloopParams mainloop{};
134
+ EpilogueParams epilogue{};
135
+ cutlass::KernelHardwareInfo hw_info{};
136
+ TileSchedulerParams scheduler{};
137
+ };
138
+
139
+ //
140
+ // Methods
141
+ //
142
+
143
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
144
+ static
145
+ Params
146
+ to_underlying_arguments(Arguments const& args) {
147
+ CUTLASS_TRACE_HOST("to_underlying_arguments():");
148
+
149
+ // Get SM count if needed, otherwise use user supplied SM count
150
+ int sm_count = args.hw_info.sm_count;
151
+ if (sm_count <= 0) {
152
+ CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
153
+ " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
154
+ sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
155
+ }
156
+
157
+ CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
158
+
159
+ cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
160
+ return {
161
+ CollectiveMainloop::to_underlying_arguments(args.mainloop),
162
+ CollectiveEpilogue::to_underlying_arguments(args.epilogue),
163
+ hw_info,
164
+ TileScheduler::to_underlying_arguments(args.scheduler)
165
+ };
166
+ }
167
+
168
+ // Computes the kernel launch grid shape based on runtime parameters
169
+ static dim3
170
+ get_grid_shape(Params const& params) {
171
+ return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
172
+ }
173
+
174
+ static dim3
175
+ get_block_shape() {
176
+ return dim3(MaxThreadsPerBlock, 1, 1);
177
+ }
178
+
179
+ CUTLASS_DEVICE
180
+ void
181
+ operator()(Params const& params, char* smem_buf) {
182
+
183
+ static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
184
+ static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
185
+ static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
186
+
187
+ using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;
188
+ using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;
189
+ using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt;
190
+ using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew;
191
+ using PipelineState = typename CollectiveMainloop::PipelineState;
192
+ using PipelineParamsK = typename MainloopPipelineK::Params;
193
+ using PipelineParamsV = typename MainloopPipelineV::Params;
194
+ using PipelineParamsVt = typename MainloopPipelineVt::Params;
195
+ using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params;
196
+
197
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
198
+
199
+ int const lane_predicate = cute::elect_one_sync();
200
+ int const warp_idx = cutlass::canonical_warp_idx_sync();
201
+
202
+ // Issue Tma Descriptor Prefetch from a single thread
203
+ if (warp_idx == 0 && lane_predicate) {
204
+ CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
205
+ CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
206
+ }
207
+
208
+ // Obtain warp index
209
+ int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
210
+ int warp_group_idx = cutlass::canonical_warp_group_idx();
211
+
212
+ if (warp_idx == 0 && lane_predicate) {
213
+ shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
214
+ if constexpr (HasQv) {
215
+ shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
216
+ }
217
+ shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/);
218
+ }
219
+
220
+ // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
221
+ PipelineParamsK pipeline_params_k;
222
+ pipeline_params_k.role = warp_group_idx == 0
223
+ ? MainloopPipelineK::ThreadCategory::Producer
224
+ : MainloopPipelineK::ThreadCategory::Consumer;
225
+ if constexpr (Use_TMA_KV) {
226
+ pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
227
+ pipeline_params_k.is_leader = warp_group_thread_idx == 0;
228
+ pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
229
+ } else {
230
+ pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
231
+ pipeline_params_k.producer_arv_count = NumProducerThreads;
232
+ }
233
+
234
+ static_assert(is_same_v<PipelineParamsK, PipelineParamsVt>);
235
+ PipelineParamsVt pipeline_params_vt = pipeline_params_k;
236
+ if constexpr (Use_TMA_KV && !SameHeadDim) {
237
+ pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
238
+ if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; }
239
+ } else {
240
+ if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; }
241
+ }
242
+
243
+ MainloopPipelineK pipeline_k = [&] {
244
+ if constexpr (Use_TMA_KV) {
245
+ return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});
246
+ } else {
247
+ return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k);
248
+ }
249
+ }();
250
+ // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{});
251
+ MainloopPipelineV pipeline_v = [&] {
252
+ if constexpr (!Transpose_V) {
253
+ static_assert(is_same_v<PipelineParamsK, PipelineParamsV>);
254
+ if constexpr (Use_TMA_KV) {
255
+ return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{});
256
+ } else {
257
+ return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt);
258
+ }
259
+ } else {
260
+ PipelineParamsV pipeline_params_v;
261
+ pipeline_params_v.role = warp_group_idx == 0
262
+ ? MainloopPipelineV::ThreadCategory::Producer
263
+ : MainloopPipelineV::ThreadCategory::Consumer;
264
+ pipeline_params_v.producer_arv_count = NumProducerThreads;
265
+ pipeline_params_v.consumer_arv_count = NumMmaThreads;
266
+ return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v);
267
+ }
268
+ }();
269
+ // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then
270
+ // the producer WG will read from pipeline_vt and write to pipeline_v.
271
+ // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used.
272
+ // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers.
273
+ // However, the thread role isn't used in the pipeline implementation.
274
+ MainloopPipelineVt pipeline_vt = [&] {
275
+ if constexpr (Use_TMA_KV) {
276
+ pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG
277
+ return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{});
278
+ } else {
279
+ pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG
280
+ return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt);
281
+ }
282
+ }();
283
+
284
+ PipelineParamsKVNew pipeline_params_kv_new;
285
+ pipeline_params_kv_new.role = warp_group_idx == 0
286
+ ? MainloopPipelineKVNew::ThreadCategory::Producer
287
+ : MainloopPipelineKVNew::ThreadCategory::Consumer;
288
+ pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
289
+ pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0;
290
+ pipeline_params_kv_new.num_consumers = NumMmaThreads;
291
+ auto pipeline_k_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
292
+ if constexpr (!SameHeadDim) {
293
+ pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
294
+ }
295
+ auto pipeline_v_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
296
+
297
+ CollectiveMainloop mainloop;
298
+ CollectiveEpilogue epilogue;
299
+
300
+ const int num_heads = get<2>(params.mainloop.shape_Q);
301
+ Tensor gS_aux = make_tensor(make_gmem_ptr(params.mainloop.ptr_S_aux), make_shape(num_heads));
302
+ Tensor sS_aux = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_s_aux.data()), SmemLayoutSAux{});
303
+
304
+ if(params.mainloop.ptr_S_aux && threadIdx.x < num_heads) {
305
+ sS_aux(threadIdx.x) = gS_aux(threadIdx.x);
306
+ }
307
+
308
+ // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
309
+ if constexpr (size(ClusterShape{}) > 1) {
310
+ cute::cluster_arrive_relaxed();
311
+ cute::cluster_wait();
312
+ } else {
313
+ __syncthreads();
314
+ }
315
+
316
+ TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
317
+
318
+ if (warp_group_idx == 0) { // Producer
319
+ cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
320
+
321
+ // The pipelines for AppendKV and main attention are different, since e.g. main attention
322
+ // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load
323
+ // KV_new. Since the pipeline states are different, we have to manually sync to make
324
+ // sure the two pipelines don't race when accessing smem_k and smem_v.
325
+ PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
326
+ PipelineState smem_pipe_write_new = cutlass::make_producer_start_state<MainloopPipelineKVNew>();
327
+ int work_idx = 0;
328
+ int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
329
+ static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
330
+ if constexpr (SingleProducerWarp) {
331
+ if (warp_idx_in_warpgroup != 0) { return; }
332
+ }
333
+ if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); }
334
+
335
+ cutlass::arch::wait_on_dependent_grids();
336
+
337
+ // Load Q, K, V
338
+ for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
339
+ work_tile_info.is_valid(params.scheduler);
340
+ work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
341
+
342
+ auto block_coord = work_tile_info.get_block_coord(params.scheduler);
343
+ SeqlenInfo_t seqlen_info{
344
+ get<2>(block_coord) /*bidb*/,
345
+ get<0>(params.mainloop.shape_Q),
346
+ !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
347
+ get<0>(params.mainloop.shape_K_new),
348
+ params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
349
+ params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
350
+ params.mainloop.seqlens_rotary
351
+ };
352
+ if constexpr (AppendKV) {
353
+ bool tile_new_valid = mainloop.load_kv_new(
354
+ params.mainloop, pipeline_k_new, pipeline_v_new,
355
+ smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx);
356
+ if (tile_new_valid) {
357
+ // if (threadIdx.x == 0) { printf("Producer: Before sync\n"); }
358
+ cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
359
+ // if (threadIdx.x == 0) { printf("Producer: After sync\n"); }
360
+ }
361
+ }
362
+ auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
363
+ scheduler.prefetch_next_work(params.scheduler, work_tile_info);
364
+ };
365
+ // pipeline_vt won't be used if we don't need to transpose V.
366
+ mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write,
367
+ shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx);
368
+ }
369
+ mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx);
370
+ } else { // Consumer
371
+ cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
372
+
373
+ // Initialize matmul objects.
374
+ TiledMmaPV tiled_mma_pv;
375
+
376
+ PipelineState smem_pipe_read;
377
+ PipelineState smem_pipe_read_new;
378
+ // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
379
+ // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
380
+
381
+ scheduler.init_consumer();
382
+ mainloop.mma_init();
383
+
384
+ int work_idx = 0;
385
+ CUTLASS_PRAGMA_NO_UNROLL
386
+ for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
387
+ work_tile_info.is_valid(params.scheduler);
388
+ // get_next_work will be called before the epilogue
389
+ ) {
390
+ auto block_coord = work_tile_info.get_block_coord(params.scheduler);
391
+ int const bidb = get<2>(block_coord);
392
+ SeqlenInfo_t seqlen_info{
393
+ bidb,
394
+ get<0>(params.mainloop.shape_Q),
395
+ !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
396
+ get<0>(params.mainloop.shape_K_new),
397
+ params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
398
+ params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
399
+ params.mainloop.seqlens_rotary
400
+ };
401
+ if constexpr (AppendKV) {
402
+ bool tile_new_valid = mainloop.store_kv_new(
403
+ params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new,
404
+ threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord);
405
+ if (tile_new_valid) {
406
+ // if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); }
407
+ // We need this sync so that the gmem write from the consumers is visible to the producer
408
+ // that might do TMA read after that.
409
+ asm volatile ("fence.proxy.async.global;");
410
+ cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
411
+ // arrive is enough, we don't need sync. The producer will sync, which means
412
+ // after that sync we're guaranteed that the AppendKV pipeline have finished
413
+ // loading and consumer smem_k and smem_v.
414
+ // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); }
415
+ }
416
+ }
417
+ // If there's tanh softcap, the scaling will be done before tanh.
418
+ float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
419
+ if constexpr (Is_FP8 && !Has_softcap) {
420
+ int const bidh = get<1>(block_coord);
421
+ int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
422
+ float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
423
+ float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
424
+ softmax_scale_log2 *= q_descale * k_descale;
425
+ }
426
+ flash::Softmax<!LargeHeadDimV ? 2 * (2 * kBlockM / NumMmaThreads) : 2, /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
427
+ // Attention output (GEMM-II) accumulator.
428
+ Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{}));
429
+ bool tile_valid;
430
+ if constexpr (!LargeHeadDimV) {
431
+ tile_valid = mainloop.mma(
432
+ params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
433
+ tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
434
+ } else { // mma_pv might not compile if !LargeHeadDimV
435
+ if (warp_group_idx == 1) {
436
+ tile_valid = mainloop.mma(
437
+ params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
438
+ tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
439
+ } else {
440
+ tile_valid = mainloop.mma_pv(
441
+ params.mainloop, pipeline_v, smem_pipe_read,
442
+ tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage);
443
+ }
444
+ }
445
+ // Do this here before the epilogue so that the next tile is ready to go.
446
+ work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info);
447
+ if constexpr (Split && Varlen) {
448
+ if (!work_tile_info.is_valid(params.scheduler)) { // Last tile
449
+ cutlass::arch::launch_dependent_grids();
450
+ }
451
+ }
452
+ if (tile_valid) {
453
+ // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
454
+ epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv,
455
+ threadIdx.x - MmaThreadOffset, block_coord);
456
+ } else {
457
+ // Write 0 to gO and -inf to gLSE.
458
+ epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);
459
+ }
460
+ }
461
+ epilogue.store_tail();
462
+ }
463
+
464
+ }
465
+
466
+ };
467
+
468
+ } // namespace flash
flash-attn/flash_fwd_launch_template.h ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include "cutlass/cutlass.h"
10
+ #include "cutlass/device_kernel.h" // For device_kernel
11
+ #include <cutlass/kernel_hardware_info.h>
12
+ #include "cutlass/cluster_launch.hpp"
13
+ #include "cutlass/kernel_launch.h"
14
+
15
+ #include "static_switch.h"
16
+ #include "flash.h"
17
+ #include "tile_size.h"
18
+ #include "tile_scheduler.hpp"
19
+ #include "flash_fwd_kernel_sm90.h"
20
+ #include "flash_fwd_kernel_sm80.h"
21
+ #include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
22
+ #include "mainloop_fwd_sm80.hpp"
23
+ #include "epilogue_fwd.hpp"
24
+ #include "heuristics.h"
25
+
26
+ using namespace cute;
27
+
28
+ template <int Arch, int kHeadDim, int kHeadDimV, int ClusterM, typename Element, typename ElementOut,
29
+ bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool PagedKVNonTMA, bool AppendKV, bool HasQv,
30
+ bool PackGQA, bool Split, bool V_colmajor, bool Use_one_mma_wg>
31
+ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
32
+ static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time");
33
+ static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time");
34
+ static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen");
35
+ static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;
36
+ static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;
37
+ using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
38
+ using ElementS = cutlass::bfloat16_t;
39
+
40
+ // Can't use structured binding since it's not compatible with constexpr
41
+ static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg);
42
+ static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);
43
+ static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);
44
+ static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);
45
+ static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);
46
+ static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap);
47
+ static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS);
48
+ static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS);
49
+ static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS);
50
+
51
+ using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
52
+ using TileShape_MNK_PV = cute::Shape<Int<kBlockM>, Int<kHeadDimV>, Int<kBlockN>>;
53
+ using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;
54
+ using CollectiveMainloop = std::conditional_t<
55
+ Arch >= 90,
56
+ flash::CollectiveMainloopFwdSm90<kStages, ClusterShape, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, HasQv, MmaPV_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor, ElementS>,
57
+ flash::CollectiveMainloopFwdSm80<kNWarps, kStages, Q_in_regs, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm80, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, PackGQA, Split, ElementS>
58
+ >;
59
+ using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK_PV, ClusterShape, ElementOut, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, PackGQA, Split, FP8_TransposeV>;
60
+
61
+ static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads;
62
+ using SchedulerPersistent = std::conditional_t<Varlen,
63
+ flash::VarlenDynamicPersistentTileScheduler<kBlockM, CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>,
64
+ std::conditional_t<!Is_causal && !Is_local,
65
+ flash::StaticPersistentTileScheduler<Split>,
66
+ flash::DynamicPersistentTileScheduler<CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>
67
+ >
68
+ >;
69
+ using SchedulerSingleTile = flash::SingleTileScheduler<Varlen, Split, PackGQA, kBlockM>;
70
+ // If Split then we probably don't have enough work for PersistentScheduler to be useful.
71
+ // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better
72
+ // since we'll avoid launching a bunch of thread blocks that immediately exit.
73
+ // On Sm80, noncausal persistent seems a bit slower.
74
+ static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split));
75
+ using Scheduler = std::conditional_t<!UsePersistentScheduler, SchedulerSingleTile, SchedulerPersistent>;
76
+ using AttnKernel = std::conditional_t<
77
+ Arch >= 90,
78
+ flash::enable_sm90_or_later<flash::FlashAttnFwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
79
+ flash::enable_sm80_to_sm89<flash::FlashAttnFwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
80
+ >;
81
+
82
+ bool const is_varlen_q = params.cu_seqlens_q;
83
+ bool const is_varlen_k = params.cu_seqlens_k;
84
+ bool const is_varlen_k_new = params.cu_seqlens_knew;
85
+ int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
86
+ int batch_q = !is_varlen_q ? params.b : 1;
87
+ int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1;
88
+ typename CollectiveMainloop::StrideV v_strides =
89
+ cute::conditional_return<!V_colmajor>(
90
+ make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),
91
+ make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));
92
+ typename CollectiveMainloop::Arguments mainloop_args {
93
+ static_cast<Element const*>(params.q_ptr),
94
+ {seqlen_q, params.d, params.h, batch_q}, // shape_Q
95
+ {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
96
+ static_cast<Element*>(params.k_ptr),
97
+ {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size,
98
+ params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K
99
+ {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
100
+ static_cast<Element*>(params.v_ptr),
101
+ params.dv, // headdim_v
102
+ v_strides, // stride_V
103
+ static_cast<Element const*>(params.knew_ptr),
104
+ {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new
105
+ {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new
106
+ static_cast<Element const*>(params.vnew_ptr),
107
+ {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new
108
+ static_cast<Element const*>(params.qv_ptr),
109
+ {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv
110
+ static_cast<Element const*>(params.rotary_cos_ptr),
111
+ {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter
112
+ {params.rotary_dim / 2, _1{}}, // stride_rotary_cos
113
+ static_cast<Element const*>(params.rotary_sin_ptr),
114
+ {params.rotary_dim / 2, _1{}}, // stride_rotary_sin
115
+ params.is_rotary_interleaved,
116
+ params.page_table,
117
+ // if page_size is not set, avoid dividing by zero
118
+ {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table
119
+ {params.page_table_batch_stride, _1{}}, // stride_page_table
120
+ params.scale_softmax,
121
+ params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr,
122
+ {params.q_descale_batch_stride, params.q_descale_head_stride},
123
+ {params.k_descale_batch_stride, params.k_descale_head_stride},
124
+ {params.v_descale_batch_stride, params.v_descale_head_stride},
125
+ params.window_size_left, params.window_size_right,
126
+ params.softcap,
127
+ params.num_splits,
128
+ params.kv_batch_idx,
129
+ params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
130
+ params.seqused_q, params.seqused_k,
131
+ params.leftpad_k, params.seqlens_rotary,
132
+ static_cast<ElementS const*>(params.s_aux_ptr)
133
+ };
134
+ typename CollectiveEpilogue::Arguments epilogue_args {
135
+ static_cast<ElementOut*>(params.o_ptr),
136
+ {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O
137
+ {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O
138
+ static_cast<float*>(params.oaccum_ptr),
139
+ {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial
140
+ static_cast<float*>(params.softmax_lse_ptr),
141
+ {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE
142
+ static_cast<float*>(params.softmax_lseaccum_ptr),
143
+ {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial
144
+ params.h_k,
145
+ params.cu_seqlens_q, params.seqused_q
146
+ };
147
+
148
+ int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k);
149
+ int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{}));
150
+ num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{}));
151
+ typename flash::TileSchedulerArguments scheduler_args {
152
+ num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits,
153
+ params.h / params.h_k,
154
+ params.seqlen_q,
155
+ params.seqlen_k, params.d, params.dv, sizeof(Element),
156
+ params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q,
157
+ // params.num_m_blocks_ptr,
158
+ params.num_splits_dynamic_ptr,
159
+ };
160
+
161
+ if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) {
162
+ prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/);
163
+ CHECK_CUDA_KERNEL_LAUNCH();
164
+ }
165
+
166
+ int device;
167
+ CHECK_CUDA(cudaGetDevice(&device));
168
+ typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
169
+ mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
170
+ });
171
+
172
+ dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
173
+ dim3 block_dims = AttnKernel::get_block_shape();
174
+ int smem_size = AttnKernel::SharedStorageSize;
175
+ // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
176
+ // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
177
+ // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
178
+ // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
179
+ // Get the ptr to kernel function.
180
+ if constexpr (size(ClusterShape{}) > 1) {
181
+ void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
182
+ if (smem_size >= 48 * 1024) {
183
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
184
+ }
185
+ dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
186
+ cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
187
+ cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
188
+ } else {
189
+ auto kernel = cutlass::device_kernel<AttnKernel>;
190
+ if (smem_size >= 48 * 1024) {
191
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
192
+ }
193
+ // kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);
194
+ cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params,
195
+ Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/);
196
+ }
197
+ CHECK_CUDA_KERNEL_LAUNCH();
198
+ }
199
+
200
+ template<int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
201
+ void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream) {
202
+ static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported");
203
+ static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
204
+ using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
205
+ CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
206
+ VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] {
207
+ static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1;
208
+ VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] {
209
+ BOOL_SWITCH(use_one_mma_wg(params), Use_one_mma_wg_, [&] {
210
+ // Avoid over compiliation by making sure this only get set if it is actually used, i.e. we currently only support one mma wg for 128 head dim and hopper
211
+ static constexpr bool Use_one_mma_wg = Use_one_mma_wg_ && Arch >= 90 && kHeadDim == 128;
212
+
213
+ // Only needed here to decide if we should use cluster
214
+ static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg)) : 128;
215
+
216
+ static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen;
217
+ BOOL_SWITCH(params.qv_ptr, HasQV_, [&] {
218
+ static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256;
219
+ APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
220
+ // Only use Cluster if number of tiles along seqlen_q is even and not varlen
221
+ CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
222
+ static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1;
223
+ run_flash_fwd<Arch, kHeadDim, kHeadDimV, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV && Varlen, HasQv, PackGQA, Split, V_colmajor, Use_one_mma_wg>(params, stream);
224
+ });
225
+ });
226
+ });
227
+ });
228
+ });
229
+ });
230
+ });
231
+ }
flash-attn/flash_prepare_scheduler.cu ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include "cutlass/fast_math.h"
6
+ #include "cutlass/barrier.h"
7
+ #include "cutlass/arch/barrier.h"
8
+
9
+ #include "cutlass/arch/grid_dependency_control.h"
10
+
11
+ #include "flash.h"
12
+
13
+ namespace flash {
14
+
15
+ __global__ void prepare_varlen_num_blocks_kernel(
16
+ int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static,
17
+ int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new,
18
+ int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr,
19
+ int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static,
20
+ cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod,
21
+ int* const tile_count_semaphore,
22
+ // int* const num_m_blocks_ptr,
23
+ int* const num_splits_dynamic_ptr,
24
+ bool enable_pdl) {
25
+
26
+ static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1;
27
+ static constexpr int kSmemSize = 1;
28
+ // Assume that there's only one block in the grid
29
+ __shared__ int total_blocks_smem[kSmemSize];
30
+
31
+ // There's only 1 block in the grid, so might as well start launching the main attn kernel
32
+ if (enable_pdl) { cutlass::arch::launch_dependent_grids(); }
33
+
34
+ if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; }
35
+ __syncthreads();
36
+
37
+ if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; }
38
+
39
+ int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
40
+
41
+ auto get_num_m_blocks = [&](int bidb_start) {
42
+ int batch_idx = lane + bidb_start;
43
+ int seqlen;
44
+ if (seqused_q) {
45
+ seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0;
46
+ } else if (cu_seqlens_q) {
47
+ int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0;
48
+ int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
49
+ seqlen = next_cu_seqlen - cur_cu_seqlen;
50
+ } else {
51
+ seqlen = seqlen_q_static;
52
+ }
53
+ seqlen *= qhead_per_khead;
54
+ return batch_idx < num_batch && lane < kNumBatchPerWarp
55
+ ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0;
56
+ };
57
+
58
+ auto get_num_n_blocks = [&](int bidb_start) {
59
+ int batch_idx = lane + bidb_start;
60
+ int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0;
61
+ int seqlen;
62
+ if (seqused_k) {
63
+ seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0;
64
+ } else if (cu_seqlens_k) {
65
+ int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0;
66
+ int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
67
+ seqlen = next_cu_seqlen - cur_cu_seqlen;
68
+ } else {
69
+ seqlen = seqlen_k_static;
70
+ }
71
+ int seqlen_new;
72
+ if (cu_seqlens_k_new) {
73
+ int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0;
74
+ int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1);
75
+ seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new;
76
+ } else {
77
+ seqlen_new = seqlen_k_new_static;
78
+ }
79
+ // if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); }
80
+ seqlen = seqlen - leftpad_k + seqlen_new;
81
+ return batch_idx < num_batch && lane < kNumBatchPerWarp
82
+ ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0;
83
+ };
84
+
85
+ int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp;
86
+ int bidb_start = kNumBatchPerWarp * warp_idx;
87
+ int num_m_blocks = get_num_m_blocks(bidb_start);
88
+ int num_n_blocks = get_num_n_blocks(bidb_start);
89
+
90
+ int total_blocks = num_m_blocks * num_n_blocks;
91
+ // Warp sum
92
+ #pragma unroll
93
+ for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) {
94
+ total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i);
95
+ }
96
+ if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); }
97
+ __syncthreads();
98
+ total_blocks = total_blocks_smem[0];
99
+ // 10% margin
100
+ int blocks_per_sm = static_cast<int>(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm)));
101
+ // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM
102
+ int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1);
103
+ if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) {
104
+ num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic;
105
+ // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic);
106
+ }
107
+ }
108
+
109
+ } // flash
110
+
111
+ void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa,
112
+ int blockM, int blockN, bool enable_pdl) {
113
+ // Only support batch <= 992 (32 warps, each with 31 batches)
114
+ int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k);
115
+ flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>(
116
+ params.seqlen_q, params.seqlen_k, params.seqlen_knew,
117
+ params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
118
+ params.seqused_q, params.seqused_k, params.leftpad_k,
119
+ params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits,
120
+ cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN),
121
+ params.tile_count_semaphore,
122
+ // params.num_m_blocks_ptr,
123
+ params.num_splits_dynamic_ptr, enable_pdl);
124
+ }
flash-attn/heuristics.h ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <vector>
8
+ #include "flash.h"
9
+
10
+ inline bool use_one_mma_wg(Flash_fwd_params const& params) {
11
+ return params.arch >= 90 && params.d == 128 &&
12
+ params.seqlen_q * (!params.pack_gqa ? 1 : params.h / params.h_k) <= 64;
13
+ };
14
+
15
+ inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) {
16
+ // If varlen, we don't actually know seqlen_q but only max_seqlen_q.
17
+ if (varlen_q) return true;
18
+ // Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM
19
+ auto round_up = [](int a, int b) { return (a + b - 1) / b * b; };
20
+ float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM));
21
+ float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM));
22
+ return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency;
23
+ };
24
+
25
+ // Find the number of splits that maximizes the occupancy. For example, if we have
26
+ // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
27
+ // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
28
+ // splits as that would incur more HBM reads/writes.
29
+ // So we find the best efficiency, then find the smallest number of splits that gets 85%
30
+ // of the best efficiency.
31
+ inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) {
32
+ // If we have enough to almost fill the SMs, then just use 1 split
33
+ // However, in the case of super long seqlen where each head of KV doesn't even fit into
34
+ // L2 (we assume that L2 size is 50MB), we want to split.
35
+ if (total_mblocks >= 0.8f * num_SMs) {
36
+ int const size_l2 = 50 * 1024 * 1024;
37
+ // Only split if there are enough queries to go over the KV at least twice
38
+ // Don't split if causal
39
+ if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) {
40
+ return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits);
41
+ } else {
42
+ return 1;
43
+ }
44
+ }
45
+ // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
46
+ if (num_n_blocks <= 4) { return 1; }
47
+ max_splits = std::min({max_splits, num_SMs, num_n_blocks});
48
+ float max_efficiency = 0.f;
49
+ std::vector<float> efficiency;
50
+ efficiency.reserve(max_splits);
51
+ for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
52
+ float n_waves = float(total_mblocks * num_splits) / num_SMs;
53
+ float eff = n_waves / ceil(n_waves);
54
+ // printf("num_splits = %d, eff = %f\n", num_splits, eff);
55
+ if (eff > max_efficiency) { max_efficiency = eff; }
56
+ efficiency.push_back(eff);
57
+ }
58
+ for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
59
+ if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
60
+ // printf("num_splits chosen = %d\n", num_splits);
61
+ return num_splits;
62
+ }
63
+ }
64
+ return 1;
65
+ }
flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<80, cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim128<86, cutlass::bfloat16_t, false>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim128<90, cutlass::bfloat16_t, false>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<80, cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim128<86, cutlass::bfloat16_t, true>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim128<90, cutlass::bfloat16_t, true>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_hdim128_bf16_sm90.cu"
6
+ #include "flash_bwd_hdim128_bf16_softcap_sm90.cu"
flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<80, cutlass::half_t, false>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim128<86, cutlass::half_t, false>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim128<90, cutlass::half_t, false>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<80, cutlass::half_t, true>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim128<86, cutlass::half_t, true>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim128<90, cutlass::half_t, true>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_hdim128_fp16_sm90.cu"
6
+ #include "flash_bwd_hdim128_fp16_softcap_sm90.cu"
flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<80, cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim192<86, cutlass::bfloat16_t, false>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim192<90, cutlass::bfloat16_t, false>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<80, cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim192<86, cutlass::bfloat16_t, true>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim192<90, cutlass::bfloat16_t, true>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_hdim192_bf16_sm90.cu"
6
+ #include "flash_bwd_hdim192_bf16_softcap_sm90.cu"
flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<80, cutlass::half_t, false>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim192<86, cutlass::half_t, false>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim192<90, cutlass::half_t, false>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<80, cutlass::half_t, true>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim192<86, cutlass::half_t, true>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim192<90, cutlass::half_t, true>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_hdim192_fp16_sm90.cu"
6
+ #include "flash_bwd_hdim192_fp16_softcap_sm90.cu"
flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim256<80, cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim256<86, cutlass::bfloat16_t, false>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim256<90, cutlass::bfloat16_t, false>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim256<80, cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim256<86, cutlass::bfloat16_t, true>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim256<90, cutlass::bfloat16_t, true>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_hdim256_bf16_sm90.cu"
6
+ #include "flash_bwd_hdim256_bf16_softcap_sm90.cu"
flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim256<80, cutlass::half_t, false>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim256<86, cutlass::half_t, false>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim256<90, cutlass::half_t, false>(params, stream);
11
+ }
12
+ #endif