drbh
feat: rebuild outputs
f475609
import torch
from ._ops import ops
def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int):
return ops.get_mla_metadata(seqlens_k, s_q, h_kv)
def mha_fwd_kvcache_mla(
q: torch.Tensor,
kcache: torch.Tensor,
vcache_: torch.Tensor,
head_size_v: int,
seqlens_k: torch.Tensor,
block_table: torch.Tensor,
softmax_scale: float,
is_causal_: bool,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
) -> torch.Tensor:
return ops.mha_fwd_kvcache_mla(
q,
kcache,
vcache_,
head_size_v,
seqlens_k,
block_table,
softmax_scale,
is_causal_,
tile_scheduler_metadata,
num_splits
)