import torch | |
from ._ops import ops | |
def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): | |
return ops.get_mla_metadata(seqlens_k, s_q, h_kv) | |
def mha_fwd_kvcache_mla( | |
q: torch.Tensor, | |
kcache: torch.Tensor, | |
vcache_: torch.Tensor, | |
head_size_v: int, | |
seqlens_k: torch.Tensor, | |
block_table: torch.Tensor, | |
softmax_scale: float, | |
is_causal_: bool, | |
tile_scheduler_metadata: torch.Tensor, | |
num_splits: torch.Tensor, | |
) -> torch.Tensor: | |
return ops.mha_fwd_kvcache_mla( | |
q, | |
kcache, | |
vcache_, | |
head_size_v, | |
seqlens_k, | |
block_table, | |
softmax_scale, | |
is_causal_, | |
tile_scheduler_metadata, | |
num_splits | |
) | |