flash-mla / torch-ext /torch_binding.h
drbh
feat: build flash mla with kernel builder
1f83cde
raw
history blame
799 Bytes
#pragma once
#include <torch/torch.h>
std::vector<torch::Tensor>
get_mla_metadata(
torch::Tensor &seqlens_k,
const int64_t num_heads_per_head_k,
const int64_t num_heads_k
);
std::vector<torch::Tensor>
mha_fwd_kvcache_mla(
torch::Tensor &q,
const torch::Tensor &kcache,
// TODO: fix for optional
// std::optional<torch::Tensor> &vcache_,
const torch::Tensor &vcache_,
const int64_t head_size_v,
const torch::Tensor &seqlens_k,
const torch::Tensor &block_table,
// TODO:should be float
const double softmax_scale,
// TODO: fix for mutable bool
const bool is_causal_,
const torch::Tensor &tile_scheduler_metadata,
const torch::Tensor &num_splits,
// TODO: remove when resolved
const int64_t unknown_param = 0
);