File size: 623 Bytes
1f83cde 59bdff8 1f83cde 59bdff8 1f83cde d76b04d 1f83cde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
#pragma once
#include <torch/torch.h>
std::vector<torch::Tensor>
get_mla_metadata(
torch::Tensor &seqlens_k,
const int64_t num_heads_per_head_k,
const int64_t num_heads_k
);
std::vector<torch::Tensor>
mha_fwd_kvcache_mla(
torch::Tensor &q,
const torch::Tensor &kcache,
const c10::optional<torch::Tensor> &vcache_,
const int64_t head_size_v,
const torch::Tensor &seqlens_k,
const torch::Tensor &block_table,
// TODO:should be float
const torch::kFloat softmax_scale,
bool is_causal,
const torch::Tensor &tile_scheduler_metadata,
const torch::Tensor &num_splits
); |