File size: 726 Bytes
1f83cde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  ops.def("get_mla_metadata(Tensor! seqlens_k, int num_heads_per_head_k, int num_heads_k) -> Tensor[]");
  ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);

  // TOOD: remove last unknown_param when resolved
  ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor! vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits, int unknown_param) -> Tensor[]");
  ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)