drbh
commited on
Commit
·
5cb0596
1
Parent(s):
8acf152
fix: readability refactors
Browse files- flash_mla/flash_mla_api.cu +6 -43
flash_mla/flash_mla_api.cu
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
#include <ATen/cuda/CUDAContext.h>
|
| 2 |
#include <c10/cuda/CUDAGuard.h>
|
| 3 |
#include <torch/all.h>
|
| 4 |
-
|
| 5 |
-
|
| 6 |
#include <cutlass/fast_math.h>
|
| 7 |
|
| 8 |
#include "flash_mla.h"
|
|
@@ -12,42 +10,6 @@
|
|
| 12 |
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 13 |
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 14 |
|
| 15 |
-
|
| 16 |
-
//
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
// #include <cmath>
|
| 20 |
-
|
| 21 |
-
// #include "cute/tensor.hpp"
|
| 22 |
-
#include <cute/tensor.hpp>
|
| 23 |
-
|
| 24 |
-
// __global__ void relu_kernel(float *__restrict__ out,
|
| 25 |
-
// float const *__restrict__ input,
|
| 26 |
-
// const int d) {
|
| 27 |
-
// const int64_t token_idx = blockIdx.x;
|
| 28 |
-
// for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 29 |
-
// auto x = input[token_idx * d + idx];
|
| 30 |
-
// out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
|
| 31 |
-
// }
|
| 32 |
-
// }
|
| 33 |
-
|
| 34 |
-
// void relu(torch::Tensor &out,
|
| 35 |
-
// torch::Tensor const &input)
|
| 36 |
-
// {
|
| 37 |
-
// TORCH_CHECK(input.scalar_type() == at::ScalarType::Float &&
|
| 38 |
-
// input.scalar_type() == at::ScalarType::Float,
|
| 39 |
-
// "relu_kernel only supports float32");
|
| 40 |
-
|
| 41 |
-
// int d = input.size(-1);
|
| 42 |
-
// int64_t num_tokens = input.numel() / d;
|
| 43 |
-
// dim3 grid(num_tokens);
|
| 44 |
-
// dim3 block(std::min(d, 1024));
|
| 45 |
-
// const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 46 |
-
// const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 47 |
-
// relu_kernel<<<grid, block, 0, stream>>>(out.data_ptr<float>(),
|
| 48 |
-
// input.data_ptr<float>(), d);
|
| 49 |
-
// }
|
| 50 |
-
|
| 51 |
std::vector<at::Tensor>
|
| 52 |
get_mla_metadata(
|
| 53 |
at::Tensor &seqlens_k,
|
|
@@ -98,16 +60,17 @@ mha_fwd_kvcache_mla(
|
|
| 98 |
|
| 99 |
// TODO: fix for optional
|
| 100 |
// std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
|
|
|
| 101 |
|
| 102 |
-
const at::Tensor &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
| 103 |
const int64_t head_size_v,
|
| 104 |
-
const at::Tensor &seqlens_k,
|
| 105 |
-
const at::Tensor &block_table,
|
|
|
|
| 106 |
// TODO: should be float
|
| 107 |
const double softmax_scale,
|
| 108 |
const bool is_causal_,
|
| 109 |
-
const at::Tensor &tile_scheduler_metadata,
|
| 110 |
-
const at::Tensor &num_splits,
|
| 111 |
|
| 112 |
// TODO: remove this once determined why build is adding this parameter
|
| 113 |
const int64_t unknown_param
|
|
|
|
| 1 |
#include <ATen/cuda/CUDAContext.h>
|
| 2 |
#include <c10/cuda/CUDAGuard.h>
|
| 3 |
#include <torch/all.h>
|
|
|
|
|
|
|
| 4 |
#include <cutlass/fast_math.h>
|
| 5 |
|
| 6 |
#include "flash_mla.h"
|
|
|
|
| 10 |
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 11 |
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
std::vector<at::Tensor>
|
| 14 |
get_mla_metadata(
|
| 15 |
at::Tensor &seqlens_k,
|
|
|
|
| 60 |
|
| 61 |
// TODO: fix for optional
|
| 62 |
// std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
| 63 |
+
const at::Tensor &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
| 64 |
|
|
|
|
| 65 |
const int64_t head_size_v,
|
| 66 |
+
const at::Tensor &seqlens_k, // batch_size
|
| 67 |
+
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
| 68 |
+
|
| 69 |
// TODO: should be float
|
| 70 |
const double softmax_scale,
|
| 71 |
const bool is_causal_,
|
| 72 |
+
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
| 73 |
+
const at::Tensor &num_splits, // batch_size + 1
|
| 74 |
|
| 75 |
// TODO: remove this once determined why build is adding this parameter
|
| 76 |
const int64_t unknown_param
|