drbh
commited on
Commit
·
5cb0596
1
Parent(s):
8acf152
fix: readability refactors
Browse files- flash_mla/flash_mla_api.cu +6 -43
flash_mla/flash_mla_api.cu
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
#include <ATen/cuda/CUDAContext.h>
|
2 |
#include <c10/cuda/CUDAGuard.h>
|
3 |
#include <torch/all.h>
|
4 |
-
|
5 |
-
|
6 |
#include <cutlass/fast_math.h>
|
7 |
|
8 |
#include "flash_mla.h"
|
@@ -12,42 +10,6 @@
|
|
12 |
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
13 |
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
14 |
|
15 |
-
|
16 |
-
//
|
17 |
-
|
18 |
-
|
19 |
-
// #include <cmath>
|
20 |
-
|
21 |
-
// #include "cute/tensor.hpp"
|
22 |
-
#include <cute/tensor.hpp>
|
23 |
-
|
24 |
-
// __global__ void relu_kernel(float *__restrict__ out,
|
25 |
-
// float const *__restrict__ input,
|
26 |
-
// const int d) {
|
27 |
-
// const int64_t token_idx = blockIdx.x;
|
28 |
-
// for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
29 |
-
// auto x = input[token_idx * d + idx];
|
30 |
-
// out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
|
31 |
-
// }
|
32 |
-
// }
|
33 |
-
|
34 |
-
// void relu(torch::Tensor &out,
|
35 |
-
// torch::Tensor const &input)
|
36 |
-
// {
|
37 |
-
// TORCH_CHECK(input.scalar_type() == at::ScalarType::Float &&
|
38 |
-
// input.scalar_type() == at::ScalarType::Float,
|
39 |
-
// "relu_kernel only supports float32");
|
40 |
-
|
41 |
-
// int d = input.size(-1);
|
42 |
-
// int64_t num_tokens = input.numel() / d;
|
43 |
-
// dim3 grid(num_tokens);
|
44 |
-
// dim3 block(std::min(d, 1024));
|
45 |
-
// const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
46 |
-
// const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
47 |
-
// relu_kernel<<<grid, block, 0, stream>>>(out.data_ptr<float>(),
|
48 |
-
// input.data_ptr<float>(), d);
|
49 |
-
// }
|
50 |
-
|
51 |
std::vector<at::Tensor>
|
52 |
get_mla_metadata(
|
53 |
at::Tensor &seqlens_k,
|
@@ -98,16 +60,17 @@ mha_fwd_kvcache_mla(
|
|
98 |
|
99 |
// TODO: fix for optional
|
100 |
// std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
|
|
101 |
|
102 |
-
const at::Tensor &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
103 |
const int64_t head_size_v,
|
104 |
-
const at::Tensor &seqlens_k,
|
105 |
-
const at::Tensor &block_table,
|
|
|
106 |
// TODO: should be float
|
107 |
const double softmax_scale,
|
108 |
const bool is_causal_,
|
109 |
-
const at::Tensor &tile_scheduler_metadata,
|
110 |
-
const at::Tensor &num_splits,
|
111 |
|
112 |
// TODO: remove this once determined why build is adding this parameter
|
113 |
const int64_t unknown_param
|
|
|
1 |
#include <ATen/cuda/CUDAContext.h>
|
2 |
#include <c10/cuda/CUDAGuard.h>
|
3 |
#include <torch/all.h>
|
|
|
|
|
4 |
#include <cutlass/fast_math.h>
|
5 |
|
6 |
#include "flash_mla.h"
|
|
|
10 |
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
11 |
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
std::vector<at::Tensor>
|
14 |
get_mla_metadata(
|
15 |
at::Tensor &seqlens_k,
|
|
|
60 |
|
61 |
// TODO: fix for optional
|
62 |
// std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
63 |
+
const at::Tensor &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
64 |
|
|
|
65 |
const int64_t head_size_v,
|
66 |
+
const at::Tensor &seqlens_k, // batch_size
|
67 |
+
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
68 |
+
|
69 |
// TODO: should be float
|
70 |
const double softmax_scale,
|
71 |
const bool is_causal_,
|
72 |
+
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
73 |
+
const at::Tensor &num_splits, // batch_size + 1
|
74 |
|
75 |
// TODO: remove this once determined why build is adding this parameter
|
76 |
const int64_t unknown_param
|