|
import torch |
|
import random |
|
import torch.nn.functional as F |
|
|
|
import flash_mla |
|
|
|
|
|
|
|
|
|
def test_flash_mla(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
b = 16 |
|
s_q = 16 |
|
mean_sk = 16 |
|
h_q = 16 |
|
h_kv = 1 |
|
d = 576 |
|
dv = 512 |
|
|
|
|
|
causal = True |
|
varlen = False |
|
|
|
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") |
|
|
|
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) |
|
if varlen: |
|
for i in range(b): |
|
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) |
|
total_seqlens = cache_seqlens.sum().item() |
|
mean_seqlens = cache_seqlens.float().mean().int().item() |
|
max_seqlen = cache_seqlens.max().item() |
|
|
|
|
|
print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") |
|
max_seqlen_pad = max_seqlen + 255 & ~255 |
|
q = torch.randn(b, s_q, h_q, d) |
|
block_size = 64 |
|
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view( |
|
b, max_seqlen_pad // block_size |
|
) |
|
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) |
|
print(blocked_k.shape) |
|
for i in range(b): |
|
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = float( |
|
"nan" |
|
) |
|
blocked_v = blocked_k[..., :dv] |
|
print(blocked_k.shape, blocked_v.shape) |
|
|
|
cache_seqlens = cache_seqlens.to("cuda") |
|
|
|
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata( |
|
seqlens_k=cache_seqlens, |
|
|
|
s_q=s_q * h_q // h_kv, |
|
h_kv=h_kv, |
|
) |
|
print(tile_scheduler_metadata, num_splits) |
|
|
|
|
|
assert False |
|
|