|
#pragma once |
|
|
|
#include <torch/torch.h> |
|
|
|
std::vector<torch::Tensor> |
|
get_mla_metadata( |
|
torch::Tensor &seqlens_k, |
|
const int64_t num_heads_per_head_k, |
|
const int64_t num_heads_k |
|
); |
|
|
|
std::vector<torch::Tensor> |
|
mha_fwd_kvcache_mla( |
|
torch::Tensor &q, |
|
const torch::Tensor &kcache, |
|
const c10::optional<torch::Tensor> &vcache_, |
|
const int64_t head_size_v, |
|
const torch::Tensor &seqlens_k, |
|
const torch::Tensor &block_table, |
|
|
|
const torch::kFloat softmax_scale, |
|
bool is_causal, |
|
const torch::Tensor &tile_scheduler_metadata, |
|
const torch::Tensor &num_splits |
|
); |