drbh commited on
Commit
5cb0596
·
1 Parent(s): 8acf152

fix: readability refactors

Browse files
Files changed (1) hide show
  1. flash_mla/flash_mla_api.cu +6 -43
flash_mla/flash_mla_api.cu CHANGED
@@ -1,8 +1,6 @@
1
  #include <ATen/cuda/CUDAContext.h>
2
  #include <c10/cuda/CUDAGuard.h>
3
  #include <torch/all.h>
4
-
5
-
6
  #include <cutlass/fast_math.h>
7
 
8
  #include "flash_mla.h"
@@ -12,42 +10,6 @@
12
  #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13
  #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14
 
15
-
16
- //
17
-
18
-
19
- // #include <cmath>
20
-
21
- // #include "cute/tensor.hpp"
22
- #include <cute/tensor.hpp>
23
-
24
- // __global__ void relu_kernel(float *__restrict__ out,
25
- // float const *__restrict__ input,
26
- // const int d) {
27
- // const int64_t token_idx = blockIdx.x;
28
- // for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
29
- // auto x = input[token_idx * d + idx];
30
- // out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
31
- // }
32
- // }
33
-
34
- // void relu(torch::Tensor &out,
35
- // torch::Tensor const &input)
36
- // {
37
- // TORCH_CHECK(input.scalar_type() == at::ScalarType::Float &&
38
- // input.scalar_type() == at::ScalarType::Float,
39
- // "relu_kernel only supports float32");
40
-
41
- // int d = input.size(-1);
42
- // int64_t num_tokens = input.numel() / d;
43
- // dim3 grid(num_tokens);
44
- // dim3 block(std::min(d, 1024));
45
- // const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
46
- // const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
47
- // relu_kernel<<<grid, block, 0, stream>>>(out.data_ptr<float>(),
48
- // input.data_ptr<float>(), d);
49
- // }
50
-
51
  std::vector<at::Tensor>
52
  get_mla_metadata(
53
  at::Tensor &seqlens_k,
@@ -98,16 +60,17 @@ mha_fwd_kvcache_mla(
98
 
99
  // TODO: fix for optional
100
  // std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
 
101
 
102
- const at::Tensor &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
103
  const int64_t head_size_v,
104
- const at::Tensor &seqlens_k, // batch_size
105
- const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
 
106
  // TODO: should be float
107
  const double softmax_scale,
108
  const bool is_causal_,
109
- const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
110
- const at::Tensor &num_splits, // batch_size + 1
111
 
112
  // TODO: remove this once determined why build is adding this parameter
113
  const int64_t unknown_param
 
1
  #include <ATen/cuda/CUDAContext.h>
2
  #include <c10/cuda/CUDAGuard.h>
3
  #include <torch/all.h>
 
 
4
  #include <cutlass/fast_math.h>
5
 
6
  #include "flash_mla.h"
 
10
  #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
11
  #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  std::vector<at::Tensor>
14
  get_mla_metadata(
15
  at::Tensor &seqlens_k,
 
60
 
61
  // TODO: fix for optional
62
  // std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
63
+ const at::Tensor &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
64
 
 
65
  const int64_t head_size_v,
66
+ const at::Tensor &seqlens_k, // batch_size
67
+ const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
68
+
69
  // TODO: should be float
70
  const double softmax_scale,
71
  const bool is_causal_,
72
+ const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
73
+ const at::Tensor &num_splits, // batch_size + 1
74
 
75
  // TODO: remove this once determined why build is adding this parameter
76
  const int64_t unknown_param