import torch from ._ops import ops def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): return ops.get_mla_metadata(seqlens_k, s_q, h_kv) def mha_fwd_kvcache_mla( q: torch.Tensor, kcache: torch.Tensor, vcache_: torch.Tensor, head_size_v: int, seqlens_k: torch.Tensor, block_table: torch.Tensor, softmax_scale: float, is_causal_: bool, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, ) -> torch.Tensor: return ops.mha_fwd_kvcache_mla( q, kcache, vcache_, head_size_v, seqlens_k, block_table, softmax_scale, is_causal_, tile_scheduler_metadata, num_splits )