flash-mla / torch-ext /torch_binding.h
drbh
fix: adjust sig types
59bdff8
raw
history blame
623 Bytes
#pragma once
#include <torch/torch.h>
std::vector<torch::Tensor>
get_mla_metadata(
torch::Tensor &seqlens_k,
const int64_t num_heads_per_head_k,
const int64_t num_heads_k
);
std::vector<torch::Tensor>
mha_fwd_kvcache_mla(
torch::Tensor &q,
const torch::Tensor &kcache,
const c10::optional<torch::Tensor> &vcache_,
const int64_t head_size_v,
const torch::Tensor &seqlens_k,
const torch::Tensor &block_table,
// TODO:should be float
const torch::kFloat softmax_scale,
bool is_causal,
const torch::Tensor &tile_scheduler_metadata,
const torch::Tensor &num_splits
);