|
import torch |
|
|
|
from ._ops import ops |
|
|
|
|
|
def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): |
|
return ops.get_mla_metadata(seqlens_k, s_q, h_kv) |
|
|
|
|
|
def mha_fwd_kvcache_mla( |
|
q: torch.Tensor, |
|
kcache: torch.Tensor, |
|
vcache_: torch.Tensor, |
|
head_size_v: int, |
|
seqlens_k: torch.Tensor, |
|
block_table: torch.Tensor, |
|
softmax_scale: float, |
|
is_causal_: bool, |
|
tile_scheduler_metadata: torch.Tensor, |
|
num_splits: torch.Tensor, |
|
) -> torch.Tensor: |
|
return ops.mha_fwd_kvcache_mla( |
|
q, |
|
kcache, |
|
vcache_, |
|
head_size_v, |
|
seqlens_k, |
|
block_table, |
|
softmax_scale, |
|
is_causal_, |
|
tile_scheduler_metadata, |
|
num_splits |
|
) |
|
|