kernel
danieldk HF Staff commited on
Commit
01fbc17
·
1 Parent(s): fbbc5a8

Sync with vLLM and add `Llama4TextMoe` layer

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build.toml +2 -0
  2. moe/moe_align_sum_kernels.cu +39 -17
  3. moe/moe_wna16.cu +346 -0
  4. moe/moe_wna16_utils.h +200 -0
  5. tests/kernels/test_moe.py +229 -125
  6. tests/kernels/utils.py +7 -7
  7. torch-ext/moe/__init__.py +2 -18
  8. torch-ext/moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  9. torch-ext/moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json +200 -0
  10. torch-ext/moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  11. torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  12. torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  13. torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  14. torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  15. torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  16. torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  17. torch-ext/moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  18. torch-ext/moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  19. torch-ext/moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  20. torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  21. torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  22. torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  23. torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  24. torch-ext/moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  25. torch-ext/moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  26. torch-ext/moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json +200 -0
  27. torch-ext/moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json +200 -0
  28. torch-ext/moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json +200 -0
  29. torch-ext/moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json +200 -0
  30. torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  31. torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  32. torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  34. torch-ext/moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  35. torch-ext/moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  36. torch-ext/moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  37. torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  38. torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  39. torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  40. torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  41. torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  42. torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  43. torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  44. torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  45. torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  46. torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  47. torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +52 -52
  48. torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  49. torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  50. torch-ext/moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
build.toml CHANGED
@@ -31,6 +31,8 @@ src = [
31
  "cuda_compat.h",
32
  "dispatch_utils.h",
33
  "moe/moe_align_sum_kernels.cu",
 
 
34
  "moe/topk_softmax_kernels.cu",
35
  ]
36
  depends = ["torch"]
 
31
  "cuda_compat.h",
32
  "dispatch_utils.h",
33
  "moe/moe_align_sum_kernels.cu",
34
+ "moe/moe_wna16.cu",
35
+ "moe/moe_wna16_utils.h",
36
  "moe/topk_softmax_kernels.cu",
37
  ]
38
  depends = ["torch"]
moe/moe_align_sum_kernels.cu CHANGED
@@ -3,7 +3,7 @@
3
  #include <c10/cuda/CUDAGuard.h>
4
 
5
  #include <ATen/ATen.h>
6
- #include <THC/THCAtomics.cuh>
7
 
8
  #include "../cuda_compat.h"
9
  #include "../dispatch_utils.h"
@@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel(
198
  }
199
 
200
  // taken from
201
- // https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
202
  template <typename scalar_t>
203
  __global__ void sgl_moe_align_block_size_kernel(
204
  scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
205
  int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
206
  int32_t block_size, size_t numel, int32_t* cumsum) {
207
  __shared__ int32_t shared_counts[32][8];
208
- __shared__ int32_t local_offsets[256];
209
 
210
- const int warp_id = threadIdx.x / WARP_SIZE;
211
- const int lane_id = threadIdx.x % WARP_SIZE;
212
  const int experts_per_warp = 8;
213
  const int my_expert_start = warp_id * experts_per_warp;
214
 
 
215
  for (int i = 0; i < experts_per_warp; ++i) {
216
  if (my_expert_start + i < num_experts) {
217
  shared_counts[warp_id][i] = 0;
218
  }
219
  }
220
 
 
 
221
  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
222
  const size_t start_idx = threadIdx.x * tokens_per_thread;
223
 
@@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel(
230
 
231
  __syncthreads();
232
 
 
233
  if (threadIdx.x == 0) {
234
  cumsum[0] = 0;
235
  for (int i = 1; i <= num_experts; ++i) {
@@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel(
246
 
247
  __syncthreads();
248
 
 
249
  if (threadIdx.x < num_experts) {
250
  for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
251
  i += block_size) {
252
  expert_ids[i / block_size] = threadIdx.x;
253
  }
254
- local_offsets[threadIdx.x] = cumsum[threadIdx.x];
255
  }
 
256
 
257
- __syncthreads();
258
-
259
- for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
 
 
 
 
 
 
 
 
260
  int32_t expert_id = topk_ids[i];
261
- int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
262
  sorted_token_ids[rank_post_pad] = i;
263
  }
264
  }
@@ -377,23 +388,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
377
  torch::Tensor experts_ids,
378
  torch::Tensor num_tokens_post_pad) {
379
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
 
 
380
  VLLM_DISPATCH_INTEGRAL_TYPES(
381
  topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
382
- // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
383
- // tensors
384
  auto options_int =
385
  torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
386
- // torch::Tensor token_cnts_buffer =
387
- // torch::empty({(num_experts + 1) * num_experts}, options_int);
388
  torch::Tensor cumsum_buffer =
389
- torch::empty({num_experts + 1}, options_int);
390
 
391
- auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
392
- kernel<<<1, 1024, 0, stream>>>(
 
393
  topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
394
  experts_ids.data_ptr<int32_t>(),
395
  num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
396
  topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
 
 
 
 
 
 
 
 
 
 
397
  });
398
  }
399
 
 
3
  #include <c10/cuda/CUDAGuard.h>
4
 
5
  #include <ATen/ATen.h>
6
+ #include <ATen/cuda/Atomic.cuh>
7
 
8
  #include "../cuda_compat.h"
9
  #include "../dispatch_utils.h"
 
198
  }
199
 
200
  // taken from
201
+ // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
202
  template <typename scalar_t>
203
  __global__ void sgl_moe_align_block_size_kernel(
204
  scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
205
  int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
206
  int32_t block_size, size_t numel, int32_t* cumsum) {
207
  __shared__ int32_t shared_counts[32][8];
 
208
 
209
+ const int warp_id = threadIdx.x / 32;
 
210
  const int experts_per_warp = 8;
211
  const int my_expert_start = warp_id * experts_per_warp;
212
 
213
+ // Initialize shared_counts for this warp's experts
214
  for (int i = 0; i < experts_per_warp; ++i) {
215
  if (my_expert_start + i < num_experts) {
216
  shared_counts[warp_id][i] = 0;
217
  }
218
  }
219
 
220
+ __syncthreads();
221
+
222
  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
223
  const size_t start_idx = threadIdx.x * tokens_per_thread;
224
 
 
231
 
232
  __syncthreads();
233
 
234
+ // Single thread computes cumulative sum and total tokens
235
  if (threadIdx.x == 0) {
236
  cumsum[0] = 0;
237
  for (int i = 1; i <= num_experts; ++i) {
 
248
 
249
  __syncthreads();
250
 
251
+ // Assign expert IDs to blocks
252
  if (threadIdx.x < num_experts) {
253
  for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
254
  i += block_size) {
255
  expert_ids[i / block_size] = threadIdx.x;
256
  }
 
257
  }
258
+ }
259
 
260
+ // taken from
261
+ // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
262
+ template <typename scalar_t>
263
+ __global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids,
264
+ int32_t* sorted_token_ids,
265
+ int32_t* cumsum_buffer,
266
+ size_t numel) {
267
+ const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
268
+ const size_t stride = blockDim.x * gridDim.x;
269
+
270
+ for (size_t i = tid; i < numel; i += stride) {
271
  int32_t expert_id = topk_ids[i];
272
+ int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
273
  sorted_token_ids[rank_post_pad] = i;
274
  }
275
  }
 
388
  torch::Tensor experts_ids,
389
  torch::Tensor num_tokens_post_pad) {
390
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
391
+ TORCH_CHECK(num_experts == 256,
392
+ "sgl_moe_align_block_size kernel only supports deepseek v3.");
393
+
394
  VLLM_DISPATCH_INTEGRAL_TYPES(
395
  topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
396
+ // calc needed amount of shared mem for `cumsum` tensors
 
397
  auto options_int =
398
  torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
 
 
399
  torch::Tensor cumsum_buffer =
400
+ torch::zeros({num_experts + 1}, options_int);
401
 
402
+ auto align_kernel =
403
+ vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
404
+ align_kernel<<<1, 1024, 0, stream>>>(
405
  topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
406
  experts_ids.data_ptr<int32_t>(),
407
  num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
408
  topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
409
+
410
+ const int block_threads = 256;
411
+ const int num_blocks =
412
+ (topk_ids.numel() + block_threads - 1) / block_threads;
413
+ const int max_blocks = 65535;
414
+ const int actual_blocks = std::min(num_blocks, max_blocks);
415
+ auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel<scalar_t>;
416
+ sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
417
+ topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
418
+ cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
419
  });
420
  }
421
 
moe/moe_wna16.cu ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <torch/all.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+ #include <ATen/cuda/CUDAContext.h>
5
+ #include <cuda_runtime.h>
6
+
7
+ #include <cuda_fp16.h>
8
+ #include <cuda_bf16.h>
9
+ #include "moe_wna16_utils.h"
10
+
11
+ #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
12
+
13
+ template <typename scalar_t, int bit, int GROUPS>
14
+ __global__ void moe_wna16_gemm_kernel(
15
+ const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
16
+
17
+ const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
18
+ const uint32_t* __restrict__ qzeros,
19
+
20
+ const float* __restrict__ topk_weights,
21
+ const int32_t* __restrict__ sorted_token_ids,
22
+ const int32_t* __restrict__ expert_ids,
23
+ const int32_t* __restrict__ num_tokens_post_pad,
24
+
25
+ uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m,
26
+ uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M,
27
+ uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp,
28
+ bool mul_topk_weight) {
29
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
30
+ if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
31
+ return;
32
+ } else {
33
+ #endif
34
+
35
+ using Dtype = ScalarType<scalar_t>;
36
+ using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
37
+
38
+ if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return;
39
+
40
+ const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x;
41
+ const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K;
42
+
43
+ const int32_t expert_id = expert_ids[blockIdx.x];
44
+
45
+ int32_t num_valid_tokens = 0;
46
+ extern __shared__ uint16_t block_input_tmp[];
47
+ scalar_t* block_input = reinterpret_cast<scalar_t*>(block_input_tmp);
48
+ scalar_t2* block_input_half2 = reinterpret_cast<scalar_t2*>(block_input);
49
+
50
+ // load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory
51
+ for (int m = 0; m < BLOCK_SIZE_M; m++) {
52
+ const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m;
53
+ const int32_t token_index = sorted_token_ids[offset_m];
54
+ if (token_index / top_k >= size_m) break;
55
+
56
+ num_valid_tokens = m + 1;
57
+ if (blockIdx.z == 0 && offset_n < size_n)
58
+ output[token_index * size_n + offset_n] = Dtype::int2num(0);
59
+
60
+ if (expert_id != -1) {
61
+ int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
62
+ for (int i = 0; i < k_per_thread; i++) {
63
+ int k = BLOCK_SIZE_N * i + threadIdx.x;
64
+ if (k >= BLOCK_SIZE_K) break;
65
+ if (offset_k + k >= size_k) break;
66
+
67
+ // load input to shared memory
68
+ // use a special layout to fit the layout of dequanted-weight
69
+ int origin_k;
70
+ if constexpr (bit == 4) {
71
+ // [0, 4, 1, 5, 2, 6, 3, 7]
72
+ int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2);
73
+ origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order;
74
+ } else {
75
+ // [0, 2, 1, 3]
76
+ int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2);
77
+ origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order;
78
+ }
79
+
80
+ origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K;
81
+ block_input[m * BLOCK_SIZE_K + k] = input[origin_k];
82
+ }
83
+ }
84
+ }
85
+
86
+ if (expert_id == -1) return;
87
+ __syncthreads();
88
+ if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return;
89
+
90
+ float res[64]; // assume BLOCK_SIZE_M <= 64
91
+ scalar_t2 res2;
92
+ scalar_t2 scale_f2;
93
+ scalar_t2 qzero_f2;
94
+
95
+ // note that (size_n * size_k * expert_id) may greater than 2 ** 31
96
+ constexpr int8_t pack_factor = 32 / bit;
97
+ const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id;
98
+ const uint32_t* expert_qweight = qweight + expert_offset / pack_factor;
99
+ const scalar_t* expert_scales = scales + expert_offset / group_size;
100
+ const uint32_t* expert_qzeros =
101
+ qzeros + expert_offset / group_size / pack_factor;
102
+
103
+ // load 4*int32 one time: 4 int32 = 128 bit = 1 float4
104
+ // weight would be loaded in loop
105
+ uint32_t expert_qweight_tmp[4];
106
+ float4* expert_qweight_tmp_float4 =
107
+ reinterpret_cast<float4*>(expert_qweight_tmp);
108
+
109
+ // load all required scales one time
110
+ scalar_t expert_scales_groups[GROUPS];
111
+ int scales_offset_tmp =
112
+ (offset_n * size_k + offset_k) / group_size / GROUPS;
113
+ if constexpr (GROUPS == 1) {
114
+ *expert_scales_groups = expert_scales[scales_offset_tmp];
115
+ } else if constexpr (GROUPS == 2) {
116
+ float* expert_scales_groups_tmp =
117
+ reinterpret_cast<float*>(expert_scales_groups);
118
+ *expert_scales_groups_tmp =
119
+ reinterpret_cast<const float*>(expert_scales)[scales_offset_tmp];
120
+ } else if constexpr (GROUPS == 4) {
121
+ float2* expert_scales_groups_tmp =
122
+ reinterpret_cast<float2*>(expert_scales_groups);
123
+ *expert_scales_groups_tmp =
124
+ reinterpret_cast<const float2*>(expert_scales)[scales_offset_tmp];
125
+ } else if constexpr (GROUPS == 8) {
126
+ float4* expert_scales_groups_tmp =
127
+ reinterpret_cast<float4*>(expert_scales_groups);
128
+ *expert_scales_groups_tmp =
129
+ reinterpret_cast<const float4*>(expert_scales)[scales_offset_tmp];
130
+ }
131
+
132
+ // load all required qzeros one time
133
+ uint8_t expert_qzeros_groups[GROUPS];
134
+ if (!has_zp) {
135
+ if constexpr (bit == 4) {
136
+ qzero_f2 = Dtype::num2num2(Dtype::int2num(8));
137
+ } else {
138
+ qzero_f2 = Dtype::num2num2(Dtype::int2num(128));
139
+ }
140
+ } else {
141
+ int qzeros_offset_tmp =
142
+ (offset_n / (8 / bit)) * (size_k / group_size / GROUPS) +
143
+ offset_k / group_size / GROUPS;
144
+ if constexpr (GROUPS == 1) {
145
+ uint8_t* expert_qzeros_groups_tmp =
146
+ reinterpret_cast<uint8_t*>(expert_qzeros_groups);
147
+ *expert_qzeros_groups_tmp =
148
+ reinterpret_cast<const uint8_t*>(expert_qzeros)[qzeros_offset_tmp];
149
+ } else if constexpr (GROUPS == 2) {
150
+ uint16_t* expert_qzeros_groups_tmp =
151
+ reinterpret_cast<uint16_t*>(expert_qzeros_groups);
152
+ *expert_qzeros_groups_tmp =
153
+ reinterpret_cast<const uint16_t*>(expert_qzeros)[qzeros_offset_tmp];
154
+ } else if constexpr (GROUPS == 4) {
155
+ uint32_t* expert_qzeros_groups_tmp =
156
+ reinterpret_cast<uint32_t*>(expert_qzeros_groups);
157
+ *expert_qzeros_groups_tmp =
158
+ reinterpret_cast<const uint32_t*>(expert_qzeros)[qzeros_offset_tmp];
159
+ } else if constexpr (GROUPS == 8) {
160
+ uint64_t* expert_qzeros_groups_tmp =
161
+ reinterpret_cast<uint64_t*>(expert_qzeros_groups);
162
+ *expert_qzeros_groups_tmp =
163
+ reinterpret_cast<const uint64_t*>(expert_qzeros)[qzeros_offset_tmp];
164
+ }
165
+ }
166
+
167
+ for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) {
168
+ int k = offset_k + tmp_k * pack_factor;
169
+ if (k >= size_k) break;
170
+ const int32_t weight_offset = offset_n * size_k + k;
171
+
172
+ if (tmp_k % 4 == 0) {
173
+ *expert_qweight_tmp_float4 = reinterpret_cast<const float4*>(
174
+ expert_qweight)[weight_offset / pack_factor / 4];
175
+ }
176
+
177
+ if (tmp_k % (group_size / pack_factor) == 0) {
178
+ scalar_t scale_f =
179
+ expert_scales_groups[tmp_k / (group_size / pack_factor)];
180
+ scale_f2 = Dtype::num2num2(scale_f);
181
+
182
+ if (has_zp) {
183
+ uint8_t qzero =
184
+ expert_qzeros_groups[tmp_k / (group_size / pack_factor)];
185
+ if constexpr (bit == 4) {
186
+ qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF;
187
+ }
188
+ qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero));
189
+ }
190
+ }
191
+
192
+ scalar_t2 weight_half2[16 / bit];
193
+ dequant<scalar_t2, bit>(expert_qweight_tmp[tmp_k % 4], weight_half2);
194
+
195
+ for (int m = 0; m < num_valid_tokens; m++) {
196
+ res2 = {};
197
+
198
+ #pragma unroll
199
+ for (int i = 0; i < 16 / bit; i++) {
200
+ int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i;
201
+ res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2),
202
+ block_input_half2[offset_input], res2);
203
+ }
204
+
205
+ if (tmp_k == 0) {
206
+ res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
207
+ } else {
208
+ res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
209
+ }
210
+ }
211
+ }
212
+
213
+ for (int m = 0; m < num_valid_tokens; ++m) {
214
+ const int32_t token_index =
215
+ sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m];
216
+ if (mul_topk_weight) {
217
+ res[m] *= topk_weights[token_index];
218
+ }
219
+ atomicAdd(&output[token_index * size_n + offset_n],
220
+ Dtype::float2num(res[m]));
221
+ }
222
+
223
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
224
+ }
225
+ #endif
226
+ }
227
+
228
+ template <typename scalar_t>
229
+ void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output,
230
+ const uint32_t* b_qweight, const scalar_t* b_scales,
231
+ const uint32_t* b_qzeros, const float* topk_weights,
232
+ const int32_t* sorted_token_ids,
233
+ const int32_t* expert_ids,
234
+ const int32_t* num_tokens_post_pad, int num_experts,
235
+ int group_size, int num_token_blocks, int top_k,
236
+ int size_m, int size_n, int size_k, int BLOCK_SIZE_M,
237
+ int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit,
238
+ bool has_zp, bool mul_topk_weight) {
239
+ dim3 blockDim, gridDim;
240
+ blockDim.x = BLOCK_SIZE_N;
241
+ blockDim.y = 1;
242
+ blockDim.z = 1;
243
+ gridDim.x = num_token_blocks;
244
+ gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N);
245
+ gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K);
246
+
247
+ auto kernel = moe_wna16_gemm_kernel<scalar_t, 4, 1>;
248
+ if (bit == 4) {
249
+ if (BLOCK_SIZE_K / group_size == 2) {
250
+ kernel = moe_wna16_gemm_kernel<scalar_t, 4, 2>;
251
+ } else if (BLOCK_SIZE_K / group_size == 4) {
252
+ kernel = moe_wna16_gemm_kernel<scalar_t, 4, 4>;
253
+ } else if (BLOCK_SIZE_K / group_size == 8) {
254
+ kernel = moe_wna16_gemm_kernel<scalar_t, 4, 8>;
255
+ }
256
+ } else {
257
+ if (BLOCK_SIZE_K / group_size == 1) {
258
+ kernel = moe_wna16_gemm_kernel<scalar_t, 8, 1>;
259
+ } else if (BLOCK_SIZE_K / group_size == 2) {
260
+ kernel = moe_wna16_gemm_kernel<scalar_t, 8, 2>;
261
+ } else if (BLOCK_SIZE_K / group_size == 4) {
262
+ kernel = moe_wna16_gemm_kernel<scalar_t, 8, 4>;
263
+ } else if (BLOCK_SIZE_K / group_size == 8) {
264
+ kernel = moe_wna16_gemm_kernel<scalar_t, 8, 8>;
265
+ }
266
+ }
267
+
268
+ const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2;
269
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
270
+ kernel<<<gridDim, blockDim, shared_mem_size, stream>>>(
271
+ input, output, b_qweight, b_scales, b_qzeros, topk_weights,
272
+ sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts,
273
+ group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N,
274
+ BLOCK_SIZE_K, has_zp, mul_topk_weight);
275
+ }
276
+
277
+ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
278
+ torch::Tensor b_qweight, torch::Tensor b_scales,
279
+ std::optional<torch::Tensor> b_qzeros,
280
+ std::optional<torch::Tensor> topk_weights,
281
+ torch::Tensor sorted_token_ids,
282
+ torch::Tensor expert_ids,
283
+ torch::Tensor num_tokens_post_pad, int64_t top_k,
284
+ int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
285
+ int64_t BLOCK_SIZE_K, int64_t bit) {
286
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
287
+ auto options =
288
+ torch::TensorOptions().dtype(input.dtype()).device(input.device());
289
+
290
+ const int num_experts = b_qweight.size(0);
291
+ const int size_m = input.size(0);
292
+ const int size_n = b_qweight.size(1);
293
+ const int size_k = input.size(1);
294
+ const int group_size = size_k / b_scales.size(2);
295
+
296
+ int64_t EM = sorted_token_ids.size(0);
297
+ if (size_m <= BLOCK_SIZE_M) {
298
+ EM = min(EM, size_m * BLOCK_SIZE_M * top_k);
299
+ }
300
+ const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
301
+
302
+ const uint32_t* b_qzeros_ptr;
303
+ if (b_qzeros.has_value())
304
+ b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
305
+ const float* topk_weights_ptr;
306
+ if (topk_weights.has_value())
307
+ topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
308
+
309
+ int groups_per_block_row = BLOCK_SIZE_K / group_size;
310
+ TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
311
+ TORCH_CHECK(size_k % BLOCK_SIZE_K == 0,
312
+ "size_k must divisible by BLOCK_SIZE_K");
313
+ TORCH_CHECK(BLOCK_SIZE_K % group_size == 0,
314
+ "BLOCK_SIZE_K must divisible by group_size");
315
+ TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64");
316
+ TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 ||
317
+ groups_per_block_row == 4 || groups_per_block_row == 8,
318
+ "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]");
319
+
320
+ if (input.scalar_type() == at::ScalarType::Half) {
321
+ run_moe_wna16_gemm<half>(
322
+ (const half*)input.data_ptr<at::Half>(),
323
+ (half*)output.data_ptr<at::Half>(),
324
+ (const uint32_t*)b_qweight.data_ptr<uint8_t>(),
325
+ (const half*)b_scales.data_ptr<at::Half>(), b_qzeros_ptr,
326
+ topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
327
+ expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
328
+ num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
329
+ size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
330
+ b_qzeros.has_value(), topk_weights.has_value());
331
+ } else if (input.scalar_type() == at::ScalarType::BFloat16) {
332
+ run_moe_wna16_gemm<nv_bfloat16>(
333
+ (const nv_bfloat16*)input.data_ptr<at::BFloat16>(),
334
+ (nv_bfloat16*)output.data_ptr<at::BFloat16>(),
335
+ (const uint32_t*)b_qweight.data_ptr<uint8_t>(),
336
+ (const nv_bfloat16*)b_scales.data_ptr<at::BFloat16>(), b_qzeros_ptr,
337
+ topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
338
+ expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
339
+ num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
340
+ size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
341
+ b_qzeros.has_value(), topk_weights.has_value());
342
+ } else {
343
+ TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16");
344
+ }
345
+ return output;
346
+ }
moe/moe_wna16_utils.h ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <cuda_fp16.h>
3
+ #include <cuda_bf16.h>
4
+
5
+ template <typename scalar_t>
6
+ class ScalarType {};
7
+
8
+ template <>
9
+ class ScalarType<half> {
10
+ public:
11
+ using scalar_t = half;
12
+ using scalar_t2 = half2;
13
+
14
+ static __device__ float inline num2float(const half x) {
15
+ return __half2float(x);
16
+ }
17
+
18
+ static __device__ half2 inline num2num2(const half x) {
19
+ return __half2half2(x);
20
+ }
21
+
22
+ static __device__ half2 inline nums2num2(const half x1, const half x2) {
23
+ return __halves2half2(x1, x2);
24
+ }
25
+
26
+ static __host__ __device__ half inline float2num(const float x) {
27
+ return __float2half(x);
28
+ }
29
+
30
+ static __host__ __device__ half inline int2num(const float x) {
31
+ return __int2half_rn(x);
32
+ }
33
+
34
+ static __host__ __device__ float2 inline num22float2(const half2 x) {
35
+ return __half22float2(x);
36
+ }
37
+
38
+ static __host__ __device__ half2 inline float22num2(const float2 x) {
39
+ return __float22half2_rn(x);
40
+ }
41
+ };
42
+
43
+ template <>
44
+ class ScalarType<nv_bfloat16> {
45
+ public:
46
+ using scalar_t = nv_bfloat16;
47
+ using scalar_t2 = nv_bfloat162;
48
+
49
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
50
+ static __device__ float inline num2float(const nv_bfloat16 x) {
51
+ return __bfloat162float(x);
52
+ }
53
+
54
+ static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
55
+ return __bfloat162bfloat162(x);
56
+ }
57
+
58
+ static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
59
+ const nv_bfloat16 x2) {
60
+ return __halves2bfloat162(x1, x2);
61
+ }
62
+
63
+ static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
64
+ return __float2bfloat16(x);
65
+ }
66
+
67
+ static __host__ __device__ nv_bfloat16 inline int2num(const float x) {
68
+ return __int2bfloat16_rn(x);
69
+ }
70
+
71
+ static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
72
+ return __bfloat1622float2(x);
73
+ }
74
+
75
+ static __host__ __device__ nv_bfloat162 inline float22num2(const float2 x) {
76
+ return __float22bfloat162_rn(x);
77
+ }
78
+ #endif
79
+ };
80
+
81
+ template <int lut>
82
+ __device__ inline int lop3(int a, int b, int c) {
83
+ int res;
84
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
85
+ : "=r"(res)
86
+ : "r"(a), "r"(b), "r"(c), "n"(lut));
87
+ return res;
88
+ }
89
+
90
+ template <int start_byte, int mask>
91
+ __device__ inline uint32_t prmt(uint32_t a) {
92
+ uint32_t res;
93
+ asm volatile("prmt.b32 %0, %1, %2, %3;\n"
94
+ : "=r"(res)
95
+ : "r"(a), "n"(start_byte), "n"(mask));
96
+ return res;
97
+ }
98
+
99
+ template <typename scalar_t2, int bit>
100
+ __device__ inline void dequant(int q, scalar_t2* res) {}
101
+
102
+ template <>
103
+ __device__ inline void dequant<half2, 4>(int q, half2* res) {
104
+ const int LO = 0x000f000f;
105
+ const int HI = 0x00f000f0;
106
+ const int EX = 0x64006400;
107
+ const int SUB = 0x64006400;
108
+ const int MUL = 0x2c002c00;
109
+ const int ADD = 0xd400d400;
110
+
111
+ int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
112
+ int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
113
+ q >>= 8;
114
+ int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
115
+ int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
116
+
117
+ res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0),
118
+ *reinterpret_cast<const half2*>(&SUB));
119
+ res[1] = __hfma2(*reinterpret_cast<half2*>(&hi0),
120
+ *reinterpret_cast<const half2*>(&MUL),
121
+ *reinterpret_cast<const half2*>(&ADD));
122
+ res[2] = __hsub2(*reinterpret_cast<half2*>(&lo1),
123
+ *reinterpret_cast<const half2*>(&SUB));
124
+ res[3] = __hfma2(*reinterpret_cast<half2*>(&hi1),
125
+ *reinterpret_cast<const half2*>(&MUL),
126
+ *reinterpret_cast<const half2*>(&ADD));
127
+ }
128
+
129
+ template <>
130
+ __device__ inline void dequant<half2, 8>(int q, half2* res) {
131
+ static constexpr uint32_t mask_for_elt_01 = 0x5250;
132
+ static constexpr uint32_t mask_for_elt_23 = 0x5351;
133
+ static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
134
+
135
+ uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
136
+ uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
137
+
138
+ static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
139
+
140
+ res[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
141
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
142
+ res[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
143
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
144
+ }
145
+
146
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
147
+ template <>
148
+ __device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
149
+ static constexpr uint32_t MASK = 0x000f000f;
150
+ static constexpr uint32_t EX = 0x43004300;
151
+
152
+ int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
153
+ q >>= 4;
154
+ int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
155
+ q >>= 4;
156
+ int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
157
+ q >>= 4;
158
+ int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
159
+
160
+ static constexpr uint32_t MUL = 0x3F803F80;
161
+ static constexpr uint32_t ADD = 0xC300C300;
162
+
163
+ res[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo0),
164
+ *reinterpret_cast<const nv_bfloat162*>(&MUL),
165
+ *reinterpret_cast<const nv_bfloat162*>(&ADD));
166
+ res[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi0),
167
+ *reinterpret_cast<const nv_bfloat162*>(&MUL),
168
+ *reinterpret_cast<const nv_bfloat162*>(&ADD));
169
+ res[2] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo1),
170
+ *reinterpret_cast<const nv_bfloat162*>(&MUL),
171
+ *reinterpret_cast<const nv_bfloat162*>(&ADD));
172
+ res[3] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi1),
173
+ *reinterpret_cast<const nv_bfloat162*>(&MUL),
174
+ *reinterpret_cast<const nv_bfloat162*>(&ADD));
175
+ }
176
+
177
+ template <>
178
+ __device__ inline void dequant<nv_bfloat162, 8>(int q, nv_bfloat162* res) {
179
+ float fp32_intermediates[4];
180
+ uint32_t* fp32_intermediates_casted =
181
+ reinterpret_cast<uint32_t*>(fp32_intermediates);
182
+
183
+ static constexpr uint32_t fp32_base = 0x4B000000;
184
+ fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
185
+ fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
186
+ fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
187
+ fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
188
+
189
+ fp32_intermediates[0] -= 8388608.f;
190
+ fp32_intermediates[1] -= 8388608.f;
191
+ fp32_intermediates[2] -= 8388608.f;
192
+ fp32_intermediates[3] -= 8388608.f;
193
+
194
+ uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(res);
195
+ bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
196
+ fp32_intermediates_casted[1], 0x7632);
197
+ bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
198
+ fp32_intermediates_casted[3], 0x7632);
199
+ }
200
+ #endif
tests/kernels/test_moe.py CHANGED
@@ -1,3 +1,4 @@
 
1
  """Tests for the MOE layers.
2
 
3
  Run `pytest tests/kernels/test_moe.py`.
@@ -18,12 +19,15 @@ from moe.utils.marlin_utils_test import marlin_quantize, quantize_weights
18
  from .utils import compute_max_diff, opcheck, torch_moe
19
 
20
 
 
 
 
21
  def stack_and_dev(tensors: List[torch.Tensor]):
22
  dev = tensors[0].device
23
  return torch.stack(tensors, dim=0).to(dev)
24
 
25
-
26
  NUM_EXPERTS = [8, 64]
 
27
  TOP_KS = [2, 6]
28
 
29
 
@@ -32,25 +36,54 @@ TOP_KS = [2, 6]
32
  @pytest.mark.parametrize("k", [128, 511, 1024])
33
  @pytest.mark.parametrize("e", NUM_EXPERTS)
34
  @pytest.mark.parametrize("topk", TOP_KS)
 
35
  @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
 
36
  def test_fused_moe(
37
  m: int,
38
  n: int,
39
  k: int,
40
  e: int,
41
  topk: int,
 
42
  dtype: torch.dtype,
 
43
  ):
44
  a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
45
  w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
46
  w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
47
 
48
  score = torch.randn((m, e), device="cuda", dtype=dtype)
49
- triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
50
- torch_output = torch_moe(a, w1, w2, score, topk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
52
- # iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
53
- # torch.testing.assert_close(iterative_output, torch_output, atol=2e-2, rtol=0)
54
 
55
 
56
  @pytest.mark.parametrize("m", [1, 32, 222])
@@ -58,21 +91,14 @@ def test_fused_moe(
58
  @pytest.mark.parametrize("k", [128, 1024])
59
  @pytest.mark.parametrize("e", NUM_EXPERTS)
60
  @pytest.mark.parametrize("topk", TOP_KS)
 
61
  @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
62
  @pytest.mark.parametrize("group_size", [64, 128])
63
  @pytest.mark.parametrize("has_zp", [True, False])
64
  @pytest.mark.parametrize("weight_bits", [4, 8])
65
- def test_fused_moe_wn16(
66
- m: int,
67
- n: int,
68
- k: int,
69
- e: int,
70
- topk: int,
71
- dtype: torch.dtype,
72
- group_size: int,
73
- has_zp: bool,
74
- weight_bits: int,
75
- ):
76
  print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
77
  a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
78
  w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
@@ -88,40 +114,35 @@ def test_fused_moe_wn16(
88
 
89
  w1_ref = w1.clone()
90
  w2_ref = w2.clone()
91
- w1_qweight = torch.empty(
92
- (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
93
- )
94
- w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
95
- w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
96
- w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
97
- w1_qzeros = torch.empty(
98
- (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
99
- )
100
- w2_qzeros = torch.empty(
101
- (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
102
- )
 
 
 
 
 
 
103
 
104
  for i in range(e * 2):
105
  expert_id = i % e
106
  if i // e == 0:
107
- w, w_ref, w_qweight, w_scales, w_qzeros = (
108
- w1,
109
- w1_ref,
110
- w1_qweight,
111
- w1_scales,
112
- w1_qzeros,
113
- )
114
  else:
115
- w, w_ref, w_qweight, w_scales, w_qzeros = (
116
- w2,
117
- w2_ref,
118
- w2_qweight,
119
- w2_scales,
120
- w2_qzeros,
121
- )
122
  weight, qweight, scales, qzeros = quantize_weights(
123
- w[expert_id].T, quant_type, group_size, has_zp, False
124
- )
125
  weight = weight.T
126
  qweight = qweight.T.contiguous().to(torch.uint8)
127
  scales = scales.T
@@ -138,25 +159,45 @@ def test_fused_moe_wn16(
138
  if has_zp:
139
  w_qzeros[expert_id] = qzeros
140
 
141
- triton_output = fused_moe(
142
- a,
143
- w1_qweight,
144
- w2_qweight,
145
- score,
146
- topk,
147
- renormalize=False,
148
- use_int4_w4a16=weight_bits == 4,
149
- use_int8_w8a16=weight_bits == 8,
150
- w1_scale=w1_scales,
151
- w2_scale=w2_scales,
152
- w1_zp=w1_qzeros if has_zp else None,
153
- w2_zp=w2_qzeros if has_zp else None,
154
- block_shape=[0, group_size],
155
- )
156
- torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
158
 
159
 
 
160
  @pytest.mark.parametrize("m", [1, 33, 64, 222])
161
  @pytest.mark.parametrize("n", [128, 2048])
162
  @pytest.mark.parametrize("k", [128, 1024])
@@ -178,7 +219,7 @@ def test_fused_marlin_moe(
178
  num_bits: int,
179
  is_k_full: bool,
180
  ):
181
- torch.manual_seed(7)
182
 
183
  # Filter act_order
184
  if act_order:
@@ -190,7 +231,8 @@ def test_fused_marlin_moe(
190
  if not is_k_full:
191
  return
192
 
193
- quant_type = scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
 
194
  dtype = torch.float16
195
  a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
196
  w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
@@ -205,8 +247,8 @@ def test_fused_marlin_moe(
205
  for i in range(w1.shape[0]):
206
  test_perm = torch.randperm(k)
207
  w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
208
- w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
209
- )
210
  w_ref1_l.append(w_ref1)
211
  qweight1_l.append(qweight1)
212
  scales1_l.append(scales1)
@@ -228,8 +270,8 @@ def test_fused_marlin_moe(
228
  for i in range(w2.shape[0]):
229
  test_perm = torch.randperm(n)
230
  w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
231
- w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
232
- )
233
  w_ref2_l.append(w_ref2)
234
  qweight2_l.append(qweight2)
235
  scales2_l.append(scales2)
@@ -273,79 +315,141 @@ def test_fused_marlin_moe(
273
 
274
  assert compute_max_diff(marlin_output, triton_output) < 4e-2
275
 
276
- token_expert_indicies = torch.empty(m, topk, dtype=torch.int32, device=a.device)
 
 
 
277
 
278
- opcheck(
279
- ops.topk_softmax,
280
- (
281
- topk_weights,
282
- topk_ids,
283
- token_expert_indicies,
284
- score.float(),
285
- ),
286
- )
287
 
288
  block_size_m = 4
289
 
290
- sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, e)
 
291
 
292
  max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
293
- workspace = torch.zeros(
294
- max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False
295
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- zp = torch.empty((0, 0), dtype=dtype, device="cuda", requires_grad=False)
298
- opcheck(
299
- ops.marlin_gemm_moe,
300
- (
301
- a,
302
- qweight1,
303
- sorted_token_ids,
304
- topk_weights,
305
- topk_ids,
306
- scales1,
307
- zp,
308
- g_idx1,
309
- sort_indices1,
310
- workspace,
311
- quant_type.id,
312
- m,
313
- 2 * n,
314
- k,
315
- True,
316
- e,
317
- topk,
318
- block_size_m,
319
- True,
320
- False,
321
- ),
322
  )
323
 
 
 
 
 
324
 
325
  def test_moe_align_block_size_opcheck():
326
  num_experts = 4
327
  block_size = 4
328
- topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
 
 
 
329
 
330
  max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
331
- sorted_ids = torch.empty(
332
- (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
333
- )
334
  sorted_ids.fill_(topk_ids.numel())
335
  max_num_m_blocks = max_num_tokens_padded // block_size
336
- expert_ids = torch.empty(
337
- (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
338
- )
339
- num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
340
-
341
- opcheck(
342
- ops.moe_align_block_size,
343
- (
344
- topk_ids,
345
- num_experts,
346
- block_size,
347
- sorted_ids,
348
- expert_ids,
349
- num_tokens_post_pad,
350
- ),
351
- )
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
  """Tests for the MOE layers.
3
 
4
  Run `pytest tests/kernels/test_moe.py`.
 
19
  from .utils import compute_max_diff, opcheck, torch_moe
20
 
21
 
22
+ from torch.nn import Parameter
23
+ from torch.nn import functional as F
24
+
25
  def stack_and_dev(tensors: List[torch.Tensor]):
26
  dev = tensors[0].device
27
  return torch.stack(tensors, dim=0).to(dev)
28
 
 
29
  NUM_EXPERTS = [8, 64]
30
+ EP_SIZE = [1, 4]
31
  TOP_KS = [2, 6]
32
 
33
 
 
36
  @pytest.mark.parametrize("k", [128, 511, 1024])
37
  @pytest.mark.parametrize("e", NUM_EXPERTS)
38
  @pytest.mark.parametrize("topk", TOP_KS)
39
+ @pytest.mark.parametrize("ep_size", EP_SIZE)
40
  @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
41
+ @pytest.mark.parametrize("padding", [True, False])
42
  def test_fused_moe(
43
  m: int,
44
  n: int,
45
  k: int,
46
  e: int,
47
  topk: int,
48
+ ep_size: int,
49
  dtype: torch.dtype,
50
+ padding: bool,
51
  ):
52
  a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
53
  w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
54
  w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
55
 
56
  score = torch.randn((m, e), device="cuda", dtype=dtype)
57
+
58
+ if ep_size > 1:
59
+ local_e = e // ep_size
60
+ e_ids = torch.randint(0,
61
+ e, (local_e, ),
62
+ device="cuda",
63
+ dtype=torch.int32)
64
+ e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
65
+ e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
66
+ w1 = w1[e_ids]
67
+ w2 = w2[e_ids]
68
+ else:
69
+ e_map = None
70
+
71
+ torch_output = torch_moe(a, w1, w2, score, topk, e_map)
72
+ if padding:
73
+ w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
74
+ torch.cuda.empty_cache()
75
+ w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
76
+ torch.cuda.empty_cache()
77
+
78
+ triton_output = fused_moe(a,
79
+ w1,
80
+ w2,
81
+ score,
82
+ topk,
83
+ global_num_experts=e,
84
+ expert_map=e_map,
85
+ renormalize=False)
86
  torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
 
 
87
 
88
 
89
  @pytest.mark.parametrize("m", [1, 32, 222])
 
91
  @pytest.mark.parametrize("k", [128, 1024])
92
  @pytest.mark.parametrize("e", NUM_EXPERTS)
93
  @pytest.mark.parametrize("topk", TOP_KS)
94
+ @pytest.mark.parametrize("ep_size", EP_SIZE)
95
  @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
96
  @pytest.mark.parametrize("group_size", [64, 128])
97
  @pytest.mark.parametrize("has_zp", [True, False])
98
  @pytest.mark.parametrize("weight_bits", [4, 8])
99
+ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
100
+ ep_size: int, dtype: torch.dtype, group_size: int,
101
+ has_zp: bool, weight_bits: int):
 
 
 
 
 
 
 
 
102
  print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
103
  a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
104
  w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
 
114
 
115
  w1_ref = w1.clone()
116
  w2_ref = w2.clone()
117
+ w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
118
+ device="cuda",
119
+ dtype=torch.uint8)
120
+ w2_qweight = torch.empty((e, k, n // pack_factor),
121
+ device="cuda",
122
+ dtype=torch.uint8)
123
+ w1_scales = torch.empty((e, 2 * n, k // group_size),
124
+ device="cuda",
125
+ dtype=dtype)
126
+ w2_scales = torch.empty((e, k, n // group_size),
127
+ device="cuda",
128
+ dtype=dtype)
129
+ w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
130
+ device="cuda",
131
+ dtype=torch.uint8)
132
+ w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
133
+ device="cuda",
134
+ dtype=torch.uint8)
135
 
136
  for i in range(e * 2):
137
  expert_id = i % e
138
  if i // e == 0:
139
+ w, w_ref, w_qweight, w_scales, w_qzeros = \
140
+ w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
 
 
 
 
 
141
  else:
142
+ w, w_ref, w_qweight, w_scales, w_qzeros = \
143
+ w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
 
 
 
 
 
144
  weight, qweight, scales, qzeros = quantize_weights(
145
+ w[expert_id].T, quant_type, group_size, has_zp, False)
 
146
  weight = weight.T
147
  qweight = qweight.T.contiguous().to(torch.uint8)
148
  scales = scales.T
 
159
  if has_zp:
160
  w_qzeros[expert_id] = qzeros
161
 
162
+ if ep_size > 1:
163
+ local_e = e // ep_size
164
+ e_ids = torch.randint(0,
165
+ e, (local_e, ),
166
+ device="cuda",
167
+ dtype=torch.int32)
168
+ e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
169
+ e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
170
+ w1_ref = w1_ref[e_ids]
171
+ w2_ref = w2_ref[e_ids]
172
+ w1_qweight = w1_qweight[e_ids]
173
+ w2_qweight = w2_qweight[e_ids]
174
+ w1_scales = w1_scales[e_ids]
175
+ w2_scales = w2_scales[e_ids]
176
+ w1_qzeros = w1_qzeros[e_ids]
177
+ w2_qzeros = w2_qzeros[e_ids]
178
+ else:
179
+ e_map = None
180
+
181
+ triton_output = fused_moe(a,
182
+ w1_qweight,
183
+ w2_qweight,
184
+ score,
185
+ topk,
186
+ renormalize=False,
187
+ use_int4_w4a16=weight_bits == 4,
188
+ use_int8_w8a16=weight_bits == 8,
189
+ global_num_experts=e,
190
+ expert_map=e_map,
191
+ w1_scale=w1_scales,
192
+ w2_scale=w2_scales,
193
+ w1_zp=w1_qzeros if has_zp else None,
194
+ w2_zp=w2_qzeros if has_zp else None,
195
+ block_shape=[0, group_size])
196
+ torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
197
  torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
198
 
199
 
200
+
201
  @pytest.mark.parametrize("m", [1, 33, 64, 222])
202
  @pytest.mark.parametrize("n", [128, 2048])
203
  @pytest.mark.parametrize("k", [128, 1024])
 
219
  num_bits: int,
220
  is_k_full: bool,
221
  ):
222
+ current_platform.seed_everything(7)
223
 
224
  # Filter act_order
225
  if act_order:
 
231
  if not is_k_full:
232
  return
233
 
234
+ quant_type = (scalar_types.uint4b8
235
+ if num_bits == 4 else scalar_types.uint8b128)
236
  dtype = torch.float16
237
  a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
238
  w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
 
247
  for i in range(w1.shape[0]):
248
  test_perm = torch.randperm(k)
249
  w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
250
+ w1[i].transpose(1, 0), quant_type, group_size, act_order,
251
+ test_perm)
252
  w_ref1_l.append(w_ref1)
253
  qweight1_l.append(qweight1)
254
  scales1_l.append(scales1)
 
270
  for i in range(w2.shape[0]):
271
  test_perm = torch.randperm(n)
272
  w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
273
+ w2[i].transpose(1, 0), quant_type, group_size, act_order,
274
+ test_perm)
275
  w_ref2_l.append(w_ref2)
276
  qweight2_l.append(qweight2)
277
  scales2_l.append(scales2)
 
315
 
316
  assert compute_max_diff(marlin_output, triton_output) < 4e-2
317
 
318
+ token_expert_indicies = torch.empty(m,
319
+ topk,
320
+ dtype=torch.int32,
321
+ device=a.device)
322
 
323
+ opcheck(ops.topk_softmax, (
324
+ topk_weights,
325
+ topk_ids,
326
+ token_expert_indicies,
327
+ score.float(),
328
+ ))
 
 
 
329
 
330
  block_size_m = 4
331
 
332
+ sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m,
333
+ e)
334
 
335
  max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
336
+ workspace = torch.zeros(max_workspace_size,
337
+ dtype=torch.int,
338
+ device="cuda",
339
+ requires_grad=False)
340
+
341
+ zp = torch.empty((0, 0),
342
+ dtype=dtype,
343
+ device="cuda",
344
+ requires_grad=False)
345
+ opcheck(ops.marlin_gemm_moe,
346
+ (a, qweight1, sorted_token_ids, topk_weights, topk_ids,
347
+ scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
348
+ m, 2 * n, k, True, e, topk, block_size_m, True, False))
349
+
350
+
351
+ @pytest.mark.skip("This test is here for the sake of debugging, "
352
+ "don't run it in automated tests.")
353
+ @pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
354
+ @pytest.mark.parametrize("n", [128, 2048, 256, 1024])
355
+ @pytest.mark.parametrize("k", [128, 1024, 512])
356
+ @pytest.mark.parametrize("e", [8, 64])
357
+ @pytest.mark.parametrize("topk", [2, 6])
358
+ @pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
359
+ @pytest.mark.parametrize("act_order", [True, False])
360
+ @pytest.mark.parametrize("num_bits", [4, 8])
361
+ @pytest.mark.parametrize("is_k_full", [True, False])
362
+ @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
363
+ def test_single_marlin_moe_multiply(
364
+ m: int,
365
+ n: int,
366
+ k: int,
367
+ e: int,
368
+ topk: int,
369
+ group_size: int,
370
+ act_order: bool,
371
+ num_bits: int,
372
+ is_k_full: bool,
373
+ ):
374
+
375
+ # Filter act_order
376
+ if act_order:
377
+ if group_size == -1:
378
+ return
379
+ if group_size == k:
380
+ return
381
+ else:
382
+ if not is_k_full:
383
+ return
384
+
385
+ quant_type = (scalar_types.uint4b8
386
+ if num_bits == 4 else scalar_types.uint8b128)
387
+ dtype = torch.float16
388
+ a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
389
+ w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
390
+
391
+ w_ref_l = []
392
+ qweights_l = []
393
+ scales_l = []
394
+ g_idx_l = []
395
+ sort_indices_l = []
396
+
397
+ for i in range(w.shape[0]):
398
+ test_perm = torch.randperm(k)
399
+ w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
400
+ w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
401
+ w_ref_l.append(w_ref)
402
+ qweights_l.append(qweight)
403
+ scales_l.append(scales)
404
+ g_idx_l.append(g_idx)
405
+ sort_indices_l.append(sort_indices)
406
+
407
+ w_ref = stack_and_dev(w_ref_l)
408
+ qweight = stack_and_dev(qweights_l).contiguous()
409
+ scales = stack_and_dev(scales_l)
410
+ g_idx = stack_and_dev(g_idx_l)
411
+ sort_indices = stack_and_dev(sort_indices_l)
412
 
413
+ score = torch.randn((m, e), device="cuda", dtype=dtype)
414
+ marlin_output = ops.single_marlin_moe(
415
+ a,
416
+ qweight,
417
+ scales,
418
+ score,
419
+ topk,
420
+ renormalize=False,
421
+ g_idx=g_idx,
422
+ sort_indices=sort_indices,
423
+ num_bits=num_bits,
424
+ is_k_full=is_k_full,
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  )
426
 
427
+ torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
428
+
429
+ assert compute_max_diff(marlin_output, torch_output) < 1e-2
430
+
431
 
432
  def test_moe_align_block_size_opcheck():
433
  num_experts = 4
434
  block_size = 4
435
+ topk_ids = torch.randint(0,
436
+ num_experts, (3, 4),
437
+ dtype=torch.int32,
438
+ device='cuda')
439
 
440
  max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
441
+ sorted_ids = torch.empty((max_num_tokens_padded, ),
442
+ dtype=torch.int32,
443
+ device=topk_ids.device)
444
  sorted_ids.fill_(topk_ids.numel())
445
  max_num_m_blocks = max_num_tokens_padded // block_size
446
+ expert_ids = torch.empty((max_num_m_blocks, ),
447
+ dtype=torch.int32,
448
+ device=topk_ids.device)
449
+ num_tokens_post_pad = torch.empty((1),
450
+ dtype=torch.int32,
451
+ device=topk_ids.device)
452
+
453
+ opcheck(ops.moe_align_block_size,
454
+ (topk_ids, num_experts, block_size, sorted_ids, expert_ids,
455
+ num_tokens_post_pad))
 
 
 
 
 
 
tests/kernels/utils.py CHANGED
@@ -38,7 +38,7 @@ class SiluAndMul(nn.Module):
38
  return F.silu(x[..., :d]) * x[..., d:]
39
 
40
 
41
- def torch_moe(a, w1, w2, score, topk):
42
  B, D = a.shape
43
  a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
44
  out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
@@ -46,15 +46,15 @@ def torch_moe(a, w1, w2, score, topk):
46
  topk_weight, topk_ids = torch.topk(score, topk)
47
  topk_weight = topk_weight.view(-1)
48
  topk_ids = topk_ids.view(-1)
 
 
49
  for i in range(w1.shape[0]):
50
  mask = topk_ids == i
51
  if mask.sum():
52
- out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
53
- 0, 1
54
- )
55
- return (
56
- out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
57
- ).sum(dim=1)
58
 
59
 
60
  # Copied/modified from torch._refs.__init__.py
 
38
  return F.silu(x[..., :d]) * x[..., d:]
39
 
40
 
41
+ def torch_moe(a, w1, w2, score, topk, expert_map):
42
  B, D = a.shape
43
  a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
44
  out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
 
46
  topk_weight, topk_ids = torch.topk(score, topk)
47
  topk_weight = topk_weight.view(-1)
48
  topk_ids = topk_ids.view(-1)
49
+ if expert_map is not None:
50
+ topk_ids = expert_map[topk_ids]
51
  for i in range(w1.shape[0]):
52
  mask = topk_ids == i
53
  if mask.sum():
54
+ out[mask] = SiluAndMul()(
55
+ a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
56
+ return (out.view(B, -1, w2.shape[1]) *
57
+ topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
 
 
58
 
59
 
60
  # Copied/modified from torch._refs.__init__.py
torch-ext/moe/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
 
 
3
  from ._ops import ops
4
  from .fp8_utils import per_token_group_quant_fp8, w8a8_block_fp8_matmul
5
  from .fused_marlin_moe import fused_marlin_moe
@@ -51,24 +52,6 @@ def moe_sum(input: torch.Tensor, output: torch.Tensor):
51
  ops.moe_sum(input, output)
52
 
53
 
54
- def moe_align_block_size(
55
- topk_ids: torch.Tensor,
56
- num_experts: int,
57
- block_size: int,
58
- sorted_token_ids: torch.Tensor,
59
- experts_ids: torch.Tensor,
60
- num_tokens_post_pad: torch.Tensor,
61
- ) -> None:
62
- ops.moe_align_block_size(
63
- topk_ids,
64
- num_experts,
65
- block_size,
66
- sorted_token_ids,
67
- experts_ids,
68
- num_tokens_post_pad,
69
- )
70
-
71
-
72
  def topk_softmax(
73
  topk_weights: torch.Tensor,
74
  topk_ids: torch.Tensor,
@@ -87,6 +70,7 @@ __all__ = [
87
  "fused_topk",
88
  "gptq_marlin_moe_repack",
89
  "grouped_topk",
 
90
  "moe_align_block_size",
91
  "moe_sum",
92
  "per_token_group_quant_fp8",
 
1
  import torch
2
 
3
+ from . import layers
4
  from ._ops import ops
5
  from .fp8_utils import per_token_group_quant_fp8, w8a8_block_fp8_matmul
6
  from .fused_marlin_moe import fused_marlin_moe
 
52
  ops.moe_sum(input, output)
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def topk_softmax(
56
  topk_weights: torch.Tensor,
57
  topk_ids: torch.Tensor,
 
70
  "fused_topk",
71
  "gptq_marlin_moe_repack",
72
  "grouped_topk",
73
+ "layers",
74
  "moe_align_block_size",
75
  "moe_sum",
76
  "per_token_group_quant_fp8",
torch-ext/moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 2
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 5
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 2
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
torch-ext/moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 2
12
+ },
13
+ "2": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 128,
16
+ "BLOCK_SIZE_K": 64,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 8,
19
+ "num_stages": 2,
20
+ "waves_per_eu": 0,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 1
23
+ },
24
+ "4": {
25
+ "BLOCK_SIZE_M": 16,
26
+ "BLOCK_SIZE_N": 128,
27
+ "BLOCK_SIZE_K": 128,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 8,
30
+ "num_stages": 2,
31
+ "waves_per_eu": 0,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 2
34
+ },
35
+ "8": {
36
+ "BLOCK_SIZE_M": 16,
37
+ "BLOCK_SIZE_N": 64,
38
+ "BLOCK_SIZE_K": 128,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 4,
41
+ "num_stages": 2,
42
+ "waves_per_eu": 0,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 2
45
+ },
46
+ "16": {
47
+ "BLOCK_SIZE_M": 16,
48
+ "BLOCK_SIZE_N": 64,
49
+ "BLOCK_SIZE_K": 128,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 4,
52
+ "num_stages": 2,
53
+ "waves_per_eu": 0,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 2
56
+ },
57
+ "24": {
58
+ "BLOCK_SIZE_M": 16,
59
+ "BLOCK_SIZE_N": 64,
60
+ "BLOCK_SIZE_K": 128,
61
+ "GROUP_SIZE_M": 1,
62
+ "num_warps": 4,
63
+ "num_stages": 2,
64
+ "waves_per_eu": 0,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 2
67
+ },
68
+ "32": {
69
+ "BLOCK_SIZE_M": 16,
70
+ "BLOCK_SIZE_N": 64,
71
+ "BLOCK_SIZE_K": 256,
72
+ "GROUP_SIZE_M": 16,
73
+ "num_warps": 4,
74
+ "num_stages": 2,
75
+ "waves_per_eu": 0,
76
+ "matrix_instr_nonkdim": 16,
77
+ "kpack": 2
78
+ },
79
+ "48": {
80
+ "BLOCK_SIZE_M": 16,
81
+ "BLOCK_SIZE_N": 64,
82
+ "BLOCK_SIZE_K": 256,
83
+ "GROUP_SIZE_M": 16,
84
+ "num_warps": 4,
85
+ "num_stages": 2,
86
+ "waves_per_eu": 0,
87
+ "matrix_instr_nonkdim": 16,
88
+ "kpack": 2
89
+ },
90
+ "64": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 8,
96
+ "num_stages": 2,
97
+ "waves_per_eu": 0,
98
+ "matrix_instr_nonkdim": 16,
99
+ "kpack": 1
100
+ },
101
+ "96": {
102
+ "BLOCK_SIZE_M": 16,
103
+ "BLOCK_SIZE_N": 64,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 32,
106
+ "num_warps": 1,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0,
109
+ "matrix_instr_nonkdim": 16,
110
+ "kpack": 2
111
+ },
112
+ "128": {
113
+ "BLOCK_SIZE_M": 16,
114
+ "BLOCK_SIZE_N": 64,
115
+ "BLOCK_SIZE_K": 128,
116
+ "GROUP_SIZE_M": 32,
117
+ "num_warps": 1,
118
+ "num_stages": 2,
119
+ "waves_per_eu": 0,
120
+ "matrix_instr_nonkdim": 16,
121
+ "kpack": 2
122
+ },
123
+ "256": {
124
+ "BLOCK_SIZE_M": 64,
125
+ "BLOCK_SIZE_N": 64,
126
+ "BLOCK_SIZE_K": 128,
127
+ "GROUP_SIZE_M": 32,
128
+ "num_warps": 8,
129
+ "num_stages": 2,
130
+ "waves_per_eu": 0,
131
+ "matrix_instr_nonkdim": 16,
132
+ "kpack": 1
133
+ },
134
+ "512": {
135
+ "BLOCK_SIZE_M": 64,
136
+ "BLOCK_SIZE_N": 64,
137
+ "BLOCK_SIZE_K": 128,
138
+ "GROUP_SIZE_M": 8,
139
+ "num_warps": 8,
140
+ "num_stages": 2,
141
+ "waves_per_eu": 0,
142
+ "matrix_instr_nonkdim": 16,
143
+ "kpack": 1
144
+ },
145
+ "1024": {
146
+ "BLOCK_SIZE_M": 64,
147
+ "BLOCK_SIZE_N": 64,
148
+ "BLOCK_SIZE_K": 128,
149
+ "GROUP_SIZE_M": 8,
150
+ "num_warps": 8,
151
+ "num_stages": 2,
152
+ "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
+ "kpack": 2
155
+ },
156
+ "1536": {
157
+ "BLOCK_SIZE_M": 64,
158
+ "BLOCK_SIZE_N": 128,
159
+ "BLOCK_SIZE_K": 128,
160
+ "GROUP_SIZE_M": 8,
161
+ "num_warps": 8,
162
+ "num_stages": 2,
163
+ "waves_per_eu": 0,
164
+ "matrix_instr_nonkdim": 16,
165
+ "kpack": 2
166
+ },
167
+ "2048": {
168
+ "BLOCK_SIZE_M": 128,
169
+ "BLOCK_SIZE_N": 128,
170
+ "BLOCK_SIZE_K": 64,
171
+ "GROUP_SIZE_M": 16,
172
+ "num_warps": 8,
173
+ "num_stages": 2,
174
+ "waves_per_eu": 0,
175
+ "matrix_instr_nonkdim": 16,
176
+ "kpack": 2
177
+ },
178
+ "3072": {
179
+ "BLOCK_SIZE_M": 128,
180
+ "BLOCK_SIZE_N": 128,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 8,
183
+ "num_warps": 8,
184
+ "num_stages": 2,
185
+ "waves_per_eu": 0,
186
+ "matrix_instr_nonkdim": 16,
187
+ "kpack": 2
188
+ },
189
+ "4096": {
190
+ "BLOCK_SIZE_M": 128,
191
+ "BLOCK_SIZE_N": 128,
192
+ "BLOCK_SIZE_K": 64,
193
+ "GROUP_SIZE_M": 8,
194
+ "num_warps": 8,
195
+ "num_stages": 2,
196
+ "waves_per_eu": 0,
197
+ "matrix_instr_nonkdim": 16,
198
+ "kpack": 2
199
+ }
200
+ }
torch-ext/moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 2
12
+ },
13
+ "2": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 128,
16
+ "BLOCK_SIZE_K": 64,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 8,
19
+ "num_stages": 2,
20
+ "waves_per_eu": 0,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 1
23
+ },
24
+ "4": {
25
+ "BLOCK_SIZE_M": 16,
26
+ "BLOCK_SIZE_N": 128,
27
+ "BLOCK_SIZE_K": 128,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 8,
30
+ "num_stages": 2,
31
+ "waves_per_eu": 0,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 2
34
+ },
35
+ "8": {
36
+ "BLOCK_SIZE_M": 16,
37
+ "BLOCK_SIZE_N": 64,
38
+ "BLOCK_SIZE_K": 128,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 4,
41
+ "num_stages": 2,
42
+ "waves_per_eu": 0,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 2
45
+ },
46
+ "16": {
47
+ "BLOCK_SIZE_M": 16,
48
+ "BLOCK_SIZE_N": 64,
49
+ "BLOCK_SIZE_K": 128,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 4,
52
+ "num_stages": 2,
53
+ "waves_per_eu": 0,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 2
56
+ },
57
+ "24": {
58
+ "BLOCK_SIZE_M": 16,
59
+ "BLOCK_SIZE_N": 64,
60
+ "BLOCK_SIZE_K": 128,
61
+ "GROUP_SIZE_M": 1,
62
+ "num_warps": 4,
63
+ "num_stages": 2,
64
+ "waves_per_eu": 0,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 2
67
+ },
68
+ "32": {
69
+ "BLOCK_SIZE_M": 16,
70
+ "BLOCK_SIZE_N": 64,
71
+ "BLOCK_SIZE_K": 256,
72
+ "GROUP_SIZE_M": 16,
73
+ "num_warps": 4,
74
+ "num_stages": 2,
75
+ "waves_per_eu": 0,
76
+ "matrix_instr_nonkdim": 16,
77
+ "kpack": 2
78
+ },
79
+ "48": {
80
+ "BLOCK_SIZE_M": 16,
81
+ "BLOCK_SIZE_N": 64,
82
+ "BLOCK_SIZE_K": 256,
83
+ "GROUP_SIZE_M": 16,
84
+ "num_warps": 4,
85
+ "num_stages": 2,
86
+ "waves_per_eu": 0,
87
+ "matrix_instr_nonkdim": 16,
88
+ "kpack": 2
89
+ },
90
+ "64": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 8,
96
+ "num_stages": 2,
97
+ "waves_per_eu": 0,
98
+ "matrix_instr_nonkdim": 16,
99
+ "kpack": 1
100
+ },
101
+ "96": {
102
+ "BLOCK_SIZE_M": 16,
103
+ "BLOCK_SIZE_N": 64,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 32,
106
+ "num_warps": 1,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0,
109
+ "matrix_instr_nonkdim": 16,
110
+ "kpack": 2
111
+ },
112
+ "128": {
113
+ "BLOCK_SIZE_M": 16,
114
+ "BLOCK_SIZE_N": 64,
115
+ "BLOCK_SIZE_K": 128,
116
+ "GROUP_SIZE_M": 32,
117
+ "num_warps": 1,
118
+ "num_stages": 2,
119
+ "waves_per_eu": 0,
120
+ "matrix_instr_nonkdim": 16,
121
+ "kpack": 2
122
+ },
123
+ "256": {
124
+ "BLOCK_SIZE_M": 64,
125
+ "BLOCK_SIZE_N": 64,
126
+ "BLOCK_SIZE_K": 128,
127
+ "GROUP_SIZE_M": 32,
128
+ "num_warps": 8,
129
+ "num_stages": 2,
130
+ "waves_per_eu": 0,
131
+ "matrix_instr_nonkdim": 16,
132
+ "kpack": 1
133
+ },
134
+ "512": {
135
+ "BLOCK_SIZE_M": 64,
136
+ "BLOCK_SIZE_N": 64,
137
+ "BLOCK_SIZE_K": 128,
138
+ "GROUP_SIZE_M": 8,
139
+ "num_warps": 8,
140
+ "num_stages": 2,
141
+ "waves_per_eu": 0,
142
+ "matrix_instr_nonkdim": 16,
143
+ "kpack": 1
144
+ },
145
+ "1024": {
146
+ "BLOCK_SIZE_M": 64,
147
+ "BLOCK_SIZE_N": 64,
148
+ "BLOCK_SIZE_K": 128,
149
+ "GROUP_SIZE_M": 8,
150
+ "num_warps": 8,
151
+ "num_stages": 2,
152
+ "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
+ "kpack": 2
155
+ },
156
+ "1536": {
157
+ "BLOCK_SIZE_M": 64,
158
+ "BLOCK_SIZE_N": 128,
159
+ "BLOCK_SIZE_K": 128,
160
+ "GROUP_SIZE_M": 8,
161
+ "num_warps": 8,
162
+ "num_stages": 2,
163
+ "waves_per_eu": 0,
164
+ "matrix_instr_nonkdim": 16,
165
+ "kpack": 2
166
+ },
167
+ "2048": {
168
+ "BLOCK_SIZE_M": 128,
169
+ "BLOCK_SIZE_N": 128,
170
+ "BLOCK_SIZE_K": 64,
171
+ "GROUP_SIZE_M": 16,
172
+ "num_warps": 8,
173
+ "num_stages": 2,
174
+ "waves_per_eu": 0,
175
+ "matrix_instr_nonkdim": 16,
176
+ "kpack": 2
177
+ },
178
+ "3072": {
179
+ "BLOCK_SIZE_M": 128,
180
+ "BLOCK_SIZE_N": 128,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 8,
183
+ "num_warps": 8,
184
+ "num_stages": 2,
185
+ "waves_per_eu": 0,
186
+ "matrix_instr_nonkdim": 16,
187
+ "kpack": 2
188
+ },
189
+ "4096": {
190
+ "BLOCK_SIZE_M": 128,
191
+ "BLOCK_SIZE_N": 128,
192
+ "BLOCK_SIZE_K": 64,
193
+ "GROUP_SIZE_M": 8,
194
+ "num_warps": 8,
195
+ "num_stages": 2,
196
+ "waves_per_eu": 0,
197
+ "matrix_instr_nonkdim": 16,
198
+ "kpack": 2
199
+ }
200
+ }
torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 64,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 4,
88
+ "num_stages": 5
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 32,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 32,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 32,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 32,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 2
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 32,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 32,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 32,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 2
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 32,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 32,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 32,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 32,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 2
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 32,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 32,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 32,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 64,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 64,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 64,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 64,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 8,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0
10
+ },
11
+ "2": {
12
+ "BLOCK_SIZE_M": 16,
13
+ "BLOCK_SIZE_N": 128,
14
+ "BLOCK_SIZE_K": 256,
15
+ "GROUP_SIZE_M": 1,
16
+ "num_warps": 8,
17
+ "num_stages": 2,
18
+ "waves_per_eu": 0
19
+ },
20
+ "4": {
21
+ "BLOCK_SIZE_M": 16,
22
+ "BLOCK_SIZE_N": 128,
23
+ "BLOCK_SIZE_K": 256,
24
+ "GROUP_SIZE_M": 1,
25
+ "num_warps": 8,
26
+ "num_stages": 2,
27
+ "waves_per_eu": 0
28
+ },
29
+ "8": {
30
+ "BLOCK_SIZE_M": 16,
31
+ "BLOCK_SIZE_N": 128,
32
+ "BLOCK_SIZE_K": 128,
33
+ "GROUP_SIZE_M": 1,
34
+ "num_warps": 8,
35
+ "num_stages": 2,
36
+ "waves_per_eu": 0
37
+ },
38
+ "16": {
39
+ "BLOCK_SIZE_M": 16,
40
+ "BLOCK_SIZE_N": 128,
41
+ "BLOCK_SIZE_K": 128,
42
+ "GROUP_SIZE_M": 1,
43
+ "num_warps": 2,
44
+ "num_stages": 2,
45
+ "waves_per_eu": 0
46
+ },
47
+ "24": {
48
+ "BLOCK_SIZE_M": 16,
49
+ "BLOCK_SIZE_N": 128,
50
+ "BLOCK_SIZE_K": 128,
51
+ "GROUP_SIZE_M": 1,
52
+ "num_warps": 2,
53
+ "num_stages": 2,
54
+ "waves_per_eu": 0
55
+ },
56
+ "32": {
57
+ "BLOCK_SIZE_M": 16,
58
+ "BLOCK_SIZE_N": 128,
59
+ "BLOCK_SIZE_K": 128,
60
+ "GROUP_SIZE_M": 4,
61
+ "num_warps": 2,
62
+ "num_stages": 2,
63
+ "waves_per_eu": 0
64
+ },
65
+ "48": {
66
+ "BLOCK_SIZE_M": 16,
67
+ "BLOCK_SIZE_N": 128,
68
+ "BLOCK_SIZE_K": 128,
69
+ "GROUP_SIZE_M": 4,
70
+ "num_warps": 2,
71
+ "num_stages": 2,
72
+ "waves_per_eu": 0
73
+ },
74
+ "64": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 2,
80
+ "num_stages": 2,
81
+ "waves_per_eu": 0
82
+ },
83
+ "96": {
84
+ "BLOCK_SIZE_M": 16,
85
+ "BLOCK_SIZE_N": 128,
86
+ "BLOCK_SIZE_K": 128,
87
+ "GROUP_SIZE_M": 8,
88
+ "num_warps": 8,
89
+ "num_stages": 2,
90
+ "waves_per_eu": 0
91
+ },
92
+ "128": {
93
+ "BLOCK_SIZE_M": 16,
94
+ "BLOCK_SIZE_N": 128,
95
+ "BLOCK_SIZE_K": 128,
96
+ "GROUP_SIZE_M": 4,
97
+ "num_warps": 4,
98
+ "num_stages": 2,
99
+ "waves_per_eu": 0
100
+ },
101
+ "256": {
102
+ "BLOCK_SIZE_M": 16,
103
+ "BLOCK_SIZE_N": 128,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 8,
106
+ "num_warps": 4,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0
109
+ },
110
+ "512": {
111
+ "BLOCK_SIZE_M": 32,
112
+ "BLOCK_SIZE_N": 128,
113
+ "BLOCK_SIZE_K": 128,
114
+ "GROUP_SIZE_M": 8,
115
+ "num_warps": 4,
116
+ "num_stages": 2,
117
+ "waves_per_eu": 0
118
+ },
119
+ "1024": {
120
+ "BLOCK_SIZE_M": 64,
121
+ "BLOCK_SIZE_N": 128,
122
+ "BLOCK_SIZE_K": 128,
123
+ "GROUP_SIZE_M": 8,
124
+ "num_warps": 2,
125
+ "num_stages": 2,
126
+ "waves_per_eu": 0
127
+ },
128
+ "1536": {
129
+ "BLOCK_SIZE_M": 64,
130
+ "BLOCK_SIZE_N": 128,
131
+ "BLOCK_SIZE_K": 128,
132
+ "GROUP_SIZE_M": 4,
133
+ "num_warps": 2,
134
+ "num_stages": 2,
135
+ "waves_per_eu": 0
136
+ },
137
+ "2048": {
138
+ "BLOCK_SIZE_M": 128,
139
+ "BLOCK_SIZE_N": 256,
140
+ "BLOCK_SIZE_K": 128,
141
+ "GROUP_SIZE_M": 8,
142
+ "num_warps": 4,
143
+ "num_stages": 2,
144
+ "waves_per_eu": 0
145
+ },
146
+ "3072": {
147
+ "BLOCK_SIZE_M": 128,
148
+ "BLOCK_SIZE_N": 256,
149
+ "BLOCK_SIZE_K": 128,
150
+ "GROUP_SIZE_M": 8,
151
+ "num_warps": 4,
152
+ "num_stages": 2,
153
+ "waves_per_eu": 0
154
+ },
155
+ "4096": {
156
+ "BLOCK_SIZE_M": 128,
157
+ "BLOCK_SIZE_N": 256,
158
+ "BLOCK_SIZE_K": 128,
159
+ "GROUP_SIZE_M": 4,
160
+ "num_warps": 4,
161
+ "num_stages": 2,
162
+ "waves_per_eu": 0
163
+ }
164
+ }
torch-ext/moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 16,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 1
12
+ },
13
+ "2": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 64,
16
+ "BLOCK_SIZE_K": 128,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 4,
19
+ "num_stages": 2,
20
+ "waves_per_eu": 0,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 1
23
+ },
24
+ "4": {
25
+ "BLOCK_SIZE_M": 64,
26
+ "BLOCK_SIZE_N": 64,
27
+ "BLOCK_SIZE_K": 128,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 8,
30
+ "num_stages": 2,
31
+ "waves_per_eu": 0,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 1
34
+ },
35
+ "8": {
36
+ "BLOCK_SIZE_M": 16,
37
+ "BLOCK_SIZE_N": 64,
38
+ "BLOCK_SIZE_K": 128,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 4,
41
+ "num_stages": 2,
42
+ "waves_per_eu": 0,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 1
45
+ },
46
+ "16": {
47
+ "BLOCK_SIZE_M": 16,
48
+ "BLOCK_SIZE_N": 64,
49
+ "BLOCK_SIZE_K": 128,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 4,
52
+ "num_stages": 2,
53
+ "waves_per_eu": 0,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 1
56
+ },
57
+ "24": {
58
+ "BLOCK_SIZE_M": 32,
59
+ "BLOCK_SIZE_N": 128,
60
+ "BLOCK_SIZE_K": 64,
61
+ "GROUP_SIZE_M": 1,
62
+ "num_warps": 8,
63
+ "num_stages": 2,
64
+ "waves_per_eu": 0,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 2
67
+ },
68
+ "32": {
69
+ "BLOCK_SIZE_M": 16,
70
+ "BLOCK_SIZE_N": 64,
71
+ "BLOCK_SIZE_K": 128,
72
+ "GROUP_SIZE_M": 1,
73
+ "num_warps": 4,
74
+ "num_stages": 2,
75
+ "waves_per_eu": 0,
76
+ "matrix_instr_nonkdim": 16,
77
+ "kpack": 2
78
+ },
79
+ "48": {
80
+ "BLOCK_SIZE_M": 16,
81
+ "BLOCK_SIZE_N": 64,
82
+ "BLOCK_SIZE_K": 128,
83
+ "GROUP_SIZE_M": 1,
84
+ "num_warps": 2,
85
+ "num_stages": 2,
86
+ "waves_per_eu": 0,
87
+ "matrix_instr_nonkdim": 16,
88
+ "kpack": 2
89
+ },
90
+ "64": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 2,
96
+ "num_stages": 2,
97
+ "waves_per_eu": 0,
98
+ "matrix_instr_nonkdim": 16,
99
+ "kpack": 2
100
+ },
101
+ "96": {
102
+ "BLOCK_SIZE_M": 16,
103
+ "BLOCK_SIZE_N": 32,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 4,
106
+ "num_warps": 2,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0,
109
+ "matrix_instr_nonkdim": 16,
110
+ "kpack": 2
111
+ },
112
+ "128": {
113
+ "BLOCK_SIZE_M": 16,
114
+ "BLOCK_SIZE_N": 32,
115
+ "BLOCK_SIZE_K": 128,
116
+ "GROUP_SIZE_M": 4,
117
+ "num_warps": 1,
118
+ "num_stages": 2,
119
+ "waves_per_eu": 0,
120
+ "matrix_instr_nonkdim": 16,
121
+ "kpack": 2
122
+ },
123
+ "256": {
124
+ "BLOCK_SIZE_M": 16,
125
+ "BLOCK_SIZE_N": 32,
126
+ "BLOCK_SIZE_K": 128,
127
+ "GROUP_SIZE_M": 1,
128
+ "num_warps": 2,
129
+ "num_stages": 2,
130
+ "waves_per_eu": 0,
131
+ "matrix_instr_nonkdim": 16,
132
+ "kpack": 2
133
+ },
134
+ "512": {
135
+ "BLOCK_SIZE_M": 32,
136
+ "BLOCK_SIZE_N": 64,
137
+ "BLOCK_SIZE_K": 128,
138
+ "GROUP_SIZE_M": 8,
139
+ "num_warps": 8,
140
+ "num_stages": 2,
141
+ "waves_per_eu": 0,
142
+ "matrix_instr_nonkdim": 16,
143
+ "kpack": 2
144
+ },
145
+ "1024": {
146
+ "BLOCK_SIZE_M": 64,
147
+ "BLOCK_SIZE_N": 128,
148
+ "BLOCK_SIZE_K": 128,
149
+ "GROUP_SIZE_M": 8,
150
+ "num_warps": 8,
151
+ "num_stages": 2,
152
+ "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
+ "kpack": 2
155
+ },
156
+ "1536": {
157
+ "BLOCK_SIZE_M": 64,
158
+ "BLOCK_SIZE_N": 128,
159
+ "BLOCK_SIZE_K": 128,
160
+ "GROUP_SIZE_M": 8,
161
+ "num_warps": 8,
162
+ "num_stages": 2,
163
+ "waves_per_eu": 0,
164
+ "matrix_instr_nonkdim": 16,
165
+ "kpack": 2
166
+ },
167
+ "2048": {
168
+ "BLOCK_SIZE_M": 128,
169
+ "BLOCK_SIZE_N": 128,
170
+ "BLOCK_SIZE_K": 64,
171
+ "GROUP_SIZE_M": 8,
172
+ "num_warps": 8,
173
+ "num_stages": 2,
174
+ "waves_per_eu": 0,
175
+ "matrix_instr_nonkdim": 16,
176
+ "kpack": 2
177
+ },
178
+ "3072": {
179
+ "BLOCK_SIZE_M": 128,
180
+ "BLOCK_SIZE_N": 128,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 8,
183
+ "num_warps": 8,
184
+ "num_stages": 2,
185
+ "waves_per_eu": 0,
186
+ "matrix_instr_nonkdim": 16,
187
+ "kpack": 2
188
+ },
189
+ "4096": {
190
+ "BLOCK_SIZE_M": 128,
191
+ "BLOCK_SIZE_N": 128,
192
+ "BLOCK_SIZE_K": 64,
193
+ "GROUP_SIZE_M": 8,
194
+ "num_warps": 8,
195
+ "num_stages": 2,
196
+ "waves_per_eu": 0,
197
+ "matrix_instr_nonkdim": 16,
198
+ "kpack": 2
199
+ }
200
+ }
torch-ext/moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 16,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 1
12
+ },
13
+ "2": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 64,
16
+ "BLOCK_SIZE_K": 128,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 4,
19
+ "num_stages": 2,
20
+ "waves_per_eu": 0,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 1
23
+ },
24
+ "4": {
25
+ "BLOCK_SIZE_M": 64,
26
+ "BLOCK_SIZE_N": 64,
27
+ "BLOCK_SIZE_K": 128,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 8,
30
+ "num_stages": 2,
31
+ "waves_per_eu": 0,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 1
34
+ },
35
+ "8": {
36
+ "BLOCK_SIZE_M": 16,
37
+ "BLOCK_SIZE_N": 64,
38
+ "BLOCK_SIZE_K": 128,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 4,
41
+ "num_stages": 2,
42
+ "waves_per_eu": 0,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 1
45
+ },
46
+ "16": {
47
+ "BLOCK_SIZE_M": 16,
48
+ "BLOCK_SIZE_N": 64,
49
+ "BLOCK_SIZE_K": 128,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 4,
52
+ "num_stages": 2,
53
+ "waves_per_eu": 0,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 1
56
+ },
57
+ "24": {
58
+ "BLOCK_SIZE_M": 32,
59
+ "BLOCK_SIZE_N": 128,
60
+ "BLOCK_SIZE_K": 64,
61
+ "GROUP_SIZE_M": 1,
62
+ "num_warps": 8,
63
+ "num_stages": 2,
64
+ "waves_per_eu": 0,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 2
67
+ },
68
+ "32": {
69
+ "BLOCK_SIZE_M": 16,
70
+ "BLOCK_SIZE_N": 64,
71
+ "BLOCK_SIZE_K": 128,
72
+ "GROUP_SIZE_M": 1,
73
+ "num_warps": 4,
74
+ "num_stages": 2,
75
+ "waves_per_eu": 0,
76
+ "matrix_instr_nonkdim": 16,
77
+ "kpack": 2
78
+ },
79
+ "48": {
80
+ "BLOCK_SIZE_M": 16,
81
+ "BLOCK_SIZE_N": 64,
82
+ "BLOCK_SIZE_K": 128,
83
+ "GROUP_SIZE_M": 1,
84
+ "num_warps": 2,
85
+ "num_stages": 2,
86
+ "waves_per_eu": 0,
87
+ "matrix_instr_nonkdim": 16,
88
+ "kpack": 2
89
+ },
90
+ "64": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 2,
96
+ "num_stages": 2,
97
+ "waves_per_eu": 0,
98
+ "matrix_instr_nonkdim": 16,
99
+ "kpack": 2
100
+ },
101
+ "96": {
102
+ "BLOCK_SIZE_M": 16,
103
+ "BLOCK_SIZE_N": 32,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 4,
106
+ "num_warps": 2,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0,
109
+ "matrix_instr_nonkdim": 16,
110
+ "kpack": 2
111
+ },
112
+ "128": {
113
+ "BLOCK_SIZE_M": 16,
114
+ "BLOCK_SIZE_N": 32,
115
+ "BLOCK_SIZE_K": 128,
116
+ "GROUP_SIZE_M": 4,
117
+ "num_warps": 1,
118
+ "num_stages": 2,
119
+ "waves_per_eu": 0,
120
+ "matrix_instr_nonkdim": 16,
121
+ "kpack": 2
122
+ },
123
+ "256": {
124
+ "BLOCK_SIZE_M": 16,
125
+ "BLOCK_SIZE_N": 32,
126
+ "BLOCK_SIZE_K": 128,
127
+ "GROUP_SIZE_M": 1,
128
+ "num_warps": 2,
129
+ "num_stages": 2,
130
+ "waves_per_eu": 0,
131
+ "matrix_instr_nonkdim": 16,
132
+ "kpack": 2
133
+ },
134
+ "512": {
135
+ "BLOCK_SIZE_M": 32,
136
+ "BLOCK_SIZE_N": 64,
137
+ "BLOCK_SIZE_K": 128,
138
+ "GROUP_SIZE_M": 8,
139
+ "num_warps": 8,
140
+ "num_stages": 2,
141
+ "waves_per_eu": 0,
142
+ "matrix_instr_nonkdim": 16,
143
+ "kpack": 2
144
+ },
145
+ "1024": {
146
+ "BLOCK_SIZE_M": 64,
147
+ "BLOCK_SIZE_N": 128,
148
+ "BLOCK_SIZE_K": 128,
149
+ "GROUP_SIZE_M": 8,
150
+ "num_warps": 8,
151
+ "num_stages": 2,
152
+ "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
+ "kpack": 2
155
+ },
156
+ "1536": {
157
+ "BLOCK_SIZE_M": 64,
158
+ "BLOCK_SIZE_N": 128,
159
+ "BLOCK_SIZE_K": 128,
160
+ "GROUP_SIZE_M": 8,
161
+ "num_warps": 8,
162
+ "num_stages": 2,
163
+ "waves_per_eu": 0,
164
+ "matrix_instr_nonkdim": 16,
165
+ "kpack": 2
166
+ },
167
+ "2048": {
168
+ "BLOCK_SIZE_M": 128,
169
+ "BLOCK_SIZE_N": 128,
170
+ "BLOCK_SIZE_K": 64,
171
+ "GROUP_SIZE_M": 8,
172
+ "num_warps": 8,
173
+ "num_stages": 2,
174
+ "waves_per_eu": 0,
175
+ "matrix_instr_nonkdim": 16,
176
+ "kpack": 2
177
+ },
178
+ "3072": {
179
+ "BLOCK_SIZE_M": 128,
180
+ "BLOCK_SIZE_N": 128,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 8,
183
+ "num_warps": 8,
184
+ "num_stages": 2,
185
+ "waves_per_eu": 0,
186
+ "matrix_instr_nonkdim": 16,
187
+ "kpack": 2
188
+ },
189
+ "4096": {
190
+ "BLOCK_SIZE_M": 128,
191
+ "BLOCK_SIZE_N": 128,
192
+ "BLOCK_SIZE_K": 64,
193
+ "GROUP_SIZE_M": 8,
194
+ "num_warps": 8,
195
+ "num_stages": 2,
196
+ "waves_per_eu": 0,
197
+ "matrix_instr_nonkdim": 16,
198
+ "kpack": 2
199
+ }
200
+ }
torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 256,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 8,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 256,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 8,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 5
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 256,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 256,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 256,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 32,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 4
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 16,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 16,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 16,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 32,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 32,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 64,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 64,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 32,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 16,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 16,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 256,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 32,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 8,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 32,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 8,
24
+ "num_stages": 2
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 8,
32
+ "num_stages": 2
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 256,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 8,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 8,
48
+ "num_stages": 2
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 8,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 256,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 8,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 8,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 8,
80
+ "num_stages": 2
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 8,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 256,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 8,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 256,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 8,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 256,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 2
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
torch-ext/moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 16,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 1
12
+ },
13
+ "2": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 64,
16
+ "BLOCK_SIZE_K": 128,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 4,
19
+ "num_stages": 2,
20
+ "waves_per_eu": 0,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 1
23
+ },
24
+ "4": {
25
+ "BLOCK_SIZE_M": 64,
26
+ "BLOCK_SIZE_N": 64,
27
+ "BLOCK_SIZE_K": 128,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 8,
30
+ "num_stages": 2,
31
+ "waves_per_eu": 0,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 1
34
+ },
35
+ "8": {
36
+ "BLOCK_SIZE_M": 16,
37
+ "BLOCK_SIZE_N": 64,
38
+ "BLOCK_SIZE_K": 128,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 4,
41
+ "num_stages": 2,
42
+ "waves_per_eu": 0,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 1
45
+ },
46
+ "16": {
47
+ "BLOCK_SIZE_M": 16,
48
+ "BLOCK_SIZE_N": 64,
49
+ "BLOCK_SIZE_K": 128,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 4,
52
+ "num_stages": 2,
53
+ "waves_per_eu": 0,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 1
56
+ },
57
+ "24": {
58
+ "BLOCK_SIZE_M": 32,
59
+ "BLOCK_SIZE_N": 128,
60
+ "BLOCK_SIZE_K": 64,
61
+ "GROUP_SIZE_M": 1,
62
+ "num_warps": 8,
63
+ "num_stages": 2,
64
+ "waves_per_eu": 0,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 2
67
+ },
68
+ "32": {
69
+ "BLOCK_SIZE_M": 16,
70
+ "BLOCK_SIZE_N": 64,
71
+ "BLOCK_SIZE_K": 128,
72
+ "GROUP_SIZE_M": 1,
73
+ "num_warps": 4,
74
+ "num_stages": 2,
75
+ "waves_per_eu": 0,
76
+ "matrix_instr_nonkdim": 16,
77
+ "kpack": 2
78
+ },
79
+ "48": {
80
+ "BLOCK_SIZE_M": 16,
81
+ "BLOCK_SIZE_N": 64,
82
+ "BLOCK_SIZE_K": 128,
83
+ "GROUP_SIZE_M": 1,
84
+ "num_warps": 2,
85
+ "num_stages": 2,
86
+ "waves_per_eu": 0,
87
+ "matrix_instr_nonkdim": 16,
88
+ "kpack": 2
89
+ },
90
+ "64": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 2,
96
+ "num_stages": 2,
97
+ "waves_per_eu": 0,
98
+ "matrix_instr_nonkdim": 16,
99
+ "kpack": 2
100
+ },
101
+ "96": {
102
+ "BLOCK_SIZE_M": 16,
103
+ "BLOCK_SIZE_N": 32,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 4,
106
+ "num_warps": 2,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0,
109
+ "matrix_instr_nonkdim": 16,
110
+ "kpack": 2
111
+ },
112
+ "128": {
113
+ "BLOCK_SIZE_M": 16,
114
+ "BLOCK_SIZE_N": 32,
115
+ "BLOCK_SIZE_K": 128,
116
+ "GROUP_SIZE_M": 4,
117
+ "num_warps": 1,
118
+ "num_stages": 2,
119
+ "waves_per_eu": 0,
120
+ "matrix_instr_nonkdim": 16,
121
+ "kpack": 2
122
+ },
123
+ "256": {
124
+ "BLOCK_SIZE_M": 16,
125
+ "BLOCK_SIZE_N": 32,
126
+ "BLOCK_SIZE_K": 128,
127
+ "GROUP_SIZE_M": 1,
128
+ "num_warps": 2,
129
+ "num_stages": 2,
130
+ "waves_per_eu": 0,
131
+ "matrix_instr_nonkdim": 16,
132
+ "kpack": 2
133
+ },
134
+ "512": {
135
+ "BLOCK_SIZE_M": 32,
136
+ "BLOCK_SIZE_N": 64,
137
+ "BLOCK_SIZE_K": 128,
138
+ "GROUP_SIZE_M": 8,
139
+ "num_warps": 8,
140
+ "num_stages": 2,
141
+ "waves_per_eu": 0,
142
+ "matrix_instr_nonkdim": 16,
143
+ "kpack": 2
144
+ },
145
+ "1024": {
146
+ "BLOCK_SIZE_M": 64,
147
+ "BLOCK_SIZE_N": 128,
148
+ "BLOCK_SIZE_K": 128,
149
+ "GROUP_SIZE_M": 8,
150
+ "num_warps": 8,
151
+ "num_stages": 2,
152
+ "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
+ "kpack": 2
155
+ },
156
+ "1536": {
157
+ "BLOCK_SIZE_M": 64,
158
+ "BLOCK_SIZE_N": 128,
159
+ "BLOCK_SIZE_K": 128,
160
+ "GROUP_SIZE_M": 8,
161
+ "num_warps": 8,
162
+ "num_stages": 2,
163
+ "waves_per_eu": 0,
164
+ "matrix_instr_nonkdim": 16,
165
+ "kpack": 2
166
+ },
167
+ "2048": {
168
+ "BLOCK_SIZE_M": 128,
169
+ "BLOCK_SIZE_N": 128,
170
+ "BLOCK_SIZE_K": 64,
171
+ "GROUP_SIZE_M": 8,
172
+ "num_warps": 8,
173
+ "num_stages": 2,
174
+ "waves_per_eu": 0,
175
+ "matrix_instr_nonkdim": 16,
176
+ "kpack": 2
177
+ },
178
+ "3072": {
179
+ "BLOCK_SIZE_M": 128,
180
+ "BLOCK_SIZE_N": 128,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 8,
183
+ "num_warps": 8,
184
+ "num_stages": 2,
185
+ "waves_per_eu": 0,
186
+ "matrix_instr_nonkdim": 16,
187
+ "kpack": 2
188
+ },
189
+ "4096": {
190
+ "BLOCK_SIZE_M": 128,
191
+ "BLOCK_SIZE_N": 128,
192
+ "BLOCK_SIZE_K": 64,
193
+ "GROUP_SIZE_M": 8,
194
+ "num_warps": 8,
195
+ "num_stages": 2,
196
+ "waves_per_eu": 0,
197
+ "matrix_instr_nonkdim": 16,
198
+ "kpack": 2
199
+ }
200
+ }
torch-ext/moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 32,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
torch-ext/moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 1
12
+ },
13
+ "2": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 128,
16
+ "BLOCK_SIZE_K": 128,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 8,
19
+ "num_stages": 2,
20
+ "waves_per_eu": 0,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 1
23
+ },
24
+ "4": {
25
+ "BLOCK_SIZE_M": 16,
26
+ "BLOCK_SIZE_N": 32,
27
+ "BLOCK_SIZE_K": 64,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 2,
30
+ "num_stages": 2,
31
+ "waves_per_eu": 0,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 1
34
+ },
35
+ "8": {
36
+ "BLOCK_SIZE_M": 32,
37
+ "BLOCK_SIZE_N": 256,
38
+ "BLOCK_SIZE_K": 64,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 8,
41
+ "num_stages": 2,
42
+ "waves_per_eu": 0,
43
+ "matrix_instr_nonkdim": 32,
44
+ "kpack": 2
45
+ },
46
+ "16": {
47
+ "BLOCK_SIZE_M": 32,
48
+ "BLOCK_SIZE_N": 128,
49
+ "BLOCK_SIZE_K": 64,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 2,
52
+ "num_stages": 2,
53
+ "waves_per_eu": 0,
54
+ "matrix_instr_nonkdim": 32,
55
+ "kpack": 2
56
+ },
57
+ "24": {
58
+ "BLOCK_SIZE_M": 16,
59
+ "BLOCK_SIZE_N": 64,
60
+ "BLOCK_SIZE_K": 256,
61
+ "GROUP_SIZE_M": 1,
62
+ "num_warps": 4,
63
+ "num_stages": 2,
64
+ "waves_per_eu": 0,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 2
67
+ },
68
+ "32": {
69
+ "BLOCK_SIZE_M": 16,
70
+ "BLOCK_SIZE_N": 64,
71
+ "BLOCK_SIZE_K": 256,
72
+ "GROUP_SIZE_M": 1,
73
+ "num_warps": 4,
74
+ "num_stages": 2,
75
+ "waves_per_eu": 0,
76
+ "matrix_instr_nonkdim": 16,
77
+ "kpack": 1
78
+ },
79
+ "48": {
80
+ "BLOCK_SIZE_M": 16,
81
+ "BLOCK_SIZE_N": 64,
82
+ "BLOCK_SIZE_K": 256,
83
+ "GROUP_SIZE_M": 4,
84
+ "num_warps": 4,
85
+ "num_stages": 2,
86
+ "waves_per_eu": 0,
87
+ "matrix_instr_nonkdim": 16,
88
+ "kpack": 2
89
+ },
90
+ "64": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 256,
94
+ "GROUP_SIZE_M": 4,
95
+ "num_warps": 4,
96
+ "num_stages": 2,
97
+ "waves_per_eu": 0,
98
+ "matrix_instr_nonkdim": 16,
99
+ "kpack": 2
100
+ },
101
+ "96": {
102
+ "BLOCK_SIZE_M": 16,
103
+ "BLOCK_SIZE_N": 64,
104
+ "BLOCK_SIZE_K": 256,
105
+ "GROUP_SIZE_M": 4,
106
+ "num_warps": 4,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0,
109
+ "matrix_instr_nonkdim": 16,
110
+ "kpack": 2
111
+ },
112
+ "128": {
113
+ "BLOCK_SIZE_M": 16,
114
+ "BLOCK_SIZE_N": 64,
115
+ "BLOCK_SIZE_K": 256,
116
+ "GROUP_SIZE_M": 4,
117
+ "num_warps": 4,
118
+ "num_stages": 2,
119
+ "waves_per_eu": 0,
120
+ "matrix_instr_nonkdim": 16,
121
+ "kpack": 2
122
+ },
123
+ "256": {
124
+ "BLOCK_SIZE_M": 32,
125
+ "BLOCK_SIZE_N": 32,
126
+ "BLOCK_SIZE_K": 256,
127
+ "GROUP_SIZE_M": 4,
128
+ "num_warps": 4,
129
+ "num_stages": 2,
130
+ "waves_per_eu": 0,
131
+ "matrix_instr_nonkdim": 16,
132
+ "kpack": 2
133
+ },
134
+ "512": {
135
+ "BLOCK_SIZE_M": 64,
136
+ "BLOCK_SIZE_N": 64,
137
+ "BLOCK_SIZE_K": 128,
138
+ "GROUP_SIZE_M": 32,
139
+ "num_warps": 8,
140
+ "num_stages": 2,
141
+ "waves_per_eu": 0,
142
+ "matrix_instr_nonkdim": 16,
143
+ "kpack": 1
144
+ },
145
+ "1024": {
146
+ "BLOCK_SIZE_M": 128,
147
+ "BLOCK_SIZE_N": 128,
148
+ "BLOCK_SIZE_K": 64,
149
+ "GROUP_SIZE_M": 32,
150
+ "num_warps": 8,
151
+ "num_stages": 2,
152
+ "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
+ "kpack": 2
155
+ },
156
+ "1536": {
157
+ "BLOCK_SIZE_M": 128,
158
+ "BLOCK_SIZE_N": 128,
159
+ "BLOCK_SIZE_K": 64,
160
+ "GROUP_SIZE_M": 32,
161
+ "num_warps": 8,
162
+ "num_stages": 2,
163
+ "waves_per_eu": 0,
164
+ "matrix_instr_nonkdim": 16,
165
+ "kpack": 2
166
+ },
167
+ "2048": {
168
+ "BLOCK_SIZE_M": 128,
169
+ "BLOCK_SIZE_N": 128,
170
+ "BLOCK_SIZE_K": 64,
171
+ "GROUP_SIZE_M": 4,
172
+ "num_warps": 8,
173
+ "num_stages": 2,
174
+ "waves_per_eu": 0,
175
+ "matrix_instr_nonkdim": 16,
176
+ "kpack": 2
177
+ },
178
+ "3072": {
179
+ "BLOCK_SIZE_M": 128,
180
+ "BLOCK_SIZE_N": 128,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 8,
183
+ "num_warps": 8,
184
+ "num_stages": 2,
185
+ "waves_per_eu": 0,
186
+ "matrix_instr_nonkdim": 16,
187
+ "kpack": 2
188
+ },
189
+ "4096": {
190
+ "BLOCK_SIZE_M": 128,
191
+ "BLOCK_SIZE_N": 128,
192
+ "BLOCK_SIZE_K": 64,
193
+ "GROUP_SIZE_M": 4,
194
+ "num_warps": 8,
195
+ "num_stages": 2,
196
+ "waves_per_eu": 0,
197
+ "matrix_instr_nonkdim": 16,
198
+ "kpack": 2
199
+ }
200
+ }
torch-ext/moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 2,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 2
12
+ },
13
+ "2": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 16,
16
+ "BLOCK_SIZE_K": 256,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 4,
19
+ "num_stages": 2,
20
+ "waves_per_eu": 0,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 2
23
+ },
24
+ "4": {
25
+ "BLOCK_SIZE_M": 16,
26
+ "BLOCK_SIZE_N": 16,
27
+ "BLOCK_SIZE_K": 128,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 4,
30
+ "num_stages": 2,
31
+ "waves_per_eu": 0,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 2
34
+ },
35
+ "8": {
36
+ "BLOCK_SIZE_M": 16,
37
+ "BLOCK_SIZE_N": 64,
38
+ "BLOCK_SIZE_K": 128,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 4,
41
+ "num_stages": 2,
42
+ "waves_per_eu": 0,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 2
45
+ },
46
+ "16": {
47
+ "BLOCK_SIZE_M": 16,
48
+ "BLOCK_SIZE_N": 16,
49
+ "BLOCK_SIZE_K": 64,
50
+ "GROUP_SIZE_M": 4,
51
+ "num_warps": 2,
52
+ "num_stages": 2,
53
+ "waves_per_eu": 0,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 1
56
+ },
57
+ "24": {
58
+ "BLOCK_SIZE_M": 16,
59
+ "BLOCK_SIZE_N": 16,
60
+ "BLOCK_SIZE_K": 64,
61
+ "GROUP_SIZE_M": 4,
62
+ "num_warps": 2,
63
+ "num_stages": 2,
64
+ "waves_per_eu": 0,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 2
67
+ },
68
+ "32": {
69
+ "BLOCK_SIZE_M": 16,
70
+ "BLOCK_SIZE_N": 32,
71
+ "BLOCK_SIZE_K": 64,
72
+ "GROUP_SIZE_M": 8,
73
+ "num_warps": 2,
74
+ "num_stages": 2,
75
+ "waves_per_eu": 0,
76
+ "matrix_instr_nonkdim": 16,
77
+ "kpack": 2
78
+ },
79
+ "48": {
80
+ "BLOCK_SIZE_M": 16,
81
+ "BLOCK_SIZE_N": 32,
82
+ "BLOCK_SIZE_K": 64,
83
+ "GROUP_SIZE_M": 8,
84
+ "num_warps": 2,
85
+ "num_stages": 2,
86
+ "waves_per_eu": 0,
87
+ "matrix_instr_nonkdim": 16,
88
+ "kpack": 1
89
+ },
90
+ "64": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 32,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 8,
95
+ "num_warps": 2,
96
+ "num_stages": 2,
97
+ "waves_per_eu": 0,
98
+ "matrix_instr_nonkdim": 16,
99
+ "kpack": 1
100
+ },
101
+ "96": {
102
+ "BLOCK_SIZE_M": 16,
103
+ "BLOCK_SIZE_N": 16,
104
+ "BLOCK_SIZE_K": 64,
105
+ "GROUP_SIZE_M": 16,
106
+ "num_warps": 1,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0,
109
+ "matrix_instr_nonkdim": 16,
110
+ "kpack": 1
111
+ },
112
+ "128": {
113
+ "BLOCK_SIZE_M": 16,
114
+ "BLOCK_SIZE_N": 16,
115
+ "BLOCK_SIZE_K": 64,
116
+ "GROUP_SIZE_M": 16,
117
+ "num_warps": 1,
118
+ "num_stages": 2,
119
+ "waves_per_eu": 0,
120
+ "matrix_instr_nonkdim": 16,
121
+ "kpack": 1
122
+ },
123
+ "256": {
124
+ "BLOCK_SIZE_M": 32,
125
+ "BLOCK_SIZE_N": 64,
126
+ "BLOCK_SIZE_K": 64,
127
+ "GROUP_SIZE_M": 8,
128
+ "num_warps": 4,
129
+ "num_stages": 2,
130
+ "waves_per_eu": 0,
131
+ "matrix_instr_nonkdim": 16,
132
+ "kpack": 1
133
+ },
134
+ "512": {
135
+ "BLOCK_SIZE_M": 64,
136
+ "BLOCK_SIZE_N": 128,
137
+ "BLOCK_SIZE_K": 64,
138
+ "GROUP_SIZE_M": 32,
139
+ "num_warps": 8,
140
+ "num_stages": 2,
141
+ "waves_per_eu": 0,
142
+ "matrix_instr_nonkdim": 16,
143
+ "kpack": 1
144
+ },
145
+ "1024": {
146
+ "BLOCK_SIZE_M": 128,
147
+ "BLOCK_SIZE_N": 64,
148
+ "BLOCK_SIZE_K": 64,
149
+ "GROUP_SIZE_M": 16,
150
+ "num_warps": 4,
151
+ "num_stages": 2,
152
+ "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
+ "kpack": 2
155
+ },
156
+ "1536": {
157
+ "BLOCK_SIZE_M": 128,
158
+ "BLOCK_SIZE_N": 128,
159
+ "BLOCK_SIZE_K": 64,
160
+ "GROUP_SIZE_M": 8,
161
+ "num_warps": 8,
162
+ "num_stages": 2,
163
+ "waves_per_eu": 0,
164
+ "matrix_instr_nonkdim": 16,
165
+ "kpack": 2
166
+ },
167
+ "2048": {
168
+ "BLOCK_SIZE_M": 64,
169
+ "BLOCK_SIZE_N": 128,
170
+ "BLOCK_SIZE_K": 64,
171
+ "GROUP_SIZE_M": 8,
172
+ "num_warps": 8,
173
+ "num_stages": 2,
174
+ "waves_per_eu": 0,
175
+ "matrix_instr_nonkdim": 16,
176
+ "kpack": 1
177
+ },
178
+ "3072": {
179
+ "BLOCK_SIZE_M": 256,
180
+ "BLOCK_SIZE_N": 128,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 4,
183
+ "num_warps": 8,
184
+ "num_stages": 2,
185
+ "waves_per_eu": 0,
186
+ "matrix_instr_nonkdim": 16,
187
+ "kpack": 2
188
+ },
189
+ "4096": {
190
+ "BLOCK_SIZE_M": 128,
191
+ "BLOCK_SIZE_N": 128,
192
+ "BLOCK_SIZE_K": 64,
193
+ "GROUP_SIZE_M": 8,
194
+ "num_warps": 8,
195
+ "num_stages": 2,
196
+ "waves_per_eu": 0,
197
+ "matrix_instr_nonkdim": 16,
198
+ "kpack": 2
199
+ }
200
+ }
torch-ext/moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 16,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 2
12
+ },
13
+ "2": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 32,
16
+ "BLOCK_SIZE_K": 128,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 4,
19
+ "num_stages": 2,
20
+ "waves_per_eu": 0,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 2
23
+ },
24
+ "4": {
25
+ "BLOCK_SIZE_M": 16,
26
+ "BLOCK_SIZE_N": 64,
27
+ "BLOCK_SIZE_K": 128,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 4,
30
+ "num_stages": 2,
31
+ "waves_per_eu": 0,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 2
34
+ },
35
+ "8": {
36
+ "BLOCK_SIZE_M": 16,
37
+ "BLOCK_SIZE_N": 64,
38
+ "BLOCK_SIZE_K": 128,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 4,
41
+ "num_stages": 2,
42
+ "waves_per_eu": 0,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 2
45
+ },
46
+ "16": {
47
+ "BLOCK_SIZE_M": 32,
48
+ "BLOCK_SIZE_N": 64,
49
+ "BLOCK_SIZE_K": 64,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 4,
52
+ "num_stages": 2,
53
+ "waves_per_eu": 0,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 1
56
+ },
57
+ "24": {
58
+ "BLOCK_SIZE_M": 16,
59
+ "BLOCK_SIZE_N": 64,
60
+ "BLOCK_SIZE_K": 64,
61
+ "GROUP_SIZE_M": 1,
62
+ "num_warps": 2,
63
+ "num_stages": 2,
64
+ "waves_per_eu": 0,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 2
67
+ },
68
+ "32": {
69
+ "BLOCK_SIZE_M": 16,
70
+ "BLOCK_SIZE_N": 64,
71
+ "BLOCK_SIZE_K": 64,
72
+ "GROUP_SIZE_M": 1,
73
+ "num_warps": 2,
74
+ "num_stages": 2,
75
+ "waves_per_eu": 0,
76
+ "matrix_instr_nonkdim": 16,
77
+ "kpack": 2
78
+ },
79
+ "48": {
80
+ "BLOCK_SIZE_M": 16,
81
+ "BLOCK_SIZE_N": 32,
82
+ "BLOCK_SIZE_K": 64,
83
+ "GROUP_SIZE_M": 8,
84
+ "num_warps": 1,
85
+ "num_stages": 2,
86
+ "waves_per_eu": 0,
87
+ "matrix_instr_nonkdim": 16,
88
+ "kpack": 1
89
+ },
90
+ "64": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 32,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 8,
95
+ "num_warps": 2,
96
+ "num_stages": 2,
97
+ "waves_per_eu": 0,
98
+ "matrix_instr_nonkdim": 16,
99
+ "kpack": 1
100
+ },
101
+ "96": {
102
+ "BLOCK_SIZE_M": 16,
103
+ "BLOCK_SIZE_N": 32,
104
+ "BLOCK_SIZE_K": 64,
105
+ "GROUP_SIZE_M": 16,
106
+ "num_warps": 2,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0,
109
+ "matrix_instr_nonkdim": 16,
110
+ "kpack": 2
111
+ },
112
+ "128": {
113
+ "BLOCK_SIZE_M": 16,
114
+ "BLOCK_SIZE_N": 32,
115
+ "BLOCK_SIZE_K": 64,
116
+ "GROUP_SIZE_M": 4,
117
+ "num_warps": 2,
118
+ "num_stages": 2,
119
+ "waves_per_eu": 0,
120
+ "matrix_instr_nonkdim": 16,
121
+ "kpack": 2
122
+ },
123
+ "256": {
124
+ "BLOCK_SIZE_M": 32,
125
+ "BLOCK_SIZE_N": 32,
126
+ "BLOCK_SIZE_K": 64,
127
+ "GROUP_SIZE_M": 8,
128
+ "num_warps": 4,
129
+ "num_stages": 2,
130
+ "waves_per_eu": 0,
131
+ "matrix_instr_nonkdim": 16,
132
+ "kpack": 1
133
+ },
134
+ "512": {
135
+ "BLOCK_SIZE_M": 64,
136
+ "BLOCK_SIZE_N": 64,
137
+ "BLOCK_SIZE_K": 64,
138
+ "GROUP_SIZE_M": 32,
139
+ "num_warps": 4,
140
+ "num_stages": 2,
141
+ "waves_per_eu": 0,
142
+ "matrix_instr_nonkdim": 16,
143
+ "kpack": 2
144
+ },
145
+ "1024": {
146
+ "BLOCK_SIZE_M": 128,
147
+ "BLOCK_SIZE_N": 128,
148
+ "BLOCK_SIZE_K": 64,
149
+ "GROUP_SIZE_M": 4,
150
+ "num_warps": 4,
151
+ "num_stages": 2,
152
+ "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
+ "kpack": 2
155
+ },
156
+ "1536": {
157
+ "BLOCK_SIZE_M": 128,
158
+ "BLOCK_SIZE_N": 128,
159
+ "BLOCK_SIZE_K": 64,
160
+ "GROUP_SIZE_M": 16,
161
+ "num_warps": 4,
162
+ "num_stages": 2,
163
+ "waves_per_eu": 0,
164
+ "matrix_instr_nonkdim": 16,
165
+ "kpack": 1
166
+ },
167
+ "2048": {
168
+ "BLOCK_SIZE_M": 256,
169
+ "BLOCK_SIZE_N": 256,
170
+ "BLOCK_SIZE_K": 32,
171
+ "GROUP_SIZE_M": 16,
172
+ "num_warps": 8,
173
+ "num_stages": 2,
174
+ "waves_per_eu": 0,
175
+ "matrix_instr_nonkdim": 16,
176
+ "kpack": 1
177
+ },
178
+ "3072": {
179
+ "BLOCK_SIZE_M": 256,
180
+ "BLOCK_SIZE_N": 256,
181
+ "BLOCK_SIZE_K": 32,
182
+ "GROUP_SIZE_M": 16,
183
+ "num_warps": 8,
184
+ "num_stages": 2,
185
+ "waves_per_eu": 0,
186
+ "matrix_instr_nonkdim": 16,
187
+ "kpack": 1
188
+ },
189
+ "4096": {
190
+ "BLOCK_SIZE_M": 128,
191
+ "BLOCK_SIZE_N": 128,
192
+ "BLOCK_SIZE_K": 64,
193
+ "GROUP_SIZE_M": 8,
194
+ "num_warps": 4,
195
+ "num_stages": 2,
196
+ "waves_per_eu": 0,
197
+ "matrix_instr_nonkdim": 16,
198
+ "kpack": 2
199
+ }
200
+ }
torch-ext/moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 16,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 1
12
+ },
13
+ "2": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 64,
16
+ "BLOCK_SIZE_K": 128,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 4,
19
+ "num_stages": 2,
20
+ "waves_per_eu": 0,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 2
23
+ },
24
+ "4": {
25
+ "BLOCK_SIZE_M": 16,
26
+ "BLOCK_SIZE_N": 128,
27
+ "BLOCK_SIZE_K": 128,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 8,
30
+ "num_stages": 2,
31
+ "waves_per_eu": 0,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 2
34
+ },
35
+ "8": {
36
+ "BLOCK_SIZE_M": 32,
37
+ "BLOCK_SIZE_N": 64,
38
+ "BLOCK_SIZE_K": 64,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 4,
41
+ "num_stages": 2,
42
+ "waves_per_eu": 0,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 1
45
+ },
46
+ "16": {
47
+ "BLOCK_SIZE_M": 16,
48
+ "BLOCK_SIZE_N": 16,
49
+ "BLOCK_SIZE_K": 64,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 2,
52
+ "num_stages": 2,
53
+ "waves_per_eu": 0,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 1
56
+ },
57
+ "24": {
58
+ "BLOCK_SIZE_M": 16,
59
+ "BLOCK_SIZE_N": 64,
60
+ "BLOCK_SIZE_K": 256,
61
+ "GROUP_SIZE_M": 4,
62
+ "num_warps": 4,
63
+ "num_stages": 2,
64
+ "waves_per_eu": 0,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 2
67
+ },
68
+ "32": {
69
+ "BLOCK_SIZE_M": 16,
70
+ "BLOCK_SIZE_N": 64,
71
+ "BLOCK_SIZE_K": 256,
72
+ "GROUP_SIZE_M": 1,
73
+ "num_warps": 4,
74
+ "num_stages": 2,
75
+ "waves_per_eu": 0,
76
+ "matrix_instr_nonkdim": 16,
77
+ "kpack": 2
78
+ },
79
+ "48": {
80
+ "BLOCK_SIZE_M": 16,
81
+ "BLOCK_SIZE_N": 64,
82
+ "BLOCK_SIZE_K": 256,
83
+ "GROUP_SIZE_M": 4,
84
+ "num_warps": 4,
85
+ "num_stages": 2,
86
+ "waves_per_eu": 0,
87
+ "matrix_instr_nonkdim": 16,
88
+ "kpack": 2
89
+ },
90
+ "64": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 256,
94
+ "GROUP_SIZE_M": 4,
95
+ "num_warps": 4,
96
+ "num_stages": 2,
97
+ "waves_per_eu": 0,
98
+ "matrix_instr_nonkdim": 16,
99
+ "kpack": 2
100
+ },
101
+ "96": {
102
+ "BLOCK_SIZE_M": 16,
103
+ "BLOCK_SIZE_N": 64,
104
+ "BLOCK_SIZE_K": 256,
105
+ "GROUP_SIZE_M": 4,
106
+ "num_warps": 4,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0,
109
+ "matrix_instr_nonkdim": 16,
110
+ "kpack": 2
111
+ },
112
+ "128": {
113
+ "BLOCK_SIZE_M": 16,
114
+ "BLOCK_SIZE_N": 64,
115
+ "BLOCK_SIZE_K": 256,
116
+ "GROUP_SIZE_M": 4,
117
+ "num_warps": 4,
118
+ "num_stages": 2,
119
+ "waves_per_eu": 0,
120
+ "matrix_instr_nonkdim": 16,
121
+ "kpack": 2
122
+ },
123
+ "256": {
124
+ "BLOCK_SIZE_M": 32,
125
+ "BLOCK_SIZE_N": 32,
126
+ "BLOCK_SIZE_K": 256,
127
+ "GROUP_SIZE_M": 4,
128
+ "num_warps": 4,
129
+ "num_stages": 2,
130
+ "waves_per_eu": 0,
131
+ "matrix_instr_nonkdim": 16,
132
+ "kpack": 2
133
+ },
134
+ "512": {
135
+ "BLOCK_SIZE_M": 64,
136
+ "BLOCK_SIZE_N": 64,
137
+ "BLOCK_SIZE_K": 128,
138
+ "GROUP_SIZE_M": 4,
139
+ "num_warps": 8,
140
+ "num_stages": 2,
141
+ "waves_per_eu": 0,
142
+ "matrix_instr_nonkdim": 16,
143
+ "kpack": 2
144
+ },
145
+ "1024": {
146
+ "BLOCK_SIZE_M": 64,
147
+ "BLOCK_SIZE_N": 64,
148
+ "BLOCK_SIZE_K": 128,
149
+ "GROUP_SIZE_M": 1,
150
+ "num_warps": 4,
151
+ "num_stages": 2,
152
+ "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
+ "kpack": 2
155
+ },
156
+ "1536": {
157
+ "BLOCK_SIZE_M": 128,
158
+ "BLOCK_SIZE_N": 128,
159
+ "BLOCK_SIZE_K": 64,
160
+ "GROUP_SIZE_M": 32,
161
+ "num_warps": 8,
162
+ "num_stages": 2,
163
+ "waves_per_eu": 0,
164
+ "matrix_instr_nonkdim": 16,
165
+ "kpack": 2
166
+ },
167
+ "2048": {
168
+ "BLOCK_SIZE_M": 128,
169
+ "BLOCK_SIZE_N": 128,
170
+ "BLOCK_SIZE_K": 64,
171
+ "GROUP_SIZE_M": 1,
172
+ "num_warps": 8,
173
+ "num_stages": 2,
174
+ "waves_per_eu": 0,
175
+ "matrix_instr_nonkdim": 16,
176
+ "kpack": 2
177
+ },
178
+ "3072": {
179
+ "BLOCK_SIZE_M": 128,
180
+ "BLOCK_SIZE_N": 128,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 8,
183
+ "num_warps": 8,
184
+ "num_stages": 2,
185
+ "waves_per_eu": 0,
186
+ "matrix_instr_nonkdim": 16,
187
+ "kpack": 2
188
+ },
189
+ "4096": {
190
+ "BLOCK_SIZE_M": 128,
191
+ "BLOCK_SIZE_N": 128,
192
+ "BLOCK_SIZE_K": 64,
193
+ "GROUP_SIZE_M": 8,
194
+ "num_warps": 8,
195
+ "num_stages": 2,
196
+ "waves_per_eu": 0,
197
+ "matrix_instr_nonkdim": 16,
198
+ "kpack": 2
199
+ }
200
+ }
torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 8,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 32,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 32,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 32,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 256,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 8,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 8,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 8,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 8,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 8,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 256,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 5
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 4
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 4
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 4
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 128,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 8,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 2
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 2
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 8,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 8,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 32,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 32,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 128,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 8,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 256,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 8,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 8,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 256,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 8,
64
+ "num_stages": 4
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 128,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 8,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 256,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 256,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 256,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 5
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 256,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 5
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 256,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 5
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 128,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 8,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 2
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 5
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 5
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 256,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 8,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 32,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 32,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 128,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 8,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 8,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 8,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 32,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 2
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 2
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 32,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 8,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 8,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H200.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 8,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 32,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 8,
32
+ "num_stages": 5
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 5
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 32,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 32,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 32,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 32,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 256,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 8,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 8,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 32,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 2
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 32,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 8,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 32,
37
+ "BLOCK_SIZE_K": 256,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 5
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 2
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 8,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 256,
62
+ "GROUP_SIZE_M": 64,
63
+ "num_warps": 8,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 32,
68
+ "BLOCK_SIZE_N": 32,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 32,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 32,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 8,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 8,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 8,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 8,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 8,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 8,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 8,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 8,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 8,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 8,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_H200.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 2
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 128,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 8,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 8,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0
10
+ },
11
+ "2": {
12
+ "BLOCK_SIZE_M": 16,
13
+ "BLOCK_SIZE_N": 16,
14
+ "BLOCK_SIZE_K": 256,
15
+ "GROUP_SIZE_M": 1,
16
+ "num_warps": 2,
17
+ "num_stages": 2,
18
+ "waves_per_eu": 0
19
+ },
20
+ "4": {
21
+ "BLOCK_SIZE_M": 16,
22
+ "BLOCK_SIZE_N": 64,
23
+ "BLOCK_SIZE_K": 256,
24
+ "GROUP_SIZE_M": 1,
25
+ "num_warps": 4,
26
+ "num_stages": 2,
27
+ "waves_per_eu": 0
28
+ },
29
+ "8": {
30
+ "BLOCK_SIZE_M": 16,
31
+ "BLOCK_SIZE_N": 32,
32
+ "BLOCK_SIZE_K": 256,
33
+ "GROUP_SIZE_M": 1,
34
+ "num_warps": 2,
35
+ "num_stages": 2,
36
+ "waves_per_eu": 0
37
+ },
38
+ "16": {
39
+ "BLOCK_SIZE_M": 16,
40
+ "BLOCK_SIZE_N": 64,
41
+ "BLOCK_SIZE_K": 256,
42
+ "GROUP_SIZE_M": 1,
43
+ "num_warps": 2,
44
+ "num_stages": 2,
45
+ "waves_per_eu": 0
46
+ },
47
+ "24": {
48
+ "BLOCK_SIZE_M": 16,
49
+ "BLOCK_SIZE_N": 64,
50
+ "BLOCK_SIZE_K": 256,
51
+ "GROUP_SIZE_M": 1,
52
+ "num_warps": 2,
53
+ "num_stages": 2,
54
+ "waves_per_eu": 0
55
+ },
56
+ "32": {
57
+ "BLOCK_SIZE_M": 16,
58
+ "BLOCK_SIZE_N": 32,
59
+ "BLOCK_SIZE_K": 256,
60
+ "GROUP_SIZE_M": 4,
61
+ "num_warps": 2,
62
+ "num_stages": 2,
63
+ "waves_per_eu": 0
64
+ },
65
+ "48": {
66
+ "BLOCK_SIZE_M": 16,
67
+ "BLOCK_SIZE_N": 64,
68
+ "BLOCK_SIZE_K": 128,
69
+ "GROUP_SIZE_M": 4,
70
+ "num_warps": 4,
71
+ "num_stages": 2,
72
+ "waves_per_eu": 0
73
+ },
74
+ "64": {
75
+ "BLOCK_SIZE_M": 32,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 4,
79
+ "num_warps": 2,
80
+ "num_stages": 2,
81
+ "waves_per_eu": 0
82
+ },
83
+ "96": {
84
+ "BLOCK_SIZE_M": 32,
85
+ "BLOCK_SIZE_N": 64,
86
+ "BLOCK_SIZE_K": 256,
87
+ "GROUP_SIZE_M": 1,
88
+ "num_warps": 2,
89
+ "num_stages": 2,
90
+ "waves_per_eu": 0
91
+ },
92
+ "128": {
93
+ "BLOCK_SIZE_M": 64,
94
+ "BLOCK_SIZE_N": 128,
95
+ "BLOCK_SIZE_K": 256,
96
+ "GROUP_SIZE_M": 4,
97
+ "num_warps": 8,
98
+ "num_stages": 2,
99
+ "waves_per_eu": 0
100
+ },
101
+ "256": {
102
+ "BLOCK_SIZE_M": 128,
103
+ "BLOCK_SIZE_N": 128,
104
+ "BLOCK_SIZE_K": 256,
105
+ "GROUP_SIZE_M": 4,
106
+ "num_warps": 8,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0
109
+ },
110
+ "512": {
111
+ "BLOCK_SIZE_M": 256,
112
+ "BLOCK_SIZE_N": 128,
113
+ "BLOCK_SIZE_K": 128,
114
+ "GROUP_SIZE_M": 4,
115
+ "num_warps": 8,
116
+ "num_stages": 2,
117
+ "waves_per_eu": 0
118
+ },
119
+ "1024": {
120
+ "BLOCK_SIZE_M": 128,
121
+ "BLOCK_SIZE_N": 128,
122
+ "BLOCK_SIZE_K": 64,
123
+ "GROUP_SIZE_M": 1,
124
+ "num_warps": 4,
125
+ "num_stages": 2,
126
+ "waves_per_eu": 0
127
+ },
128
+ "1536": {
129
+ "BLOCK_SIZE_M": 128,
130
+ "BLOCK_SIZE_N": 256,
131
+ "BLOCK_SIZE_K": 128,
132
+ "GROUP_SIZE_M": 1,
133
+ "num_warps": 8,
134
+ "num_stages": 2,
135
+ "waves_per_eu": 0
136
+ },
137
+ "2048": {
138
+ "BLOCK_SIZE_M": 128,
139
+ "BLOCK_SIZE_N": 256,
140
+ "BLOCK_SIZE_K": 128,
141
+ "GROUP_SIZE_M": 1,
142
+ "num_warps": 8,
143
+ "num_stages": 2,
144
+ "waves_per_eu": 0
145
+ },
146
+ "3072": {
147
+ "BLOCK_SIZE_M": 128,
148
+ "BLOCK_SIZE_N": 256,
149
+ "BLOCK_SIZE_K": 128,
150
+ "GROUP_SIZE_M": 1,
151
+ "num_warps": 8,
152
+ "num_stages": 2,
153
+ "waves_per_eu": 0
154
+ },
155
+ "4096": {
156
+ "BLOCK_SIZE_M": 256,
157
+ "BLOCK_SIZE_N": 256,
158
+ "BLOCK_SIZE_K": 64,
159
+ "GROUP_SIZE_M": 1,
160
+ "num_warps": 8,
161
+ "num_stages": 2,
162
+ "waves_per_eu": 0
163
+ }
164
+ }
torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json CHANGED
@@ -1,123 +1,123 @@
1
  {
2
  "1": {
3
  "BLOCK_SIZE_M": 16,
4
- "BLOCK_SIZE_N": 32,
5
  "BLOCK_SIZE_K": 256,
6
  "GROUP_SIZE_M": 1,
7
  "num_warps": 2,
8
- "num_stages": 0,
9
  "waves_per_eu": 0,
10
  "matrix_instr_nonkdim": 16,
11
- "kpack": 1
12
  },
13
  "2": {
14
  "BLOCK_SIZE_M": 16,
15
  "BLOCK_SIZE_N": 16,
16
- "BLOCK_SIZE_K": 128,
17
  "GROUP_SIZE_M": 1,
18
- "num_warps": 2,
19
- "num_stages": 0,
20
  "waves_per_eu": 0,
21
  "matrix_instr_nonkdim": 16,
22
  "kpack": 2
23
  },
24
  "4": {
25
  "BLOCK_SIZE_M": 16,
26
- "BLOCK_SIZE_N": 32,
27
- "BLOCK_SIZE_K": 256,
28
  "GROUP_SIZE_M": 1,
29
- "num_warps": 2,
30
- "num_stages": 0,
31
  "waves_per_eu": 0,
32
  "matrix_instr_nonkdim": 16,
33
  "kpack": 2
34
  },
35
  "8": {
36
  "BLOCK_SIZE_M": 16,
37
- "BLOCK_SIZE_N": 16,
38
- "BLOCK_SIZE_K": 256,
39
  "GROUP_SIZE_M": 1,
40
- "num_warps": 1,
41
- "num_stages": 0,
42
  "waves_per_eu": 0,
43
  "matrix_instr_nonkdim": 16,
44
  "kpack": 2
45
  },
46
  "16": {
47
  "BLOCK_SIZE_M": 16,
48
- "BLOCK_SIZE_N": 16,
49
- "BLOCK_SIZE_K": 256,
50
  "GROUP_SIZE_M": 1,
51
- "num_warps": 4,
52
- "num_stages": 0,
53
  "waves_per_eu": 0,
54
  "matrix_instr_nonkdim": 16,
55
  "kpack": 2
56
  },
57
  "24": {
58
  "BLOCK_SIZE_M": 16,
59
- "BLOCK_SIZE_N": 32,
60
- "BLOCK_SIZE_K": 64,
61
  "GROUP_SIZE_M": 1,
62
- "num_warps": 1,
63
- "num_stages": 0,
64
  "waves_per_eu": 0,
65
  "matrix_instr_nonkdim": 16,
66
  "kpack": 2
67
  },
68
  "32": {
69
  "BLOCK_SIZE_M": 16,
70
- "BLOCK_SIZE_N": 16,
71
- "BLOCK_SIZE_K": 128,
72
  "GROUP_SIZE_M": 4,
73
  "num_warps": 2,
74
- "num_stages": 0,
75
  "waves_per_eu": 0,
76
  "matrix_instr_nonkdim": 16,
77
- "kpack": 1
78
  },
79
  "48": {
80
  "BLOCK_SIZE_M": 16,
81
- "BLOCK_SIZE_N": 16,
82
  "BLOCK_SIZE_K": 128,
83
  "GROUP_SIZE_M": 4,
84
- "num_warps": 2,
85
- "num_stages": 0,
86
  "waves_per_eu": 0,
87
  "matrix_instr_nonkdim": 16,
88
- "kpack": 2
89
  },
90
  "64": {
91
  "BLOCK_SIZE_M": 32,
92
  "BLOCK_SIZE_N": 64,
93
  "BLOCK_SIZE_K": 128,
94
  "GROUP_SIZE_M": 4,
95
- "num_warps": 8,
96
- "num_stages": 0,
97
  "waves_per_eu": 0,
98
  "matrix_instr_nonkdim": 16,
99
  "kpack": 2
100
  },
101
  "96": {
102
  "BLOCK_SIZE_M": 32,
103
- "BLOCK_SIZE_N": 32,
104
- "BLOCK_SIZE_K": 128,
105
  "GROUP_SIZE_M": 4,
106
- "num_warps": 4,
107
- "num_stages": 0,
108
  "waves_per_eu": 0,
109
  "matrix_instr_nonkdim": 16,
110
- "kpack": 2
111
  },
112
  "128": {
113
  "BLOCK_SIZE_M": 64,
114
  "BLOCK_SIZE_N": 64,
115
- "BLOCK_SIZE_K": 64,
116
  "GROUP_SIZE_M": 4,
117
- "num_warps": 8,
118
- "num_stages": 0,
119
  "waves_per_eu": 0,
120
- "matrix_instr_nonkdim": 16,
121
  "kpack": 2
122
  },
123
  "256": {
@@ -126,10 +126,10 @@
126
  "BLOCK_SIZE_K": 64,
127
  "GROUP_SIZE_M": 4,
128
  "num_warps": 8,
129
- "num_stages": 0,
130
  "waves_per_eu": 0,
131
  "matrix_instr_nonkdim": 16,
132
- "kpack": 1
133
  },
134
  "512": {
135
  "BLOCK_SIZE_M": 128,
@@ -137,7 +137,7 @@
137
  "BLOCK_SIZE_K": 64,
138
  "GROUP_SIZE_M": 4,
139
  "num_warps": 8,
140
- "num_stages": 0,
141
  "waves_per_eu": 0,
142
  "matrix_instr_nonkdim": 16,
143
  "kpack": 2
@@ -148,9 +148,9 @@
148
  "BLOCK_SIZE_K": 64,
149
  "GROUP_SIZE_M": 1,
150
  "num_warps": 8,
151
- "num_stages": 0,
152
  "waves_per_eu": 0,
153
- "matrix_instr_nonkdim": 32,
154
  "kpack": 2
155
  },
156
  "1536": {
@@ -159,7 +159,7 @@
159
  "BLOCK_SIZE_K": 64,
160
  "GROUP_SIZE_M": 1,
161
  "num_warps": 8,
162
- "num_stages": 0,
163
  "waves_per_eu": 0,
164
  "matrix_instr_nonkdim": 16,
165
  "kpack": 2
@@ -170,7 +170,7 @@
170
  "BLOCK_SIZE_K": 64,
171
  "GROUP_SIZE_M": 1,
172
  "num_warps": 8,
173
- "num_stages": 0,
174
  "waves_per_eu": 0,
175
  "matrix_instr_nonkdim": 16,
176
  "kpack": 2
@@ -181,10 +181,10 @@
181
  "BLOCK_SIZE_K": 64,
182
  "GROUP_SIZE_M": 1,
183
  "num_warps": 8,
184
- "num_stages": 0,
185
  "waves_per_eu": 0,
186
  "matrix_instr_nonkdim": 16,
187
- "kpack": 1
188
  },
189
  "4096": {
190
  "BLOCK_SIZE_M": 128,
@@ -192,9 +192,9 @@
192
  "BLOCK_SIZE_K": 64,
193
  "GROUP_SIZE_M": 1,
194
  "num_warps": 8,
195
- "num_stages": 0,
196
  "waves_per_eu": 0,
197
  "matrix_instr_nonkdim": 16,
198
- "kpack": 1
199
  }
200
  }
 
1
  {
2
  "1": {
3
  "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 16,
5
  "BLOCK_SIZE_K": 256,
6
  "GROUP_SIZE_M": 1,
7
  "num_warps": 2,
8
+ "num_stages": 2,
9
  "waves_per_eu": 0,
10
  "matrix_instr_nonkdim": 16,
11
+ "kpack": 2
12
  },
13
  "2": {
14
  "BLOCK_SIZE_M": 16,
15
  "BLOCK_SIZE_N": 16,
16
+ "BLOCK_SIZE_K": 256,
17
  "GROUP_SIZE_M": 1,
18
+ "num_warps": 4,
19
+ "num_stages": 2,
20
  "waves_per_eu": 0,
21
  "matrix_instr_nonkdim": 16,
22
  "kpack": 2
23
  },
24
  "4": {
25
  "BLOCK_SIZE_M": 16,
26
+ "BLOCK_SIZE_N": 16,
27
+ "BLOCK_SIZE_K": 128,
28
  "GROUP_SIZE_M": 1,
29
+ "num_warps": 1,
30
+ "num_stages": 2,
31
  "waves_per_eu": 0,
32
  "matrix_instr_nonkdim": 16,
33
  "kpack": 2
34
  },
35
  "8": {
36
  "BLOCK_SIZE_M": 16,
37
+ "BLOCK_SIZE_N": 64,
38
+ "BLOCK_SIZE_K": 64,
39
  "GROUP_SIZE_M": 1,
40
+ "num_warps": 2,
41
+ "num_stages": 2,
42
  "waves_per_eu": 0,
43
  "matrix_instr_nonkdim": 16,
44
  "kpack": 2
45
  },
46
  "16": {
47
  "BLOCK_SIZE_M": 16,
48
+ "BLOCK_SIZE_N": 64,
49
+ "BLOCK_SIZE_K": 64,
50
  "GROUP_SIZE_M": 1,
51
+ "num_warps": 2,
52
+ "num_stages": 2,
53
  "waves_per_eu": 0,
54
  "matrix_instr_nonkdim": 16,
55
  "kpack": 2
56
  },
57
  "24": {
58
  "BLOCK_SIZE_M": 16,
59
+ "BLOCK_SIZE_N": 16,
60
+ "BLOCK_SIZE_K": 256,
61
  "GROUP_SIZE_M": 1,
62
+ "num_warps": 2,
63
+ "num_stages": 2,
64
  "waves_per_eu": 0,
65
  "matrix_instr_nonkdim": 16,
66
  "kpack": 2
67
  },
68
  "32": {
69
  "BLOCK_SIZE_M": 16,
70
+ "BLOCK_SIZE_N": 32,
71
+ "BLOCK_SIZE_K": 256,
72
  "GROUP_SIZE_M": 4,
73
  "num_warps": 2,
74
+ "num_stages": 2,
75
  "waves_per_eu": 0,
76
  "matrix_instr_nonkdim": 16,
77
+ "kpack": 2
78
  },
79
  "48": {
80
  "BLOCK_SIZE_M": 16,
81
+ "BLOCK_SIZE_N": 64,
82
  "BLOCK_SIZE_K": 128,
83
  "GROUP_SIZE_M": 4,
84
+ "num_warps": 4,
85
+ "num_stages": 2,
86
  "waves_per_eu": 0,
87
  "matrix_instr_nonkdim": 16,
88
+ "kpack": 1
89
  },
90
  "64": {
91
  "BLOCK_SIZE_M": 32,
92
  "BLOCK_SIZE_N": 64,
93
  "BLOCK_SIZE_K": 128,
94
  "GROUP_SIZE_M": 4,
95
+ "num_warps": 4,
96
+ "num_stages": 2,
97
  "waves_per_eu": 0,
98
  "matrix_instr_nonkdim": 16,
99
  "kpack": 2
100
  },
101
  "96": {
102
  "BLOCK_SIZE_M": 32,
103
+ "BLOCK_SIZE_N": 64,
104
+ "BLOCK_SIZE_K": 256,
105
  "GROUP_SIZE_M": 4,
106
+ "num_warps": 8,
107
+ "num_stages": 2,
108
  "waves_per_eu": 0,
109
  "matrix_instr_nonkdim": 16,
110
+ "kpack": 1
111
  },
112
  "128": {
113
  "BLOCK_SIZE_M": 64,
114
  "BLOCK_SIZE_N": 64,
115
+ "BLOCK_SIZE_K": 128,
116
  "GROUP_SIZE_M": 4,
117
+ "num_warps": 4,
118
+ "num_stages": 2,
119
  "waves_per_eu": 0,
120
+ "matrix_instr_nonkdim": 32,
121
  "kpack": 2
122
  },
123
  "256": {
 
126
  "BLOCK_SIZE_K": 64,
127
  "GROUP_SIZE_M": 4,
128
  "num_warps": 8,
129
+ "num_stages": 2,
130
  "waves_per_eu": 0,
131
  "matrix_instr_nonkdim": 16,
132
+ "kpack": 2
133
  },
134
  "512": {
135
  "BLOCK_SIZE_M": 128,
 
137
  "BLOCK_SIZE_K": 64,
138
  "GROUP_SIZE_M": 4,
139
  "num_warps": 8,
140
+ "num_stages": 2,
141
  "waves_per_eu": 0,
142
  "matrix_instr_nonkdim": 16,
143
  "kpack": 2
 
148
  "BLOCK_SIZE_K": 64,
149
  "GROUP_SIZE_M": 1,
150
  "num_warps": 8,
151
+ "num_stages": 2,
152
  "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
  "kpack": 2
155
  },
156
  "1536": {
 
159
  "BLOCK_SIZE_K": 64,
160
  "GROUP_SIZE_M": 1,
161
  "num_warps": 8,
162
+ "num_stages": 2,
163
  "waves_per_eu": 0,
164
  "matrix_instr_nonkdim": 16,
165
  "kpack": 2
 
170
  "BLOCK_SIZE_K": 64,
171
  "GROUP_SIZE_M": 1,
172
  "num_warps": 8,
173
+ "num_stages": 2,
174
  "waves_per_eu": 0,
175
  "matrix_instr_nonkdim": 16,
176
  "kpack": 2
 
181
  "BLOCK_SIZE_K": 64,
182
  "GROUP_SIZE_M": 1,
183
  "num_warps": 8,
184
+ "num_stages": 2,
185
  "waves_per_eu": 0,
186
  "matrix_instr_nonkdim": 16,
187
+ "kpack": 2
188
  },
189
  "4096": {
190
  "BLOCK_SIZE_M": 128,
 
192
  "BLOCK_SIZE_K": 64,
193
  "GROUP_SIZE_M": 1,
194
  "num_warps": 8,
195
+ "num_stages": 2,
196
  "waves_per_eu": 0,
197
  "matrix_instr_nonkdim": 16,
198
+ "kpack": 2
199
  }
200
  }
torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 16,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 2,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0
10
+ },
11
+ "2": {
12
+ "BLOCK_SIZE_M": 16,
13
+ "BLOCK_SIZE_N": 32,
14
+ "BLOCK_SIZE_K": 256,
15
+ "GROUP_SIZE_M": 1,
16
+ "num_warps": 4,
17
+ "num_stages": 2,
18
+ "waves_per_eu": 0
19
+ },
20
+ "4": {
21
+ "BLOCK_SIZE_M": 16,
22
+ "BLOCK_SIZE_N": 16,
23
+ "BLOCK_SIZE_K": 256,
24
+ "GROUP_SIZE_M": 1,
25
+ "num_warps": 4,
26
+ "num_stages": 2,
27
+ "waves_per_eu": 0
28
+ },
29
+ "8": {
30
+ "BLOCK_SIZE_M": 16,
31
+ "BLOCK_SIZE_N": 64,
32
+ "BLOCK_SIZE_K": 256,
33
+ "GROUP_SIZE_M": 1,
34
+ "num_warps": 2,
35
+ "num_stages": 2,
36
+ "waves_per_eu": 0
37
+ },
38
+ "16": {
39
+ "BLOCK_SIZE_M": 64,
40
+ "BLOCK_SIZE_N": 64,
41
+ "BLOCK_SIZE_K": 256,
42
+ "GROUP_SIZE_M": 1,
43
+ "num_warps": 4,
44
+ "num_stages": 2,
45
+ "waves_per_eu": 0
46
+ },
47
+ "24": {
48
+ "BLOCK_SIZE_M": 32,
49
+ "BLOCK_SIZE_N": 64,
50
+ "BLOCK_SIZE_K": 256,
51
+ "GROUP_SIZE_M": 1,
52
+ "num_warps": 2,
53
+ "num_stages": 2,
54
+ "waves_per_eu": 0
55
+ },
56
+ "32": {
57
+ "BLOCK_SIZE_M": 16,
58
+ "BLOCK_SIZE_N": 32,
59
+ "BLOCK_SIZE_K": 256,
60
+ "GROUP_SIZE_M": 4,
61
+ "num_warps": 2,
62
+ "num_stages": 2,
63
+ "waves_per_eu": 0
64
+ },
65
+ "48": {
66
+ "BLOCK_SIZE_M": 16,
67
+ "BLOCK_SIZE_N": 64,
68
+ "BLOCK_SIZE_K": 256,
69
+ "GROUP_SIZE_M": 1,
70
+ "num_warps": 4,
71
+ "num_stages": 2,
72
+ "waves_per_eu": 0
73
+ },
74
+ "64": {
75
+ "BLOCK_SIZE_M": 32,
76
+ "BLOCK_SIZE_N": 16,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 4,
79
+ "num_warps": 2,
80
+ "num_stages": 2,
81
+ "waves_per_eu": 0
82
+ },
83
+ "96": {
84
+ "BLOCK_SIZE_M": 32,
85
+ "BLOCK_SIZE_N": 64,
86
+ "BLOCK_SIZE_K": 256,
87
+ "GROUP_SIZE_M": 1,
88
+ "num_warps": 2,
89
+ "num_stages": 2,
90
+ "waves_per_eu": 0
91
+ },
92
+ "128": {
93
+ "BLOCK_SIZE_M": 64,
94
+ "BLOCK_SIZE_N": 64,
95
+ "BLOCK_SIZE_K": 256,
96
+ "GROUP_SIZE_M": 4,
97
+ "num_warps": 4,
98
+ "num_stages": 2,
99
+ "waves_per_eu": 0
100
+ },
101
+ "256": {
102
+ "BLOCK_SIZE_M": 128,
103
+ "BLOCK_SIZE_N": 128,
104
+ "BLOCK_SIZE_K": 256,
105
+ "GROUP_SIZE_M": 4,
106
+ "num_warps": 8,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0
109
+ },
110
+ "512": {
111
+ "BLOCK_SIZE_M": 256,
112
+ "BLOCK_SIZE_N": 128,
113
+ "BLOCK_SIZE_K": 128,
114
+ "GROUP_SIZE_M": 4,
115
+ "num_warps": 8,
116
+ "num_stages": 2,
117
+ "waves_per_eu": 0
118
+ },
119
+ "1024": {
120
+ "BLOCK_SIZE_M": 128,
121
+ "BLOCK_SIZE_N": 128,
122
+ "BLOCK_SIZE_K": 256,
123
+ "GROUP_SIZE_M": 1,
124
+ "num_warps": 8,
125
+ "num_stages": 2,
126
+ "waves_per_eu": 0
127
+ },
128
+ "1536": {
129
+ "BLOCK_SIZE_M": 256,
130
+ "BLOCK_SIZE_N": 256,
131
+ "BLOCK_SIZE_K": 64,
132
+ "GROUP_SIZE_M": 1,
133
+ "num_warps": 8,
134
+ "num_stages": 2,
135
+ "waves_per_eu": 0
136
+ },
137
+ "2048": {
138
+ "BLOCK_SIZE_M": 128,
139
+ "BLOCK_SIZE_N": 256,
140
+ "BLOCK_SIZE_K": 128,
141
+ "GROUP_SIZE_M": 1,
142
+ "num_warps": 8,
143
+ "num_stages": 2,
144
+ "waves_per_eu": 0
145
+ },
146
+ "3072": {
147
+ "BLOCK_SIZE_M": 128,
148
+ "BLOCK_SIZE_N": 256,
149
+ "BLOCK_SIZE_K": 128,
150
+ "GROUP_SIZE_M": 1,
151
+ "num_warps": 8,
152
+ "num_stages": 2,
153
+ "waves_per_eu": 0
154
+ },
155
+ "4096": {
156
+ "BLOCK_SIZE_M": 256,
157
+ "BLOCK_SIZE_N": 256,
158
+ "BLOCK_SIZE_K": 64,
159
+ "GROUP_SIZE_M": 1,
160
+ "num_warps": 8,
161
+ "num_stages": 2,
162
+ "waves_per_eu": 0
163
+ }
164
+ }
torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 16,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 2,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 1
12
+ },
13
+ "2": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 16,
16
+ "BLOCK_SIZE_K": 256,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 4,
19
+ "num_stages": 2,
20
+ "waves_per_eu": 0,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 2
23
+ },
24
+ "4": {
25
+ "BLOCK_SIZE_M": 16,
26
+ "BLOCK_SIZE_N": 16,
27
+ "BLOCK_SIZE_K": 128,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 1,
30
+ "num_stages": 2,
31
+ "waves_per_eu": 0,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 2
34
+ },
35
+ "8": {
36
+ "BLOCK_SIZE_M": 16,
37
+ "BLOCK_SIZE_N": 32,
38
+ "BLOCK_SIZE_K": 128,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 2,
41
+ "num_stages": 2,
42
+ "waves_per_eu": 0,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 2
45
+ },
46
+ "16": {
47
+ "BLOCK_SIZE_M": 16,
48
+ "BLOCK_SIZE_N": 64,
49
+ "BLOCK_SIZE_K": 64,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 2,
52
+ "num_stages": 2,
53
+ "waves_per_eu": 0,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 2
56
+ },
57
+ "24": {
58
+ "BLOCK_SIZE_M": 16,
59
+ "BLOCK_SIZE_N": 128,
60
+ "BLOCK_SIZE_K": 64,
61
+ "GROUP_SIZE_M": 1,
62
+ "num_warps": 4,
63
+ "num_stages": 2,
64
+ "waves_per_eu": 0,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 2
67
+ },
68
+ "32": {
69
+ "BLOCK_SIZE_M": 16,
70
+ "BLOCK_SIZE_N": 64,
71
+ "BLOCK_SIZE_K": 128,
72
+ "GROUP_SIZE_M": 4,
73
+ "num_warps": 4,
74
+ "num_stages": 2,
75
+ "waves_per_eu": 0,
76
+ "matrix_instr_nonkdim": 16,
77
+ "kpack": 2
78
+ },
79
+ "48": {
80
+ "BLOCK_SIZE_M": 16,
81
+ "BLOCK_SIZE_N": 64,
82
+ "BLOCK_SIZE_K": 128,
83
+ "GROUP_SIZE_M": 4,
84
+ "num_warps": 1,
85
+ "num_stages": 2,
86
+ "waves_per_eu": 0,
87
+ "matrix_instr_nonkdim": 16,
88
+ "kpack": 1
89
+ },
90
+ "64": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 4,
95
+ "num_warps": 8,
96
+ "num_stages": 2,
97
+ "waves_per_eu": 0,
98
+ "matrix_instr_nonkdim": 16,
99
+ "kpack": 2
100
+ },
101
+ "96": {
102
+ "BLOCK_SIZE_M": 32,
103
+ "BLOCK_SIZE_N": 64,
104
+ "BLOCK_SIZE_K": 256,
105
+ "GROUP_SIZE_M": 4,
106
+ "num_warps": 8,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0,
109
+ "matrix_instr_nonkdim": 16,
110
+ "kpack": 1
111
+ },
112
+ "128": {
113
+ "BLOCK_SIZE_M": 64,
114
+ "BLOCK_SIZE_N": 64,
115
+ "BLOCK_SIZE_K": 128,
116
+ "GROUP_SIZE_M": 4,
117
+ "num_warps": 4,
118
+ "num_stages": 2,
119
+ "waves_per_eu": 0,
120
+ "matrix_instr_nonkdim": 32,
121
+ "kpack": 2
122
+ },
123
+ "256": {
124
+ "BLOCK_SIZE_M": 128,
125
+ "BLOCK_SIZE_N": 128,
126
+ "BLOCK_SIZE_K": 64,
127
+ "GROUP_SIZE_M": 4,
128
+ "num_warps": 8,
129
+ "num_stages": 2,
130
+ "waves_per_eu": 0,
131
+ "matrix_instr_nonkdim": 16,
132
+ "kpack": 2
133
+ },
134
+ "512": {
135
+ "BLOCK_SIZE_M": 128,
136
+ "BLOCK_SIZE_N": 128,
137
+ "BLOCK_SIZE_K": 64,
138
+ "GROUP_SIZE_M": 1,
139
+ "num_warps": 8,
140
+ "num_stages": 2,
141
+ "waves_per_eu": 0,
142
+ "matrix_instr_nonkdim": 16,
143
+ "kpack": 2
144
+ },
145
+ "1024": {
146
+ "BLOCK_SIZE_M": 128,
147
+ "BLOCK_SIZE_N": 128,
148
+ "BLOCK_SIZE_K": 64,
149
+ "GROUP_SIZE_M": 1,
150
+ "num_warps": 8,
151
+ "num_stages": 2,
152
+ "waves_per_eu": 0,
153
+ "matrix_instr_nonkdim": 16,
154
+ "kpack": 2
155
+ },
156
+ "1536": {
157
+ "BLOCK_SIZE_M": 128,
158
+ "BLOCK_SIZE_N": 128,
159
+ "BLOCK_SIZE_K": 64,
160
+ "GROUP_SIZE_M": 1,
161
+ "num_warps": 8,
162
+ "num_stages": 2,
163
+ "waves_per_eu": 0,
164
+ "matrix_instr_nonkdim": 16,
165
+ "kpack": 2
166
+ },
167
+ "2048": {
168
+ "BLOCK_SIZE_M": 128,
169
+ "BLOCK_SIZE_N": 128,
170
+ "BLOCK_SIZE_K": 64,
171
+ "GROUP_SIZE_M": 1,
172
+ "num_warps": 8,
173
+ "num_stages": 2,
174
+ "waves_per_eu": 0,
175
+ "matrix_instr_nonkdim": 16,
176
+ "kpack": 2
177
+ },
178
+ "3072": {
179
+ "BLOCK_SIZE_M": 128,
180
+ "BLOCK_SIZE_N": 128,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 1,
183
+ "num_warps": 8,
184
+ "num_stages": 2,
185
+ "waves_per_eu": 0,
186
+ "matrix_instr_nonkdim": 16,
187
+ "kpack": 2
188
+ },
189
+ "4096": {
190
+ "BLOCK_SIZE_M": 128,
191
+ "BLOCK_SIZE_N": 128,
192
+ "BLOCK_SIZE_K": 64,
193
+ "GROUP_SIZE_M": 1,
194
+ "num_warps": 8,
195
+ "num_stages": 2,
196
+ "waves_per_eu": 0,
197
+ "matrix_instr_nonkdim": 16,
198
+ "kpack": 2
199
+ }
200
+ }
torch-ext/moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 256,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 8,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 256,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 256,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 5
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 128,
92
+ "BLOCK_SIZE_N": 256,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 8,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 128,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 8,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 8,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 8,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 8,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }