Update triton_flash_blocksparse_attn.py
Browse filesAdd suggestion similar to https://huggingface.co/THUDM/cogagent-chat-hf/blob/d519da3b191401234f4bd86ce1c287c61bc276a3/util.py#L210 to avoid error
```ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)```
- triton_flash_blocksparse_attn.py +25 -24
triton_flash_blocksparse_attn.py
CHANGED
|
@@ -611,30 +611,31 @@ def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BL
|
|
| 611 |
# print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
|
| 612 |
# {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
|
| 613 |
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
|
|
|
| 638 |
if inference:
|
| 639 |
L, m = None, None
|
| 640 |
|
|
|
|
| 611 |
# print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
|
| 612 |
# {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
|
| 613 |
|
| 614 |
+
with torch.cuda.device(q.device.index):
|
| 615 |
+
_fwd_kernel[grid](
|
| 616 |
+
q, k, v, sm_scale,
|
| 617 |
+
layout_crow_indices,
|
| 618 |
+
layout_col_indices,
|
| 619 |
+
layout_crow_indices.stride(0), layout_crow_indices.stride(1),
|
| 620 |
+
layout_col_indices.stride(0), layout_col_indices.stride(1),
|
| 621 |
+
tmp, L, m,
|
| 622 |
+
o,
|
| 623 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
| 624 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
| 625 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
| 626 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
| 627 |
+
q.shape[0], q.shape[1], k.shape[2],
|
| 628 |
+
k.shape[2] - q.shape[2],
|
| 629 |
+
q_rounded_len,
|
| 630 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
| 631 |
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
| 632 |
+
EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
|
| 633 |
+
EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
|
| 634 |
+
INFERENCE=inference,
|
| 635 |
+
NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
|
| 636 |
+
num_warps=num_warps,
|
| 637 |
+
num_stages=num_stages,
|
| 638 |
+
)
|
| 639 |
if inference:
|
| 640 |
L, m = None, None
|
| 641 |
|