#pragma once #include std::vector get_mla_metadata( torch::Tensor &seqlens_k, const int64_t num_heads_per_head_k, const int64_t num_heads_k ); std::vector mha_fwd_kvcache_mla( torch::Tensor &q, const torch::Tensor &kcache, // TODO: fix for optional // std::optional &vcache_, const torch::Tensor &vcache_, const int64_t head_size_v, const torch::Tensor &seqlens_k, const torch::Tensor &block_table, // TODO:should be float const double softmax_scale, // TODO: fix for mutable bool const bool is_causal_, const torch::Tensor &tile_scheduler_metadata, const torch::Tensor &num_splits, // TODO: remove when resolved const int64_t unknown_param = 0 );