Sync with vLLM and add `Llama4TextMoe` layer
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- build.toml +2 -0
- moe/moe_align_sum_kernels.cu +39 -17
- moe/moe_wna16.cu +346 -0
- moe/moe_wna16_utils.h +200 -0
- tests/kernels/test_moe.py +229 -125
- tests/kernels/utils.py +7 -7
- torch-ext/moe/__init__.py +2 -18
- torch-ext/moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- torch-ext/moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json +200 -0
- torch-ext/moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
- torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
- torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
- torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
- torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
- torch-ext/moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
- torch-ext/moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
- torch-ext/moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
- torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
- torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
- torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
- torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
- torch-ext/moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
- torch-ext/moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- torch-ext/moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json +200 -0
- torch-ext/moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json +200 -0
- torch-ext/moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json +200 -0
- torch-ext/moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json +200 -0
- torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- torch-ext/moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- torch-ext/moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- torch-ext/moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
- torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
- torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +52 -52
- torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
- torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
- torch-ext/moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
build.toml
CHANGED
@@ -31,6 +31,8 @@ src = [
|
|
31 |
"cuda_compat.h",
|
32 |
"dispatch_utils.h",
|
33 |
"moe/moe_align_sum_kernels.cu",
|
|
|
|
|
34 |
"moe/topk_softmax_kernels.cu",
|
35 |
]
|
36 |
depends = ["torch"]
|
|
|
31 |
"cuda_compat.h",
|
32 |
"dispatch_utils.h",
|
33 |
"moe/moe_align_sum_kernels.cu",
|
34 |
+
"moe/moe_wna16.cu",
|
35 |
+
"moe/moe_wna16_utils.h",
|
36 |
"moe/topk_softmax_kernels.cu",
|
37 |
]
|
38 |
depends = ["torch"]
|
moe/moe_align_sum_kernels.cu
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
#include <c10/cuda/CUDAGuard.h>
|
4 |
|
5 |
#include <ATen/ATen.h>
|
6 |
-
#include <
|
7 |
|
8 |
#include "../cuda_compat.h"
|
9 |
#include "../dispatch_utils.h"
|
@@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel(
|
|
198 |
}
|
199 |
|
200 |
// taken from
|
201 |
-
// https://github.com/sgl-project/sglang/commit/
|
202 |
template <typename scalar_t>
|
203 |
__global__ void sgl_moe_align_block_size_kernel(
|
204 |
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
205 |
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
206 |
int32_t block_size, size_t numel, int32_t* cumsum) {
|
207 |
__shared__ int32_t shared_counts[32][8];
|
208 |
-
__shared__ int32_t local_offsets[256];
|
209 |
|
210 |
-
const int warp_id = threadIdx.x /
|
211 |
-
const int lane_id = threadIdx.x % WARP_SIZE;
|
212 |
const int experts_per_warp = 8;
|
213 |
const int my_expert_start = warp_id * experts_per_warp;
|
214 |
|
|
|
215 |
for (int i = 0; i < experts_per_warp; ++i) {
|
216 |
if (my_expert_start + i < num_experts) {
|
217 |
shared_counts[warp_id][i] = 0;
|
218 |
}
|
219 |
}
|
220 |
|
|
|
|
|
221 |
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
222 |
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
223 |
|
@@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel(
|
|
230 |
|
231 |
__syncthreads();
|
232 |
|
|
|
233 |
if (threadIdx.x == 0) {
|
234 |
cumsum[0] = 0;
|
235 |
for (int i = 1; i <= num_experts; ++i) {
|
@@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel(
|
|
246 |
|
247 |
__syncthreads();
|
248 |
|
|
|
249 |
if (threadIdx.x < num_experts) {
|
250 |
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
251 |
i += block_size) {
|
252 |
expert_ids[i / block_size] = threadIdx.x;
|
253 |
}
|
254 |
-
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
|
255 |
}
|
|
|
256 |
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
int32_t expert_id = topk_ids[i];
|
261 |
-
int32_t rank_post_pad = atomicAdd(&
|
262 |
sorted_token_ids[rank_post_pad] = i;
|
263 |
}
|
264 |
}
|
@@ -377,23 +388,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|
377 |
torch::Tensor experts_ids,
|
378 |
torch::Tensor num_tokens_post_pad) {
|
379 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
|
|
380 |
VLLM_DISPATCH_INTEGRAL_TYPES(
|
381 |
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
382 |
-
// calc needed amount of shared mem for `
|
383 |
-
// tensors
|
384 |
auto options_int =
|
385 |
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
386 |
-
// torch::Tensor token_cnts_buffer =
|
387 |
-
// torch::empty({(num_experts + 1) * num_experts}, options_int);
|
388 |
torch::Tensor cumsum_buffer =
|
389 |
-
torch::
|
390 |
|
391 |
-
auto
|
392 |
-
|
|
|
393 |
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
394 |
experts_ids.data_ptr<int32_t>(),
|
395 |
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
396 |
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
});
|
398 |
}
|
399 |
|
|
|
3 |
#include <c10/cuda/CUDAGuard.h>
|
4 |
|
5 |
#include <ATen/ATen.h>
|
6 |
+
#include <ATen/cuda/Atomic.cuh>
|
7 |
|
8 |
#include "../cuda_compat.h"
|
9 |
#include "../dispatch_utils.h"
|
|
|
198 |
}
|
199 |
|
200 |
// taken from
|
201 |
+
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
|
202 |
template <typename scalar_t>
|
203 |
__global__ void sgl_moe_align_block_size_kernel(
|
204 |
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
205 |
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
206 |
int32_t block_size, size_t numel, int32_t* cumsum) {
|
207 |
__shared__ int32_t shared_counts[32][8];
|
|
|
208 |
|
209 |
+
const int warp_id = threadIdx.x / 32;
|
|
|
210 |
const int experts_per_warp = 8;
|
211 |
const int my_expert_start = warp_id * experts_per_warp;
|
212 |
|
213 |
+
// Initialize shared_counts for this warp's experts
|
214 |
for (int i = 0; i < experts_per_warp; ++i) {
|
215 |
if (my_expert_start + i < num_experts) {
|
216 |
shared_counts[warp_id][i] = 0;
|
217 |
}
|
218 |
}
|
219 |
|
220 |
+
__syncthreads();
|
221 |
+
|
222 |
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
223 |
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
224 |
|
|
|
231 |
|
232 |
__syncthreads();
|
233 |
|
234 |
+
// Single thread computes cumulative sum and total tokens
|
235 |
if (threadIdx.x == 0) {
|
236 |
cumsum[0] = 0;
|
237 |
for (int i = 1; i <= num_experts; ++i) {
|
|
|
248 |
|
249 |
__syncthreads();
|
250 |
|
251 |
+
// Assign expert IDs to blocks
|
252 |
if (threadIdx.x < num_experts) {
|
253 |
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
254 |
i += block_size) {
|
255 |
expert_ids[i / block_size] = threadIdx.x;
|
256 |
}
|
|
|
257 |
}
|
258 |
+
}
|
259 |
|
260 |
+
// taken from
|
261 |
+
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
|
262 |
+
template <typename scalar_t>
|
263 |
+
__global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids,
|
264 |
+
int32_t* sorted_token_ids,
|
265 |
+
int32_t* cumsum_buffer,
|
266 |
+
size_t numel) {
|
267 |
+
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
268 |
+
const size_t stride = blockDim.x * gridDim.x;
|
269 |
+
|
270 |
+
for (size_t i = tid; i < numel; i += stride) {
|
271 |
int32_t expert_id = topk_ids[i];
|
272 |
+
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
|
273 |
sorted_token_ids[rank_post_pad] = i;
|
274 |
}
|
275 |
}
|
|
|
388 |
torch::Tensor experts_ids,
|
389 |
torch::Tensor num_tokens_post_pad) {
|
390 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
391 |
+
TORCH_CHECK(num_experts == 256,
|
392 |
+
"sgl_moe_align_block_size kernel only supports deepseek v3.");
|
393 |
+
|
394 |
VLLM_DISPATCH_INTEGRAL_TYPES(
|
395 |
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
396 |
+
// calc needed amount of shared mem for `cumsum` tensors
|
|
|
397 |
auto options_int =
|
398 |
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
|
|
|
|
399 |
torch::Tensor cumsum_buffer =
|
400 |
+
torch::zeros({num_experts + 1}, options_int);
|
401 |
|
402 |
+
auto align_kernel =
|
403 |
+
vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
|
404 |
+
align_kernel<<<1, 1024, 0, stream>>>(
|
405 |
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
406 |
experts_ids.data_ptr<int32_t>(),
|
407 |
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
408 |
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
409 |
+
|
410 |
+
const int block_threads = 256;
|
411 |
+
const int num_blocks =
|
412 |
+
(topk_ids.numel() + block_threads - 1) / block_threads;
|
413 |
+
const int max_blocks = 65535;
|
414 |
+
const int actual_blocks = std::min(num_blocks, max_blocks);
|
415 |
+
auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel<scalar_t>;
|
416 |
+
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
|
417 |
+
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
418 |
+
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
|
419 |
});
|
420 |
}
|
421 |
|
moe/moe_wna16.cu
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#include <torch/all.h>
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
#include <ATen/cuda/CUDAContext.h>
|
5 |
+
#include <cuda_runtime.h>
|
6 |
+
|
7 |
+
#include <cuda_fp16.h>
|
8 |
+
#include <cuda_bf16.h>
|
9 |
+
#include "moe_wna16_utils.h"
|
10 |
+
|
11 |
+
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
12 |
+
|
13 |
+
template <typename scalar_t, int bit, int GROUPS>
|
14 |
+
__global__ void moe_wna16_gemm_kernel(
|
15 |
+
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
16 |
+
|
17 |
+
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
18 |
+
const uint32_t* __restrict__ qzeros,
|
19 |
+
|
20 |
+
const float* __restrict__ topk_weights,
|
21 |
+
const int32_t* __restrict__ sorted_token_ids,
|
22 |
+
const int32_t* __restrict__ expert_ids,
|
23 |
+
const int32_t* __restrict__ num_tokens_post_pad,
|
24 |
+
|
25 |
+
uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m,
|
26 |
+
uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M,
|
27 |
+
uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp,
|
28 |
+
bool mul_topk_weight) {
|
29 |
+
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
30 |
+
if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
31 |
+
return;
|
32 |
+
} else {
|
33 |
+
#endif
|
34 |
+
|
35 |
+
using Dtype = ScalarType<scalar_t>;
|
36 |
+
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
37 |
+
|
38 |
+
if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return;
|
39 |
+
|
40 |
+
const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x;
|
41 |
+
const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K;
|
42 |
+
|
43 |
+
const int32_t expert_id = expert_ids[blockIdx.x];
|
44 |
+
|
45 |
+
int32_t num_valid_tokens = 0;
|
46 |
+
extern __shared__ uint16_t block_input_tmp[];
|
47 |
+
scalar_t* block_input = reinterpret_cast<scalar_t*>(block_input_tmp);
|
48 |
+
scalar_t2* block_input_half2 = reinterpret_cast<scalar_t2*>(block_input);
|
49 |
+
|
50 |
+
// load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory
|
51 |
+
for (int m = 0; m < BLOCK_SIZE_M; m++) {
|
52 |
+
const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m;
|
53 |
+
const int32_t token_index = sorted_token_ids[offset_m];
|
54 |
+
if (token_index / top_k >= size_m) break;
|
55 |
+
|
56 |
+
num_valid_tokens = m + 1;
|
57 |
+
if (blockIdx.z == 0 && offset_n < size_n)
|
58 |
+
output[token_index * size_n + offset_n] = Dtype::int2num(0);
|
59 |
+
|
60 |
+
if (expert_id != -1) {
|
61 |
+
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
62 |
+
for (int i = 0; i < k_per_thread; i++) {
|
63 |
+
int k = BLOCK_SIZE_N * i + threadIdx.x;
|
64 |
+
if (k >= BLOCK_SIZE_K) break;
|
65 |
+
if (offset_k + k >= size_k) break;
|
66 |
+
|
67 |
+
// load input to shared memory
|
68 |
+
// use a special layout to fit the layout of dequanted-weight
|
69 |
+
int origin_k;
|
70 |
+
if constexpr (bit == 4) {
|
71 |
+
// [0, 4, 1, 5, 2, 6, 3, 7]
|
72 |
+
int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2);
|
73 |
+
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order;
|
74 |
+
} else {
|
75 |
+
// [0, 2, 1, 3]
|
76 |
+
int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2);
|
77 |
+
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order;
|
78 |
+
}
|
79 |
+
|
80 |
+
origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K;
|
81 |
+
block_input[m * BLOCK_SIZE_K + k] = input[origin_k];
|
82 |
+
}
|
83 |
+
}
|
84 |
+
}
|
85 |
+
|
86 |
+
if (expert_id == -1) return;
|
87 |
+
__syncthreads();
|
88 |
+
if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return;
|
89 |
+
|
90 |
+
float res[64]; // assume BLOCK_SIZE_M <= 64
|
91 |
+
scalar_t2 res2;
|
92 |
+
scalar_t2 scale_f2;
|
93 |
+
scalar_t2 qzero_f2;
|
94 |
+
|
95 |
+
// note that (size_n * size_k * expert_id) may greater than 2 ** 31
|
96 |
+
constexpr int8_t pack_factor = 32 / bit;
|
97 |
+
const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id;
|
98 |
+
const uint32_t* expert_qweight = qweight + expert_offset / pack_factor;
|
99 |
+
const scalar_t* expert_scales = scales + expert_offset / group_size;
|
100 |
+
const uint32_t* expert_qzeros =
|
101 |
+
qzeros + expert_offset / group_size / pack_factor;
|
102 |
+
|
103 |
+
// load 4*int32 one time: 4 int32 = 128 bit = 1 float4
|
104 |
+
// weight would be loaded in loop
|
105 |
+
uint32_t expert_qweight_tmp[4];
|
106 |
+
float4* expert_qweight_tmp_float4 =
|
107 |
+
reinterpret_cast<float4*>(expert_qweight_tmp);
|
108 |
+
|
109 |
+
// load all required scales one time
|
110 |
+
scalar_t expert_scales_groups[GROUPS];
|
111 |
+
int scales_offset_tmp =
|
112 |
+
(offset_n * size_k + offset_k) / group_size / GROUPS;
|
113 |
+
if constexpr (GROUPS == 1) {
|
114 |
+
*expert_scales_groups = expert_scales[scales_offset_tmp];
|
115 |
+
} else if constexpr (GROUPS == 2) {
|
116 |
+
float* expert_scales_groups_tmp =
|
117 |
+
reinterpret_cast<float*>(expert_scales_groups);
|
118 |
+
*expert_scales_groups_tmp =
|
119 |
+
reinterpret_cast<const float*>(expert_scales)[scales_offset_tmp];
|
120 |
+
} else if constexpr (GROUPS == 4) {
|
121 |
+
float2* expert_scales_groups_tmp =
|
122 |
+
reinterpret_cast<float2*>(expert_scales_groups);
|
123 |
+
*expert_scales_groups_tmp =
|
124 |
+
reinterpret_cast<const float2*>(expert_scales)[scales_offset_tmp];
|
125 |
+
} else if constexpr (GROUPS == 8) {
|
126 |
+
float4* expert_scales_groups_tmp =
|
127 |
+
reinterpret_cast<float4*>(expert_scales_groups);
|
128 |
+
*expert_scales_groups_tmp =
|
129 |
+
reinterpret_cast<const float4*>(expert_scales)[scales_offset_tmp];
|
130 |
+
}
|
131 |
+
|
132 |
+
// load all required qzeros one time
|
133 |
+
uint8_t expert_qzeros_groups[GROUPS];
|
134 |
+
if (!has_zp) {
|
135 |
+
if constexpr (bit == 4) {
|
136 |
+
qzero_f2 = Dtype::num2num2(Dtype::int2num(8));
|
137 |
+
} else {
|
138 |
+
qzero_f2 = Dtype::num2num2(Dtype::int2num(128));
|
139 |
+
}
|
140 |
+
} else {
|
141 |
+
int qzeros_offset_tmp =
|
142 |
+
(offset_n / (8 / bit)) * (size_k / group_size / GROUPS) +
|
143 |
+
offset_k / group_size / GROUPS;
|
144 |
+
if constexpr (GROUPS == 1) {
|
145 |
+
uint8_t* expert_qzeros_groups_tmp =
|
146 |
+
reinterpret_cast<uint8_t*>(expert_qzeros_groups);
|
147 |
+
*expert_qzeros_groups_tmp =
|
148 |
+
reinterpret_cast<const uint8_t*>(expert_qzeros)[qzeros_offset_tmp];
|
149 |
+
} else if constexpr (GROUPS == 2) {
|
150 |
+
uint16_t* expert_qzeros_groups_tmp =
|
151 |
+
reinterpret_cast<uint16_t*>(expert_qzeros_groups);
|
152 |
+
*expert_qzeros_groups_tmp =
|
153 |
+
reinterpret_cast<const uint16_t*>(expert_qzeros)[qzeros_offset_tmp];
|
154 |
+
} else if constexpr (GROUPS == 4) {
|
155 |
+
uint32_t* expert_qzeros_groups_tmp =
|
156 |
+
reinterpret_cast<uint32_t*>(expert_qzeros_groups);
|
157 |
+
*expert_qzeros_groups_tmp =
|
158 |
+
reinterpret_cast<const uint32_t*>(expert_qzeros)[qzeros_offset_tmp];
|
159 |
+
} else if constexpr (GROUPS == 8) {
|
160 |
+
uint64_t* expert_qzeros_groups_tmp =
|
161 |
+
reinterpret_cast<uint64_t*>(expert_qzeros_groups);
|
162 |
+
*expert_qzeros_groups_tmp =
|
163 |
+
reinterpret_cast<const uint64_t*>(expert_qzeros)[qzeros_offset_tmp];
|
164 |
+
}
|
165 |
+
}
|
166 |
+
|
167 |
+
for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) {
|
168 |
+
int k = offset_k + tmp_k * pack_factor;
|
169 |
+
if (k >= size_k) break;
|
170 |
+
const int32_t weight_offset = offset_n * size_k + k;
|
171 |
+
|
172 |
+
if (tmp_k % 4 == 0) {
|
173 |
+
*expert_qweight_tmp_float4 = reinterpret_cast<const float4*>(
|
174 |
+
expert_qweight)[weight_offset / pack_factor / 4];
|
175 |
+
}
|
176 |
+
|
177 |
+
if (tmp_k % (group_size / pack_factor) == 0) {
|
178 |
+
scalar_t scale_f =
|
179 |
+
expert_scales_groups[tmp_k / (group_size / pack_factor)];
|
180 |
+
scale_f2 = Dtype::num2num2(scale_f);
|
181 |
+
|
182 |
+
if (has_zp) {
|
183 |
+
uint8_t qzero =
|
184 |
+
expert_qzeros_groups[tmp_k / (group_size / pack_factor)];
|
185 |
+
if constexpr (bit == 4) {
|
186 |
+
qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF;
|
187 |
+
}
|
188 |
+
qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero));
|
189 |
+
}
|
190 |
+
}
|
191 |
+
|
192 |
+
scalar_t2 weight_half2[16 / bit];
|
193 |
+
dequant<scalar_t2, bit>(expert_qweight_tmp[tmp_k % 4], weight_half2);
|
194 |
+
|
195 |
+
for (int m = 0; m < num_valid_tokens; m++) {
|
196 |
+
res2 = {};
|
197 |
+
|
198 |
+
#pragma unroll
|
199 |
+
for (int i = 0; i < 16 / bit; i++) {
|
200 |
+
int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i;
|
201 |
+
res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2),
|
202 |
+
block_input_half2[offset_input], res2);
|
203 |
+
}
|
204 |
+
|
205 |
+
if (tmp_k == 0) {
|
206 |
+
res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
|
207 |
+
} else {
|
208 |
+
res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
|
209 |
+
}
|
210 |
+
}
|
211 |
+
}
|
212 |
+
|
213 |
+
for (int m = 0; m < num_valid_tokens; ++m) {
|
214 |
+
const int32_t token_index =
|
215 |
+
sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m];
|
216 |
+
if (mul_topk_weight) {
|
217 |
+
res[m] *= topk_weights[token_index];
|
218 |
+
}
|
219 |
+
atomicAdd(&output[token_index * size_n + offset_n],
|
220 |
+
Dtype::float2num(res[m]));
|
221 |
+
}
|
222 |
+
|
223 |
+
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
224 |
+
}
|
225 |
+
#endif
|
226 |
+
}
|
227 |
+
|
228 |
+
template <typename scalar_t>
|
229 |
+
void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output,
|
230 |
+
const uint32_t* b_qweight, const scalar_t* b_scales,
|
231 |
+
const uint32_t* b_qzeros, const float* topk_weights,
|
232 |
+
const int32_t* sorted_token_ids,
|
233 |
+
const int32_t* expert_ids,
|
234 |
+
const int32_t* num_tokens_post_pad, int num_experts,
|
235 |
+
int group_size, int num_token_blocks, int top_k,
|
236 |
+
int size_m, int size_n, int size_k, int BLOCK_SIZE_M,
|
237 |
+
int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit,
|
238 |
+
bool has_zp, bool mul_topk_weight) {
|
239 |
+
dim3 blockDim, gridDim;
|
240 |
+
blockDim.x = BLOCK_SIZE_N;
|
241 |
+
blockDim.y = 1;
|
242 |
+
blockDim.z = 1;
|
243 |
+
gridDim.x = num_token_blocks;
|
244 |
+
gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N);
|
245 |
+
gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K);
|
246 |
+
|
247 |
+
auto kernel = moe_wna16_gemm_kernel<scalar_t, 4, 1>;
|
248 |
+
if (bit == 4) {
|
249 |
+
if (BLOCK_SIZE_K / group_size == 2) {
|
250 |
+
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 2>;
|
251 |
+
} else if (BLOCK_SIZE_K / group_size == 4) {
|
252 |
+
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 4>;
|
253 |
+
} else if (BLOCK_SIZE_K / group_size == 8) {
|
254 |
+
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 8>;
|
255 |
+
}
|
256 |
+
} else {
|
257 |
+
if (BLOCK_SIZE_K / group_size == 1) {
|
258 |
+
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 1>;
|
259 |
+
} else if (BLOCK_SIZE_K / group_size == 2) {
|
260 |
+
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 2>;
|
261 |
+
} else if (BLOCK_SIZE_K / group_size == 4) {
|
262 |
+
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 4>;
|
263 |
+
} else if (BLOCK_SIZE_K / group_size == 8) {
|
264 |
+
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 8>;
|
265 |
+
}
|
266 |
+
}
|
267 |
+
|
268 |
+
const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2;
|
269 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
270 |
+
kernel<<<gridDim, blockDim, shared_mem_size, stream>>>(
|
271 |
+
input, output, b_qweight, b_scales, b_qzeros, topk_weights,
|
272 |
+
sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts,
|
273 |
+
group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N,
|
274 |
+
BLOCK_SIZE_K, has_zp, mul_topk_weight);
|
275 |
+
}
|
276 |
+
|
277 |
+
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
278 |
+
torch::Tensor b_qweight, torch::Tensor b_scales,
|
279 |
+
std::optional<torch::Tensor> b_qzeros,
|
280 |
+
std::optional<torch::Tensor> topk_weights,
|
281 |
+
torch::Tensor sorted_token_ids,
|
282 |
+
torch::Tensor expert_ids,
|
283 |
+
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
284 |
+
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
285 |
+
int64_t BLOCK_SIZE_K, int64_t bit) {
|
286 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
287 |
+
auto options =
|
288 |
+
torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
289 |
+
|
290 |
+
const int num_experts = b_qweight.size(0);
|
291 |
+
const int size_m = input.size(0);
|
292 |
+
const int size_n = b_qweight.size(1);
|
293 |
+
const int size_k = input.size(1);
|
294 |
+
const int group_size = size_k / b_scales.size(2);
|
295 |
+
|
296 |
+
int64_t EM = sorted_token_ids.size(0);
|
297 |
+
if (size_m <= BLOCK_SIZE_M) {
|
298 |
+
EM = min(EM, size_m * BLOCK_SIZE_M * top_k);
|
299 |
+
}
|
300 |
+
const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
|
301 |
+
|
302 |
+
const uint32_t* b_qzeros_ptr;
|
303 |
+
if (b_qzeros.has_value())
|
304 |
+
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
305 |
+
const float* topk_weights_ptr;
|
306 |
+
if (topk_weights.has_value())
|
307 |
+
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
|
308 |
+
|
309 |
+
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
310 |
+
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
311 |
+
TORCH_CHECK(size_k % BLOCK_SIZE_K == 0,
|
312 |
+
"size_k must divisible by BLOCK_SIZE_K");
|
313 |
+
TORCH_CHECK(BLOCK_SIZE_K % group_size == 0,
|
314 |
+
"BLOCK_SIZE_K must divisible by group_size");
|
315 |
+
TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64");
|
316 |
+
TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 ||
|
317 |
+
groups_per_block_row == 4 || groups_per_block_row == 8,
|
318 |
+
"BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]");
|
319 |
+
|
320 |
+
if (input.scalar_type() == at::ScalarType::Half) {
|
321 |
+
run_moe_wna16_gemm<half>(
|
322 |
+
(const half*)input.data_ptr<at::Half>(),
|
323 |
+
(half*)output.data_ptr<at::Half>(),
|
324 |
+
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
|
325 |
+
(const half*)b_scales.data_ptr<at::Half>(), b_qzeros_ptr,
|
326 |
+
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
|
327 |
+
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
328 |
+
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
|
329 |
+
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
|
330 |
+
b_qzeros.has_value(), topk_weights.has_value());
|
331 |
+
} else if (input.scalar_type() == at::ScalarType::BFloat16) {
|
332 |
+
run_moe_wna16_gemm<nv_bfloat16>(
|
333 |
+
(const nv_bfloat16*)input.data_ptr<at::BFloat16>(),
|
334 |
+
(nv_bfloat16*)output.data_ptr<at::BFloat16>(),
|
335 |
+
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
|
336 |
+
(const nv_bfloat16*)b_scales.data_ptr<at::BFloat16>(), b_qzeros_ptr,
|
337 |
+
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
|
338 |
+
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
339 |
+
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
|
340 |
+
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
|
341 |
+
b_qzeros.has_value(), topk_weights.has_value());
|
342 |
+
} else {
|
343 |
+
TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16");
|
344 |
+
}
|
345 |
+
return output;
|
346 |
+
}
|
moe/moe_wna16_utils.h
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#include <cuda_fp16.h>
|
3 |
+
#include <cuda_bf16.h>
|
4 |
+
|
5 |
+
template <typename scalar_t>
|
6 |
+
class ScalarType {};
|
7 |
+
|
8 |
+
template <>
|
9 |
+
class ScalarType<half> {
|
10 |
+
public:
|
11 |
+
using scalar_t = half;
|
12 |
+
using scalar_t2 = half2;
|
13 |
+
|
14 |
+
static __device__ float inline num2float(const half x) {
|
15 |
+
return __half2float(x);
|
16 |
+
}
|
17 |
+
|
18 |
+
static __device__ half2 inline num2num2(const half x) {
|
19 |
+
return __half2half2(x);
|
20 |
+
}
|
21 |
+
|
22 |
+
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
23 |
+
return __halves2half2(x1, x2);
|
24 |
+
}
|
25 |
+
|
26 |
+
static __host__ __device__ half inline float2num(const float x) {
|
27 |
+
return __float2half(x);
|
28 |
+
}
|
29 |
+
|
30 |
+
static __host__ __device__ half inline int2num(const float x) {
|
31 |
+
return __int2half_rn(x);
|
32 |
+
}
|
33 |
+
|
34 |
+
static __host__ __device__ float2 inline num22float2(const half2 x) {
|
35 |
+
return __half22float2(x);
|
36 |
+
}
|
37 |
+
|
38 |
+
static __host__ __device__ half2 inline float22num2(const float2 x) {
|
39 |
+
return __float22half2_rn(x);
|
40 |
+
}
|
41 |
+
};
|
42 |
+
|
43 |
+
template <>
|
44 |
+
class ScalarType<nv_bfloat16> {
|
45 |
+
public:
|
46 |
+
using scalar_t = nv_bfloat16;
|
47 |
+
using scalar_t2 = nv_bfloat162;
|
48 |
+
|
49 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
50 |
+
static __device__ float inline num2float(const nv_bfloat16 x) {
|
51 |
+
return __bfloat162float(x);
|
52 |
+
}
|
53 |
+
|
54 |
+
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
55 |
+
return __bfloat162bfloat162(x);
|
56 |
+
}
|
57 |
+
|
58 |
+
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
|
59 |
+
const nv_bfloat16 x2) {
|
60 |
+
return __halves2bfloat162(x1, x2);
|
61 |
+
}
|
62 |
+
|
63 |
+
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
64 |
+
return __float2bfloat16(x);
|
65 |
+
}
|
66 |
+
|
67 |
+
static __host__ __device__ nv_bfloat16 inline int2num(const float x) {
|
68 |
+
return __int2bfloat16_rn(x);
|
69 |
+
}
|
70 |
+
|
71 |
+
static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
|
72 |
+
return __bfloat1622float2(x);
|
73 |
+
}
|
74 |
+
|
75 |
+
static __host__ __device__ nv_bfloat162 inline float22num2(const float2 x) {
|
76 |
+
return __float22bfloat162_rn(x);
|
77 |
+
}
|
78 |
+
#endif
|
79 |
+
};
|
80 |
+
|
81 |
+
template <int lut>
|
82 |
+
__device__ inline int lop3(int a, int b, int c) {
|
83 |
+
int res;
|
84 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
85 |
+
: "=r"(res)
|
86 |
+
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
87 |
+
return res;
|
88 |
+
}
|
89 |
+
|
90 |
+
template <int start_byte, int mask>
|
91 |
+
__device__ inline uint32_t prmt(uint32_t a) {
|
92 |
+
uint32_t res;
|
93 |
+
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
94 |
+
: "=r"(res)
|
95 |
+
: "r"(a), "n"(start_byte), "n"(mask));
|
96 |
+
return res;
|
97 |
+
}
|
98 |
+
|
99 |
+
template <typename scalar_t2, int bit>
|
100 |
+
__device__ inline void dequant(int q, scalar_t2* res) {}
|
101 |
+
|
102 |
+
template <>
|
103 |
+
__device__ inline void dequant<half2, 4>(int q, half2* res) {
|
104 |
+
const int LO = 0x000f000f;
|
105 |
+
const int HI = 0x00f000f0;
|
106 |
+
const int EX = 0x64006400;
|
107 |
+
const int SUB = 0x64006400;
|
108 |
+
const int MUL = 0x2c002c00;
|
109 |
+
const int ADD = 0xd400d400;
|
110 |
+
|
111 |
+
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
112 |
+
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
113 |
+
q >>= 8;
|
114 |
+
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
115 |
+
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
116 |
+
|
117 |
+
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0),
|
118 |
+
*reinterpret_cast<const half2*>(&SUB));
|
119 |
+
res[1] = __hfma2(*reinterpret_cast<half2*>(&hi0),
|
120 |
+
*reinterpret_cast<const half2*>(&MUL),
|
121 |
+
*reinterpret_cast<const half2*>(&ADD));
|
122 |
+
res[2] = __hsub2(*reinterpret_cast<half2*>(&lo1),
|
123 |
+
*reinterpret_cast<const half2*>(&SUB));
|
124 |
+
res[3] = __hfma2(*reinterpret_cast<half2*>(&hi1),
|
125 |
+
*reinterpret_cast<const half2*>(&MUL),
|
126 |
+
*reinterpret_cast<const half2*>(&ADD));
|
127 |
+
}
|
128 |
+
|
129 |
+
template <>
|
130 |
+
__device__ inline void dequant<half2, 8>(int q, half2* res) {
|
131 |
+
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
132 |
+
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
133 |
+
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
134 |
+
|
135 |
+
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
136 |
+
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
137 |
+
|
138 |
+
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
139 |
+
|
140 |
+
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
141 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
142 |
+
res[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
143 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
144 |
+
}
|
145 |
+
|
146 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
147 |
+
template <>
|
148 |
+
__device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
|
149 |
+
static constexpr uint32_t MASK = 0x000f000f;
|
150 |
+
static constexpr uint32_t EX = 0x43004300;
|
151 |
+
|
152 |
+
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
153 |
+
q >>= 4;
|
154 |
+
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
155 |
+
q >>= 4;
|
156 |
+
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
157 |
+
q >>= 4;
|
158 |
+
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
159 |
+
|
160 |
+
static constexpr uint32_t MUL = 0x3F803F80;
|
161 |
+
static constexpr uint32_t ADD = 0xC300C300;
|
162 |
+
|
163 |
+
res[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo0),
|
164 |
+
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
165 |
+
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
166 |
+
res[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi0),
|
167 |
+
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
168 |
+
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
169 |
+
res[2] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo1),
|
170 |
+
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
171 |
+
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
172 |
+
res[3] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi1),
|
173 |
+
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
174 |
+
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
175 |
+
}
|
176 |
+
|
177 |
+
template <>
|
178 |
+
__device__ inline void dequant<nv_bfloat162, 8>(int q, nv_bfloat162* res) {
|
179 |
+
float fp32_intermediates[4];
|
180 |
+
uint32_t* fp32_intermediates_casted =
|
181 |
+
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
182 |
+
|
183 |
+
static constexpr uint32_t fp32_base = 0x4B000000;
|
184 |
+
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
185 |
+
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
186 |
+
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
187 |
+
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
188 |
+
|
189 |
+
fp32_intermediates[0] -= 8388608.f;
|
190 |
+
fp32_intermediates[1] -= 8388608.f;
|
191 |
+
fp32_intermediates[2] -= 8388608.f;
|
192 |
+
fp32_intermediates[3] -= 8388608.f;
|
193 |
+
|
194 |
+
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(res);
|
195 |
+
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
196 |
+
fp32_intermediates_casted[1], 0x7632);
|
197 |
+
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
198 |
+
fp32_intermediates_casted[3], 0x7632);
|
199 |
+
}
|
200 |
+
#endif
|
tests/kernels/test_moe.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
"""Tests for the MOE layers.
|
2 |
|
3 |
Run `pytest tests/kernels/test_moe.py`.
|
@@ -18,12 +19,15 @@ from moe.utils.marlin_utils_test import marlin_quantize, quantize_weights
|
|
18 |
from .utils import compute_max_diff, opcheck, torch_moe
|
19 |
|
20 |
|
|
|
|
|
|
|
21 |
def stack_and_dev(tensors: List[torch.Tensor]):
|
22 |
dev = tensors[0].device
|
23 |
return torch.stack(tensors, dim=0).to(dev)
|
24 |
|
25 |
-
|
26 |
NUM_EXPERTS = [8, 64]
|
|
|
27 |
TOP_KS = [2, 6]
|
28 |
|
29 |
|
@@ -32,25 +36,54 @@ TOP_KS = [2, 6]
|
|
32 |
@pytest.mark.parametrize("k", [128, 511, 1024])
|
33 |
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
34 |
@pytest.mark.parametrize("topk", TOP_KS)
|
|
|
35 |
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
|
36 |
def test_fused_moe(
|
37 |
m: int,
|
38 |
n: int,
|
39 |
k: int,
|
40 |
e: int,
|
41 |
topk: int,
|
|
|
42 |
dtype: torch.dtype,
|
|
|
43 |
):
|
44 |
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
45 |
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
46 |
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
47 |
|
48 |
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
52 |
-
# iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
|
53 |
-
# torch.testing.assert_close(iterative_output, torch_output, atol=2e-2, rtol=0)
|
54 |
|
55 |
|
56 |
@pytest.mark.parametrize("m", [1, 32, 222])
|
@@ -58,21 +91,14 @@ def test_fused_moe(
|
|
58 |
@pytest.mark.parametrize("k", [128, 1024])
|
59 |
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
60 |
@pytest.mark.parametrize("topk", TOP_KS)
|
|
|
61 |
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
62 |
@pytest.mark.parametrize("group_size", [64, 128])
|
63 |
@pytest.mark.parametrize("has_zp", [True, False])
|
64 |
@pytest.mark.parametrize("weight_bits", [4, 8])
|
65 |
-
def test_fused_moe_wn16(
|
66 |
-
|
67 |
-
|
68 |
-
k: int,
|
69 |
-
e: int,
|
70 |
-
topk: int,
|
71 |
-
dtype: torch.dtype,
|
72 |
-
group_size: int,
|
73 |
-
has_zp: bool,
|
74 |
-
weight_bits: int,
|
75 |
-
):
|
76 |
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
|
77 |
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
78 |
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
@@ -88,40 +114,35 @@ def test_fused_moe_wn16(
|
|
88 |
|
89 |
w1_ref = w1.clone()
|
90 |
w2_ref = w2.clone()
|
91 |
-
w1_qweight = torch.empty(
|
92 |
-
|
93 |
-
|
94 |
-
w2_qweight = torch.empty((e, k, n // pack_factor),
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
for i in range(e * 2):
|
105 |
expert_id = i % e
|
106 |
if i // e == 0:
|
107 |
-
w, w_ref, w_qweight, w_scales, w_qzeros =
|
108 |
-
w1,
|
109 |
-
w1_ref,
|
110 |
-
w1_qweight,
|
111 |
-
w1_scales,
|
112 |
-
w1_qzeros,
|
113 |
-
)
|
114 |
else:
|
115 |
-
w, w_ref, w_qweight, w_scales, w_qzeros =
|
116 |
-
w2,
|
117 |
-
w2_ref,
|
118 |
-
w2_qweight,
|
119 |
-
w2_scales,
|
120 |
-
w2_qzeros,
|
121 |
-
)
|
122 |
weight, qweight, scales, qzeros = quantize_weights(
|
123 |
-
w[expert_id].T, quant_type, group_size, has_zp, False
|
124 |
-
)
|
125 |
weight = weight.T
|
126 |
qweight = qweight.T.contiguous().to(torch.uint8)
|
127 |
scales = scales.T
|
@@ -138,25 +159,45 @@ def test_fused_moe_wn16(
|
|
138 |
if has_zp:
|
139 |
w_qzeros[expert_id] = qzeros
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
158 |
|
159 |
|
|
|
160 |
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
161 |
@pytest.mark.parametrize("n", [128, 2048])
|
162 |
@pytest.mark.parametrize("k", [128, 1024])
|
@@ -178,7 +219,7 @@ def test_fused_marlin_moe(
|
|
178 |
num_bits: int,
|
179 |
is_k_full: bool,
|
180 |
):
|
181 |
-
|
182 |
|
183 |
# Filter act_order
|
184 |
if act_order:
|
@@ -190,7 +231,8 @@ def test_fused_marlin_moe(
|
|
190 |
if not is_k_full:
|
191 |
return
|
192 |
|
193 |
-
quant_type = scalar_types.uint4b8
|
|
|
194 |
dtype = torch.float16
|
195 |
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
196 |
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
@@ -205,8 +247,8 @@ def test_fused_marlin_moe(
|
|
205 |
for i in range(w1.shape[0]):
|
206 |
test_perm = torch.randperm(k)
|
207 |
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
208 |
-
w1[i].transpose(1, 0), quant_type, group_size, act_order,
|
209 |
-
|
210 |
w_ref1_l.append(w_ref1)
|
211 |
qweight1_l.append(qweight1)
|
212 |
scales1_l.append(scales1)
|
@@ -228,8 +270,8 @@ def test_fused_marlin_moe(
|
|
228 |
for i in range(w2.shape[0]):
|
229 |
test_perm = torch.randperm(n)
|
230 |
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
231 |
-
w2[i].transpose(1, 0), quant_type, group_size, act_order,
|
232 |
-
|
233 |
w_ref2_l.append(w_ref2)
|
234 |
qweight2_l.append(qweight2)
|
235 |
scales2_l.append(scales2)
|
@@ -273,79 +315,141 @@ def test_fused_marlin_moe(
|
|
273 |
|
274 |
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
275 |
|
276 |
-
token_expert_indicies = torch.empty(m,
|
|
|
|
|
|
|
277 |
|
278 |
-
opcheck(
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
score.float(),
|
285 |
-
),
|
286 |
-
)
|
287 |
|
288 |
block_size_m = 4
|
289 |
|
290 |
-
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m,
|
|
|
291 |
|
292 |
max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
|
293 |
-
workspace = torch.zeros(
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
sort_indices1,
|
310 |
-
workspace,
|
311 |
-
quant_type.id,
|
312 |
-
m,
|
313 |
-
2 * n,
|
314 |
-
k,
|
315 |
-
True,
|
316 |
-
e,
|
317 |
-
topk,
|
318 |
-
block_size_m,
|
319 |
-
True,
|
320 |
-
False,
|
321 |
-
),
|
322 |
)
|
323 |
|
|
|
|
|
|
|
|
|
324 |
|
325 |
def test_moe_align_block_size_opcheck():
|
326 |
num_experts = 4
|
327 |
block_size = 4
|
328 |
-
topk_ids = torch.randint(0,
|
|
|
|
|
|
|
329 |
|
330 |
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
331 |
-
sorted_ids = torch.empty(
|
332 |
-
|
333 |
-
|
334 |
sorted_ids.fill_(topk_ids.numel())
|
335 |
max_num_m_blocks = max_num_tokens_padded // block_size
|
336 |
-
expert_ids = torch.empty(
|
337 |
-
|
338 |
-
|
339 |
-
num_tokens_post_pad = torch.empty((1),
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
topk_ids,
|
345 |
-
|
346 |
-
block_size,
|
347 |
-
sorted_ids,
|
348 |
-
expert_ids,
|
349 |
-
num_tokens_post_pad,
|
350 |
-
),
|
351 |
-
)
|
|
|
1 |
+
# SPDX-License-Identifier: Apache-2.0
|
2 |
"""Tests for the MOE layers.
|
3 |
|
4 |
Run `pytest tests/kernels/test_moe.py`.
|
|
|
19 |
from .utils import compute_max_diff, opcheck, torch_moe
|
20 |
|
21 |
|
22 |
+
from torch.nn import Parameter
|
23 |
+
from torch.nn import functional as F
|
24 |
+
|
25 |
def stack_and_dev(tensors: List[torch.Tensor]):
|
26 |
dev = tensors[0].device
|
27 |
return torch.stack(tensors, dim=0).to(dev)
|
28 |
|
|
|
29 |
NUM_EXPERTS = [8, 64]
|
30 |
+
EP_SIZE = [1, 4]
|
31 |
TOP_KS = [2, 6]
|
32 |
|
33 |
|
|
|
36 |
@pytest.mark.parametrize("k", [128, 511, 1024])
|
37 |
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
38 |
@pytest.mark.parametrize("topk", TOP_KS)
|
39 |
+
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
40 |
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
41 |
+
@pytest.mark.parametrize("padding", [True, False])
|
42 |
def test_fused_moe(
|
43 |
m: int,
|
44 |
n: int,
|
45 |
k: int,
|
46 |
e: int,
|
47 |
topk: int,
|
48 |
+
ep_size: int,
|
49 |
dtype: torch.dtype,
|
50 |
+
padding: bool,
|
51 |
):
|
52 |
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
53 |
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
54 |
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
55 |
|
56 |
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
57 |
+
|
58 |
+
if ep_size > 1:
|
59 |
+
local_e = e // ep_size
|
60 |
+
e_ids = torch.randint(0,
|
61 |
+
e, (local_e, ),
|
62 |
+
device="cuda",
|
63 |
+
dtype=torch.int32)
|
64 |
+
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
|
65 |
+
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
66 |
+
w1 = w1[e_ids]
|
67 |
+
w2 = w2[e_ids]
|
68 |
+
else:
|
69 |
+
e_map = None
|
70 |
+
|
71 |
+
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
|
72 |
+
if padding:
|
73 |
+
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
|
74 |
+
torch.cuda.empty_cache()
|
75 |
+
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
|
76 |
+
torch.cuda.empty_cache()
|
77 |
+
|
78 |
+
triton_output = fused_moe(a,
|
79 |
+
w1,
|
80 |
+
w2,
|
81 |
+
score,
|
82 |
+
topk,
|
83 |
+
global_num_experts=e,
|
84 |
+
expert_map=e_map,
|
85 |
+
renormalize=False)
|
86 |
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
|
|
|
|
87 |
|
88 |
|
89 |
@pytest.mark.parametrize("m", [1, 32, 222])
|
|
|
91 |
@pytest.mark.parametrize("k", [128, 1024])
|
92 |
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
93 |
@pytest.mark.parametrize("topk", TOP_KS)
|
94 |
+
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
95 |
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
96 |
@pytest.mark.parametrize("group_size", [64, 128])
|
97 |
@pytest.mark.parametrize("has_zp", [True, False])
|
98 |
@pytest.mark.parametrize("weight_bits", [4, 8])
|
99 |
+
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
100 |
+
ep_size: int, dtype: torch.dtype, group_size: int,
|
101 |
+
has_zp: bool, weight_bits: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
|
103 |
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
104 |
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
|
|
114 |
|
115 |
w1_ref = w1.clone()
|
116 |
w2_ref = w2.clone()
|
117 |
+
w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
|
118 |
+
device="cuda",
|
119 |
+
dtype=torch.uint8)
|
120 |
+
w2_qweight = torch.empty((e, k, n // pack_factor),
|
121 |
+
device="cuda",
|
122 |
+
dtype=torch.uint8)
|
123 |
+
w1_scales = torch.empty((e, 2 * n, k // group_size),
|
124 |
+
device="cuda",
|
125 |
+
dtype=dtype)
|
126 |
+
w2_scales = torch.empty((e, k, n // group_size),
|
127 |
+
device="cuda",
|
128 |
+
dtype=dtype)
|
129 |
+
w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
|
130 |
+
device="cuda",
|
131 |
+
dtype=torch.uint8)
|
132 |
+
w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
|
133 |
+
device="cuda",
|
134 |
+
dtype=torch.uint8)
|
135 |
|
136 |
for i in range(e * 2):
|
137 |
expert_id = i % e
|
138 |
if i // e == 0:
|
139 |
+
w, w_ref, w_qweight, w_scales, w_qzeros = \
|
140 |
+
w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
|
|
|
|
|
|
|
|
|
|
|
141 |
else:
|
142 |
+
w, w_ref, w_qweight, w_scales, w_qzeros = \
|
143 |
+
w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
|
|
|
|
|
|
|
|
|
|
|
144 |
weight, qweight, scales, qzeros = quantize_weights(
|
145 |
+
w[expert_id].T, quant_type, group_size, has_zp, False)
|
|
|
146 |
weight = weight.T
|
147 |
qweight = qweight.T.contiguous().to(torch.uint8)
|
148 |
scales = scales.T
|
|
|
159 |
if has_zp:
|
160 |
w_qzeros[expert_id] = qzeros
|
161 |
|
162 |
+
if ep_size > 1:
|
163 |
+
local_e = e // ep_size
|
164 |
+
e_ids = torch.randint(0,
|
165 |
+
e, (local_e, ),
|
166 |
+
device="cuda",
|
167 |
+
dtype=torch.int32)
|
168 |
+
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
|
169 |
+
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
170 |
+
w1_ref = w1_ref[e_ids]
|
171 |
+
w2_ref = w2_ref[e_ids]
|
172 |
+
w1_qweight = w1_qweight[e_ids]
|
173 |
+
w2_qweight = w2_qweight[e_ids]
|
174 |
+
w1_scales = w1_scales[e_ids]
|
175 |
+
w2_scales = w2_scales[e_ids]
|
176 |
+
w1_qzeros = w1_qzeros[e_ids]
|
177 |
+
w2_qzeros = w2_qzeros[e_ids]
|
178 |
+
else:
|
179 |
+
e_map = None
|
180 |
+
|
181 |
+
triton_output = fused_moe(a,
|
182 |
+
w1_qweight,
|
183 |
+
w2_qweight,
|
184 |
+
score,
|
185 |
+
topk,
|
186 |
+
renormalize=False,
|
187 |
+
use_int4_w4a16=weight_bits == 4,
|
188 |
+
use_int8_w8a16=weight_bits == 8,
|
189 |
+
global_num_experts=e,
|
190 |
+
expert_map=e_map,
|
191 |
+
w1_scale=w1_scales,
|
192 |
+
w2_scale=w2_scales,
|
193 |
+
w1_zp=w1_qzeros if has_zp else None,
|
194 |
+
w2_zp=w2_qzeros if has_zp else None,
|
195 |
+
block_shape=[0, group_size])
|
196 |
+
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
|
197 |
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
198 |
|
199 |
|
200 |
+
|
201 |
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
202 |
@pytest.mark.parametrize("n", [128, 2048])
|
203 |
@pytest.mark.parametrize("k", [128, 1024])
|
|
|
219 |
num_bits: int,
|
220 |
is_k_full: bool,
|
221 |
):
|
222 |
+
current_platform.seed_everything(7)
|
223 |
|
224 |
# Filter act_order
|
225 |
if act_order:
|
|
|
231 |
if not is_k_full:
|
232 |
return
|
233 |
|
234 |
+
quant_type = (scalar_types.uint4b8
|
235 |
+
if num_bits == 4 else scalar_types.uint8b128)
|
236 |
dtype = torch.float16
|
237 |
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
238 |
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
|
|
247 |
for i in range(w1.shape[0]):
|
248 |
test_perm = torch.randperm(k)
|
249 |
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
250 |
+
w1[i].transpose(1, 0), quant_type, group_size, act_order,
|
251 |
+
test_perm)
|
252 |
w_ref1_l.append(w_ref1)
|
253 |
qweight1_l.append(qweight1)
|
254 |
scales1_l.append(scales1)
|
|
|
270 |
for i in range(w2.shape[0]):
|
271 |
test_perm = torch.randperm(n)
|
272 |
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
273 |
+
w2[i].transpose(1, 0), quant_type, group_size, act_order,
|
274 |
+
test_perm)
|
275 |
w_ref2_l.append(w_ref2)
|
276 |
qweight2_l.append(qweight2)
|
277 |
scales2_l.append(scales2)
|
|
|
315 |
|
316 |
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
317 |
|
318 |
+
token_expert_indicies = torch.empty(m,
|
319 |
+
topk,
|
320 |
+
dtype=torch.int32,
|
321 |
+
device=a.device)
|
322 |
|
323 |
+
opcheck(ops.topk_softmax, (
|
324 |
+
topk_weights,
|
325 |
+
topk_ids,
|
326 |
+
token_expert_indicies,
|
327 |
+
score.float(),
|
328 |
+
))
|
|
|
|
|
|
|
329 |
|
330 |
block_size_m = 4
|
331 |
|
332 |
+
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m,
|
333 |
+
e)
|
334 |
|
335 |
max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
|
336 |
+
workspace = torch.zeros(max_workspace_size,
|
337 |
+
dtype=torch.int,
|
338 |
+
device="cuda",
|
339 |
+
requires_grad=False)
|
340 |
+
|
341 |
+
zp = torch.empty((0, 0),
|
342 |
+
dtype=dtype,
|
343 |
+
device="cuda",
|
344 |
+
requires_grad=False)
|
345 |
+
opcheck(ops.marlin_gemm_moe,
|
346 |
+
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
|
347 |
+
scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
|
348 |
+
m, 2 * n, k, True, e, topk, block_size_m, True, False))
|
349 |
+
|
350 |
+
|
351 |
+
@pytest.mark.skip("This test is here for the sake of debugging, "
|
352 |
+
"don't run it in automated tests.")
|
353 |
+
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
354 |
+
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
355 |
+
@pytest.mark.parametrize("k", [128, 1024, 512])
|
356 |
+
@pytest.mark.parametrize("e", [8, 64])
|
357 |
+
@pytest.mark.parametrize("topk", [2, 6])
|
358 |
+
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
359 |
+
@pytest.mark.parametrize("act_order", [True, False])
|
360 |
+
@pytest.mark.parametrize("num_bits", [4, 8])
|
361 |
+
@pytest.mark.parametrize("is_k_full", [True, False])
|
362 |
+
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
363 |
+
def test_single_marlin_moe_multiply(
|
364 |
+
m: int,
|
365 |
+
n: int,
|
366 |
+
k: int,
|
367 |
+
e: int,
|
368 |
+
topk: int,
|
369 |
+
group_size: int,
|
370 |
+
act_order: bool,
|
371 |
+
num_bits: int,
|
372 |
+
is_k_full: bool,
|
373 |
+
):
|
374 |
+
|
375 |
+
# Filter act_order
|
376 |
+
if act_order:
|
377 |
+
if group_size == -1:
|
378 |
+
return
|
379 |
+
if group_size == k:
|
380 |
+
return
|
381 |
+
else:
|
382 |
+
if not is_k_full:
|
383 |
+
return
|
384 |
+
|
385 |
+
quant_type = (scalar_types.uint4b8
|
386 |
+
if num_bits == 4 else scalar_types.uint8b128)
|
387 |
+
dtype = torch.float16
|
388 |
+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
389 |
+
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
390 |
+
|
391 |
+
w_ref_l = []
|
392 |
+
qweights_l = []
|
393 |
+
scales_l = []
|
394 |
+
g_idx_l = []
|
395 |
+
sort_indices_l = []
|
396 |
+
|
397 |
+
for i in range(w.shape[0]):
|
398 |
+
test_perm = torch.randperm(k)
|
399 |
+
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
400 |
+
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
|
401 |
+
w_ref_l.append(w_ref)
|
402 |
+
qweights_l.append(qweight)
|
403 |
+
scales_l.append(scales)
|
404 |
+
g_idx_l.append(g_idx)
|
405 |
+
sort_indices_l.append(sort_indices)
|
406 |
+
|
407 |
+
w_ref = stack_and_dev(w_ref_l)
|
408 |
+
qweight = stack_and_dev(qweights_l).contiguous()
|
409 |
+
scales = stack_and_dev(scales_l)
|
410 |
+
g_idx = stack_and_dev(g_idx_l)
|
411 |
+
sort_indices = stack_and_dev(sort_indices_l)
|
412 |
|
413 |
+
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
414 |
+
marlin_output = ops.single_marlin_moe(
|
415 |
+
a,
|
416 |
+
qweight,
|
417 |
+
scales,
|
418 |
+
score,
|
419 |
+
topk,
|
420 |
+
renormalize=False,
|
421 |
+
g_idx=g_idx,
|
422 |
+
sort_indices=sort_indices,
|
423 |
+
num_bits=num_bits,
|
424 |
+
is_k_full=is_k_full,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
)
|
426 |
|
427 |
+
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
428 |
+
|
429 |
+
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
430 |
+
|
431 |
|
432 |
def test_moe_align_block_size_opcheck():
|
433 |
num_experts = 4
|
434 |
block_size = 4
|
435 |
+
topk_ids = torch.randint(0,
|
436 |
+
num_experts, (3, 4),
|
437 |
+
dtype=torch.int32,
|
438 |
+
device='cuda')
|
439 |
|
440 |
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
441 |
+
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
442 |
+
dtype=torch.int32,
|
443 |
+
device=topk_ids.device)
|
444 |
sorted_ids.fill_(topk_ids.numel())
|
445 |
max_num_m_blocks = max_num_tokens_padded // block_size
|
446 |
+
expert_ids = torch.empty((max_num_m_blocks, ),
|
447 |
+
dtype=torch.int32,
|
448 |
+
device=topk_ids.device)
|
449 |
+
num_tokens_post_pad = torch.empty((1),
|
450 |
+
dtype=torch.int32,
|
451 |
+
device=topk_ids.device)
|
452 |
+
|
453 |
+
opcheck(ops.moe_align_block_size,
|
454 |
+
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
|
455 |
+
num_tokens_post_pad))
|
|
|
|
|
|
|
|
|
|
|
|
tests/kernels/utils.py
CHANGED
@@ -38,7 +38,7 @@ class SiluAndMul(nn.Module):
|
|
38 |
return F.silu(x[..., :d]) * x[..., d:]
|
39 |
|
40 |
|
41 |
-
def torch_moe(a, w1, w2, score, topk):
|
42 |
B, D = a.shape
|
43 |
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
44 |
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
@@ -46,15 +46,15 @@ def torch_moe(a, w1, w2, score, topk):
|
|
46 |
topk_weight, topk_ids = torch.topk(score, topk)
|
47 |
topk_weight = topk_weight.view(-1)
|
48 |
topk_ids = topk_ids.view(-1)
|
|
|
|
|
49 |
for i in range(w1.shape[0]):
|
50 |
mask = topk_ids == i
|
51 |
if mask.sum():
|
52 |
-
out[mask] = SiluAndMul()(
|
53 |
-
0, 1
|
54 |
-
|
55 |
-
|
56 |
-
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
57 |
-
).sum(dim=1)
|
58 |
|
59 |
|
60 |
# Copied/modified from torch._refs.__init__.py
|
|
|
38 |
return F.silu(x[..., :d]) * x[..., d:]
|
39 |
|
40 |
|
41 |
+
def torch_moe(a, w1, w2, score, topk, expert_map):
|
42 |
B, D = a.shape
|
43 |
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
44 |
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
|
46 |
topk_weight, topk_ids = torch.topk(score, topk)
|
47 |
topk_weight = topk_weight.view(-1)
|
48 |
topk_ids = topk_ids.view(-1)
|
49 |
+
if expert_map is not None:
|
50 |
+
topk_ids = expert_map[topk_ids]
|
51 |
for i in range(w1.shape[0]):
|
52 |
mask = topk_ids == i
|
53 |
if mask.sum():
|
54 |
+
out[mask] = SiluAndMul()(
|
55 |
+
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
56 |
+
return (out.view(B, -1, w2.shape[1]) *
|
57 |
+
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
|
|
|
|
58 |
|
59 |
|
60 |
# Copied/modified from torch._refs.__init__.py
|
torch-ext/moe/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import torch
|
2 |
|
|
|
3 |
from ._ops import ops
|
4 |
from .fp8_utils import per_token_group_quant_fp8, w8a8_block_fp8_matmul
|
5 |
from .fused_marlin_moe import fused_marlin_moe
|
@@ -51,24 +52,6 @@ def moe_sum(input: torch.Tensor, output: torch.Tensor):
|
|
51 |
ops.moe_sum(input, output)
|
52 |
|
53 |
|
54 |
-
def moe_align_block_size(
|
55 |
-
topk_ids: torch.Tensor,
|
56 |
-
num_experts: int,
|
57 |
-
block_size: int,
|
58 |
-
sorted_token_ids: torch.Tensor,
|
59 |
-
experts_ids: torch.Tensor,
|
60 |
-
num_tokens_post_pad: torch.Tensor,
|
61 |
-
) -> None:
|
62 |
-
ops.moe_align_block_size(
|
63 |
-
topk_ids,
|
64 |
-
num_experts,
|
65 |
-
block_size,
|
66 |
-
sorted_token_ids,
|
67 |
-
experts_ids,
|
68 |
-
num_tokens_post_pad,
|
69 |
-
)
|
70 |
-
|
71 |
-
|
72 |
def topk_softmax(
|
73 |
topk_weights: torch.Tensor,
|
74 |
topk_ids: torch.Tensor,
|
@@ -87,6 +70,7 @@ __all__ = [
|
|
87 |
"fused_topk",
|
88 |
"gptq_marlin_moe_repack",
|
89 |
"grouped_topk",
|
|
|
90 |
"moe_align_block_size",
|
91 |
"moe_sum",
|
92 |
"per_token_group_quant_fp8",
|
|
|
1 |
import torch
|
2 |
|
3 |
+
from . import layers
|
4 |
from ._ops import ops
|
5 |
from .fp8_utils import per_token_group_quant_fp8, w8a8_block_fp8_matmul
|
6 |
from .fused_marlin_moe import fused_marlin_moe
|
|
|
52 |
ops.moe_sum(input, output)
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def topk_softmax(
|
56 |
topk_weights: torch.Tensor,
|
57 |
topk_ids: torch.Tensor,
|
|
|
70 |
"fused_topk",
|
71 |
"gptq_marlin_moe_repack",
|
72 |
"grouped_topk",
|
73 |
+
"layers",
|
74 |
"moe_align_block_size",
|
75 |
"moe_sum",
|
76 |
"per_token_group_quant_fp8",
|
torch-ext/moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 64,
|
6 |
+
"GROUP_SIZE_M": 64,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 64,
|
13 |
+
"BLOCK_SIZE_K": 64,
|
14 |
+
"GROUP_SIZE_M": 64,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 5
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 64,
|
22 |
+
"GROUP_SIZE_M": 64,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 5
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 16,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 2
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 64,
|
38 |
+
"GROUP_SIZE_M": 1,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 5
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 64,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 16,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 2
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 64,
|
53 |
+
"BLOCK_SIZE_K": 256,
|
54 |
+
"GROUP_SIZE_M": 16,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 2
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 32,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 2
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 16,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 16,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 16,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 16,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 16,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 32,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 16,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 2
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 64,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 64,
|
110 |
+
"GROUP_SIZE_M": 16,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 4
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 64,
|
116 |
+
"BLOCK_SIZE_N": 128,
|
117 |
+
"BLOCK_SIZE_K": 64,
|
118 |
+
"GROUP_SIZE_M": 16,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 3
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 64,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 64,
|
134 |
+
"GROUP_SIZE_M": 1,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 64,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 64,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 4
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 256,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0,
|
10 |
+
"matrix_instr_nonkdim": 16,
|
11 |
+
"kpack": 2
|
12 |
+
},
|
13 |
+
"2": {
|
14 |
+
"BLOCK_SIZE_M": 16,
|
15 |
+
"BLOCK_SIZE_N": 128,
|
16 |
+
"BLOCK_SIZE_K": 64,
|
17 |
+
"GROUP_SIZE_M": 1,
|
18 |
+
"num_warps": 8,
|
19 |
+
"num_stages": 2,
|
20 |
+
"waves_per_eu": 0,
|
21 |
+
"matrix_instr_nonkdim": 16,
|
22 |
+
"kpack": 1
|
23 |
+
},
|
24 |
+
"4": {
|
25 |
+
"BLOCK_SIZE_M": 16,
|
26 |
+
"BLOCK_SIZE_N": 128,
|
27 |
+
"BLOCK_SIZE_K": 128,
|
28 |
+
"GROUP_SIZE_M": 1,
|
29 |
+
"num_warps": 8,
|
30 |
+
"num_stages": 2,
|
31 |
+
"waves_per_eu": 0,
|
32 |
+
"matrix_instr_nonkdim": 16,
|
33 |
+
"kpack": 2
|
34 |
+
},
|
35 |
+
"8": {
|
36 |
+
"BLOCK_SIZE_M": 16,
|
37 |
+
"BLOCK_SIZE_N": 64,
|
38 |
+
"BLOCK_SIZE_K": 128,
|
39 |
+
"GROUP_SIZE_M": 1,
|
40 |
+
"num_warps": 4,
|
41 |
+
"num_stages": 2,
|
42 |
+
"waves_per_eu": 0,
|
43 |
+
"matrix_instr_nonkdim": 16,
|
44 |
+
"kpack": 2
|
45 |
+
},
|
46 |
+
"16": {
|
47 |
+
"BLOCK_SIZE_M": 16,
|
48 |
+
"BLOCK_SIZE_N": 64,
|
49 |
+
"BLOCK_SIZE_K": 128,
|
50 |
+
"GROUP_SIZE_M": 1,
|
51 |
+
"num_warps": 4,
|
52 |
+
"num_stages": 2,
|
53 |
+
"waves_per_eu": 0,
|
54 |
+
"matrix_instr_nonkdim": 16,
|
55 |
+
"kpack": 2
|
56 |
+
},
|
57 |
+
"24": {
|
58 |
+
"BLOCK_SIZE_M": 16,
|
59 |
+
"BLOCK_SIZE_N": 64,
|
60 |
+
"BLOCK_SIZE_K": 128,
|
61 |
+
"GROUP_SIZE_M": 1,
|
62 |
+
"num_warps": 4,
|
63 |
+
"num_stages": 2,
|
64 |
+
"waves_per_eu": 0,
|
65 |
+
"matrix_instr_nonkdim": 16,
|
66 |
+
"kpack": 2
|
67 |
+
},
|
68 |
+
"32": {
|
69 |
+
"BLOCK_SIZE_M": 16,
|
70 |
+
"BLOCK_SIZE_N": 64,
|
71 |
+
"BLOCK_SIZE_K": 256,
|
72 |
+
"GROUP_SIZE_M": 16,
|
73 |
+
"num_warps": 4,
|
74 |
+
"num_stages": 2,
|
75 |
+
"waves_per_eu": 0,
|
76 |
+
"matrix_instr_nonkdim": 16,
|
77 |
+
"kpack": 2
|
78 |
+
},
|
79 |
+
"48": {
|
80 |
+
"BLOCK_SIZE_M": 16,
|
81 |
+
"BLOCK_SIZE_N": 64,
|
82 |
+
"BLOCK_SIZE_K": 256,
|
83 |
+
"GROUP_SIZE_M": 16,
|
84 |
+
"num_warps": 4,
|
85 |
+
"num_stages": 2,
|
86 |
+
"waves_per_eu": 0,
|
87 |
+
"matrix_instr_nonkdim": 16,
|
88 |
+
"kpack": 2
|
89 |
+
},
|
90 |
+
"64": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 32,
|
95 |
+
"num_warps": 8,
|
96 |
+
"num_stages": 2,
|
97 |
+
"waves_per_eu": 0,
|
98 |
+
"matrix_instr_nonkdim": 16,
|
99 |
+
"kpack": 1
|
100 |
+
},
|
101 |
+
"96": {
|
102 |
+
"BLOCK_SIZE_M": 16,
|
103 |
+
"BLOCK_SIZE_N": 64,
|
104 |
+
"BLOCK_SIZE_K": 128,
|
105 |
+
"GROUP_SIZE_M": 32,
|
106 |
+
"num_warps": 1,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0,
|
109 |
+
"matrix_instr_nonkdim": 16,
|
110 |
+
"kpack": 2
|
111 |
+
},
|
112 |
+
"128": {
|
113 |
+
"BLOCK_SIZE_M": 16,
|
114 |
+
"BLOCK_SIZE_N": 64,
|
115 |
+
"BLOCK_SIZE_K": 128,
|
116 |
+
"GROUP_SIZE_M": 32,
|
117 |
+
"num_warps": 1,
|
118 |
+
"num_stages": 2,
|
119 |
+
"waves_per_eu": 0,
|
120 |
+
"matrix_instr_nonkdim": 16,
|
121 |
+
"kpack": 2
|
122 |
+
},
|
123 |
+
"256": {
|
124 |
+
"BLOCK_SIZE_M": 64,
|
125 |
+
"BLOCK_SIZE_N": 64,
|
126 |
+
"BLOCK_SIZE_K": 128,
|
127 |
+
"GROUP_SIZE_M": 32,
|
128 |
+
"num_warps": 8,
|
129 |
+
"num_stages": 2,
|
130 |
+
"waves_per_eu": 0,
|
131 |
+
"matrix_instr_nonkdim": 16,
|
132 |
+
"kpack": 1
|
133 |
+
},
|
134 |
+
"512": {
|
135 |
+
"BLOCK_SIZE_M": 64,
|
136 |
+
"BLOCK_SIZE_N": 64,
|
137 |
+
"BLOCK_SIZE_K": 128,
|
138 |
+
"GROUP_SIZE_M": 8,
|
139 |
+
"num_warps": 8,
|
140 |
+
"num_stages": 2,
|
141 |
+
"waves_per_eu": 0,
|
142 |
+
"matrix_instr_nonkdim": 16,
|
143 |
+
"kpack": 1
|
144 |
+
},
|
145 |
+
"1024": {
|
146 |
+
"BLOCK_SIZE_M": 64,
|
147 |
+
"BLOCK_SIZE_N": 64,
|
148 |
+
"BLOCK_SIZE_K": 128,
|
149 |
+
"GROUP_SIZE_M": 8,
|
150 |
+
"num_warps": 8,
|
151 |
+
"num_stages": 2,
|
152 |
+
"waves_per_eu": 0,
|
153 |
+
"matrix_instr_nonkdim": 16,
|
154 |
+
"kpack": 2
|
155 |
+
},
|
156 |
+
"1536": {
|
157 |
+
"BLOCK_SIZE_M": 64,
|
158 |
+
"BLOCK_SIZE_N": 128,
|
159 |
+
"BLOCK_SIZE_K": 128,
|
160 |
+
"GROUP_SIZE_M": 8,
|
161 |
+
"num_warps": 8,
|
162 |
+
"num_stages": 2,
|
163 |
+
"waves_per_eu": 0,
|
164 |
+
"matrix_instr_nonkdim": 16,
|
165 |
+
"kpack": 2
|
166 |
+
},
|
167 |
+
"2048": {
|
168 |
+
"BLOCK_SIZE_M": 128,
|
169 |
+
"BLOCK_SIZE_N": 128,
|
170 |
+
"BLOCK_SIZE_K": 64,
|
171 |
+
"GROUP_SIZE_M": 16,
|
172 |
+
"num_warps": 8,
|
173 |
+
"num_stages": 2,
|
174 |
+
"waves_per_eu": 0,
|
175 |
+
"matrix_instr_nonkdim": 16,
|
176 |
+
"kpack": 2
|
177 |
+
},
|
178 |
+
"3072": {
|
179 |
+
"BLOCK_SIZE_M": 128,
|
180 |
+
"BLOCK_SIZE_N": 128,
|
181 |
+
"BLOCK_SIZE_K": 64,
|
182 |
+
"GROUP_SIZE_M": 8,
|
183 |
+
"num_warps": 8,
|
184 |
+
"num_stages": 2,
|
185 |
+
"waves_per_eu": 0,
|
186 |
+
"matrix_instr_nonkdim": 16,
|
187 |
+
"kpack": 2
|
188 |
+
},
|
189 |
+
"4096": {
|
190 |
+
"BLOCK_SIZE_M": 128,
|
191 |
+
"BLOCK_SIZE_N": 128,
|
192 |
+
"BLOCK_SIZE_K": 64,
|
193 |
+
"GROUP_SIZE_M": 8,
|
194 |
+
"num_warps": 8,
|
195 |
+
"num_stages": 2,
|
196 |
+
"waves_per_eu": 0,
|
197 |
+
"matrix_instr_nonkdim": 16,
|
198 |
+
"kpack": 2
|
199 |
+
}
|
200 |
+
}
|
torch-ext/moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 256,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0,
|
10 |
+
"matrix_instr_nonkdim": 16,
|
11 |
+
"kpack": 2
|
12 |
+
},
|
13 |
+
"2": {
|
14 |
+
"BLOCK_SIZE_M": 16,
|
15 |
+
"BLOCK_SIZE_N": 128,
|
16 |
+
"BLOCK_SIZE_K": 64,
|
17 |
+
"GROUP_SIZE_M": 1,
|
18 |
+
"num_warps": 8,
|
19 |
+
"num_stages": 2,
|
20 |
+
"waves_per_eu": 0,
|
21 |
+
"matrix_instr_nonkdim": 16,
|
22 |
+
"kpack": 1
|
23 |
+
},
|
24 |
+
"4": {
|
25 |
+
"BLOCK_SIZE_M": 16,
|
26 |
+
"BLOCK_SIZE_N": 128,
|
27 |
+
"BLOCK_SIZE_K": 128,
|
28 |
+
"GROUP_SIZE_M": 1,
|
29 |
+
"num_warps": 8,
|
30 |
+
"num_stages": 2,
|
31 |
+
"waves_per_eu": 0,
|
32 |
+
"matrix_instr_nonkdim": 16,
|
33 |
+
"kpack": 2
|
34 |
+
},
|
35 |
+
"8": {
|
36 |
+
"BLOCK_SIZE_M": 16,
|
37 |
+
"BLOCK_SIZE_N": 64,
|
38 |
+
"BLOCK_SIZE_K": 128,
|
39 |
+
"GROUP_SIZE_M": 1,
|
40 |
+
"num_warps": 4,
|
41 |
+
"num_stages": 2,
|
42 |
+
"waves_per_eu": 0,
|
43 |
+
"matrix_instr_nonkdim": 16,
|
44 |
+
"kpack": 2
|
45 |
+
},
|
46 |
+
"16": {
|
47 |
+
"BLOCK_SIZE_M": 16,
|
48 |
+
"BLOCK_SIZE_N": 64,
|
49 |
+
"BLOCK_SIZE_K": 128,
|
50 |
+
"GROUP_SIZE_M": 1,
|
51 |
+
"num_warps": 4,
|
52 |
+
"num_stages": 2,
|
53 |
+
"waves_per_eu": 0,
|
54 |
+
"matrix_instr_nonkdim": 16,
|
55 |
+
"kpack": 2
|
56 |
+
},
|
57 |
+
"24": {
|
58 |
+
"BLOCK_SIZE_M": 16,
|
59 |
+
"BLOCK_SIZE_N": 64,
|
60 |
+
"BLOCK_SIZE_K": 128,
|
61 |
+
"GROUP_SIZE_M": 1,
|
62 |
+
"num_warps": 4,
|
63 |
+
"num_stages": 2,
|
64 |
+
"waves_per_eu": 0,
|
65 |
+
"matrix_instr_nonkdim": 16,
|
66 |
+
"kpack": 2
|
67 |
+
},
|
68 |
+
"32": {
|
69 |
+
"BLOCK_SIZE_M": 16,
|
70 |
+
"BLOCK_SIZE_N": 64,
|
71 |
+
"BLOCK_SIZE_K": 256,
|
72 |
+
"GROUP_SIZE_M": 16,
|
73 |
+
"num_warps": 4,
|
74 |
+
"num_stages": 2,
|
75 |
+
"waves_per_eu": 0,
|
76 |
+
"matrix_instr_nonkdim": 16,
|
77 |
+
"kpack": 2
|
78 |
+
},
|
79 |
+
"48": {
|
80 |
+
"BLOCK_SIZE_M": 16,
|
81 |
+
"BLOCK_SIZE_N": 64,
|
82 |
+
"BLOCK_SIZE_K": 256,
|
83 |
+
"GROUP_SIZE_M": 16,
|
84 |
+
"num_warps": 4,
|
85 |
+
"num_stages": 2,
|
86 |
+
"waves_per_eu": 0,
|
87 |
+
"matrix_instr_nonkdim": 16,
|
88 |
+
"kpack": 2
|
89 |
+
},
|
90 |
+
"64": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 32,
|
95 |
+
"num_warps": 8,
|
96 |
+
"num_stages": 2,
|
97 |
+
"waves_per_eu": 0,
|
98 |
+
"matrix_instr_nonkdim": 16,
|
99 |
+
"kpack": 1
|
100 |
+
},
|
101 |
+
"96": {
|
102 |
+
"BLOCK_SIZE_M": 16,
|
103 |
+
"BLOCK_SIZE_N": 64,
|
104 |
+
"BLOCK_SIZE_K": 128,
|
105 |
+
"GROUP_SIZE_M": 32,
|
106 |
+
"num_warps": 1,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0,
|
109 |
+
"matrix_instr_nonkdim": 16,
|
110 |
+
"kpack": 2
|
111 |
+
},
|
112 |
+
"128": {
|
113 |
+
"BLOCK_SIZE_M": 16,
|
114 |
+
"BLOCK_SIZE_N": 64,
|
115 |
+
"BLOCK_SIZE_K": 128,
|
116 |
+
"GROUP_SIZE_M": 32,
|
117 |
+
"num_warps": 1,
|
118 |
+
"num_stages": 2,
|
119 |
+
"waves_per_eu": 0,
|
120 |
+
"matrix_instr_nonkdim": 16,
|
121 |
+
"kpack": 2
|
122 |
+
},
|
123 |
+
"256": {
|
124 |
+
"BLOCK_SIZE_M": 64,
|
125 |
+
"BLOCK_SIZE_N": 64,
|
126 |
+
"BLOCK_SIZE_K": 128,
|
127 |
+
"GROUP_SIZE_M": 32,
|
128 |
+
"num_warps": 8,
|
129 |
+
"num_stages": 2,
|
130 |
+
"waves_per_eu": 0,
|
131 |
+
"matrix_instr_nonkdim": 16,
|
132 |
+
"kpack": 1
|
133 |
+
},
|
134 |
+
"512": {
|
135 |
+
"BLOCK_SIZE_M": 64,
|
136 |
+
"BLOCK_SIZE_N": 64,
|
137 |
+
"BLOCK_SIZE_K": 128,
|
138 |
+
"GROUP_SIZE_M": 8,
|
139 |
+
"num_warps": 8,
|
140 |
+
"num_stages": 2,
|
141 |
+
"waves_per_eu": 0,
|
142 |
+
"matrix_instr_nonkdim": 16,
|
143 |
+
"kpack": 1
|
144 |
+
},
|
145 |
+
"1024": {
|
146 |
+
"BLOCK_SIZE_M": 64,
|
147 |
+
"BLOCK_SIZE_N": 64,
|
148 |
+
"BLOCK_SIZE_K": 128,
|
149 |
+
"GROUP_SIZE_M": 8,
|
150 |
+
"num_warps": 8,
|
151 |
+
"num_stages": 2,
|
152 |
+
"waves_per_eu": 0,
|
153 |
+
"matrix_instr_nonkdim": 16,
|
154 |
+
"kpack": 2
|
155 |
+
},
|
156 |
+
"1536": {
|
157 |
+
"BLOCK_SIZE_M": 64,
|
158 |
+
"BLOCK_SIZE_N": 128,
|
159 |
+
"BLOCK_SIZE_K": 128,
|
160 |
+
"GROUP_SIZE_M": 8,
|
161 |
+
"num_warps": 8,
|
162 |
+
"num_stages": 2,
|
163 |
+
"waves_per_eu": 0,
|
164 |
+
"matrix_instr_nonkdim": 16,
|
165 |
+
"kpack": 2
|
166 |
+
},
|
167 |
+
"2048": {
|
168 |
+
"BLOCK_SIZE_M": 128,
|
169 |
+
"BLOCK_SIZE_N": 128,
|
170 |
+
"BLOCK_SIZE_K": 64,
|
171 |
+
"GROUP_SIZE_M": 16,
|
172 |
+
"num_warps": 8,
|
173 |
+
"num_stages": 2,
|
174 |
+
"waves_per_eu": 0,
|
175 |
+
"matrix_instr_nonkdim": 16,
|
176 |
+
"kpack": 2
|
177 |
+
},
|
178 |
+
"3072": {
|
179 |
+
"BLOCK_SIZE_M": 128,
|
180 |
+
"BLOCK_SIZE_N": 128,
|
181 |
+
"BLOCK_SIZE_K": 64,
|
182 |
+
"GROUP_SIZE_M": 8,
|
183 |
+
"num_warps": 8,
|
184 |
+
"num_stages": 2,
|
185 |
+
"waves_per_eu": 0,
|
186 |
+
"matrix_instr_nonkdim": 16,
|
187 |
+
"kpack": 2
|
188 |
+
},
|
189 |
+
"4096": {
|
190 |
+
"BLOCK_SIZE_M": 128,
|
191 |
+
"BLOCK_SIZE_N": 128,
|
192 |
+
"BLOCK_SIZE_K": 64,
|
193 |
+
"GROUP_SIZE_M": 8,
|
194 |
+
"num_warps": 8,
|
195 |
+
"num_stages": 2,
|
196 |
+
"waves_per_eu": 0,
|
197 |
+
"matrix_instr_nonkdim": 16,
|
198 |
+
"kpack": 2
|
199 |
+
}
|
200 |
+
}
|
torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 32,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 64,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 64,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 128,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 64,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 64,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 16,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 4
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 16,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 2
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 2
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 64,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 2
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 64,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 4
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 16,
|
84 |
+
"BLOCK_SIZE_N": 64,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 64,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 5
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 64,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 4
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 32,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 1,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 32,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 64,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 32,
|
116 |
+
"BLOCK_SIZE_N": 128,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 16,
|
119 |
+
"num_warps": 4,
|
120 |
+
"num_stages": 3
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 32,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 16,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 32,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 32,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 64,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 3
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 64,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 64,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 3
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 16,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 64,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 64,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 1,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 64,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 64,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 32,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 2
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 32,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 2
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 2
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 2
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 16,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 64,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 2
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 64,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 2
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 32,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 1,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 32,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 64,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 64,
|
116 |
+
"BLOCK_SIZE_N": 128,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 64,
|
119 |
+
"num_warps": 4,
|
120 |
+
"num_stages": 3
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 32,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 16,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 32,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 32,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 32,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 3
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 64,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 64,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 3
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 128,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 64,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 64,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 64,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 1,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 2
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 2
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 2
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 16,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 64,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 16,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 64,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 2
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 32,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 64,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 32,
|
116 |
+
"BLOCK_SIZE_N": 128,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 64,
|
119 |
+
"num_warps": 4,
|
120 |
+
"num_stages": 3
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 32,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 32,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 32,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 32,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 64,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 3
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 64,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 64,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 3
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 1,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 64,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 64,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 1,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 64,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 64,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 16,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 2
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 2
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 2
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 2
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 16,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 64,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 2
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 64,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 2
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 32,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 1,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 32,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 32,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 64,
|
116 |
+
"BLOCK_SIZE_N": 64,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 16,
|
119 |
+
"num_warps": 4,
|
120 |
+
"num_stages": 3
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 32,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 32,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 32,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 64,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 16,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 3
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 64,
|
12 |
+
"BLOCK_SIZE_N": 64,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 1,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 3
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 64,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 16,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 64,
|
28 |
+
"BLOCK_SIZE_N": 64,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 32,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 64,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 16,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 64,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 64,
|
52 |
+
"BLOCK_SIZE_N": 64,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 64,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 64,
|
60 |
+
"BLOCK_SIZE_N": 64,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 64,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 64,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 16,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 64,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 64,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 32,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 64,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 64,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 64,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 64,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 64,
|
116 |
+
"BLOCK_SIZE_N": 128,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 64,
|
119 |
+
"num_warps": 4,
|
120 |
+
"num_stages": 3
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 64,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 32,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 64,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 64,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 32,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 64,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 16,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 3
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 64,
|
12 |
+
"BLOCK_SIZE_N": 64,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 16,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 3
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 128,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 16,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 64,
|
28 |
+
"BLOCK_SIZE_N": 64,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 32,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 64,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 16,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 64,
|
44 |
+
"BLOCK_SIZE_N": 64,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 64,
|
52 |
+
"BLOCK_SIZE_N": 64,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 64,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 64,
|
60 |
+
"BLOCK_SIZE_N": 64,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 64,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 16,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 16,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 16,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 32,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 64,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 16,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 64,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 64,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 64,
|
116 |
+
"BLOCK_SIZE_N": 128,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 32,
|
119 |
+
"num_warps": 4,
|
120 |
+
"num_stages": 3
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 64,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 64,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 64,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 64,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 64,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 256,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 8,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0
|
10 |
+
},
|
11 |
+
"2": {
|
12 |
+
"BLOCK_SIZE_M": 16,
|
13 |
+
"BLOCK_SIZE_N": 128,
|
14 |
+
"BLOCK_SIZE_K": 256,
|
15 |
+
"GROUP_SIZE_M": 1,
|
16 |
+
"num_warps": 8,
|
17 |
+
"num_stages": 2,
|
18 |
+
"waves_per_eu": 0
|
19 |
+
},
|
20 |
+
"4": {
|
21 |
+
"BLOCK_SIZE_M": 16,
|
22 |
+
"BLOCK_SIZE_N": 128,
|
23 |
+
"BLOCK_SIZE_K": 256,
|
24 |
+
"GROUP_SIZE_M": 1,
|
25 |
+
"num_warps": 8,
|
26 |
+
"num_stages": 2,
|
27 |
+
"waves_per_eu": 0
|
28 |
+
},
|
29 |
+
"8": {
|
30 |
+
"BLOCK_SIZE_M": 16,
|
31 |
+
"BLOCK_SIZE_N": 128,
|
32 |
+
"BLOCK_SIZE_K": 128,
|
33 |
+
"GROUP_SIZE_M": 1,
|
34 |
+
"num_warps": 8,
|
35 |
+
"num_stages": 2,
|
36 |
+
"waves_per_eu": 0
|
37 |
+
},
|
38 |
+
"16": {
|
39 |
+
"BLOCK_SIZE_M": 16,
|
40 |
+
"BLOCK_SIZE_N": 128,
|
41 |
+
"BLOCK_SIZE_K": 128,
|
42 |
+
"GROUP_SIZE_M": 1,
|
43 |
+
"num_warps": 2,
|
44 |
+
"num_stages": 2,
|
45 |
+
"waves_per_eu": 0
|
46 |
+
},
|
47 |
+
"24": {
|
48 |
+
"BLOCK_SIZE_M": 16,
|
49 |
+
"BLOCK_SIZE_N": 128,
|
50 |
+
"BLOCK_SIZE_K": 128,
|
51 |
+
"GROUP_SIZE_M": 1,
|
52 |
+
"num_warps": 2,
|
53 |
+
"num_stages": 2,
|
54 |
+
"waves_per_eu": 0
|
55 |
+
},
|
56 |
+
"32": {
|
57 |
+
"BLOCK_SIZE_M": 16,
|
58 |
+
"BLOCK_SIZE_N": 128,
|
59 |
+
"BLOCK_SIZE_K": 128,
|
60 |
+
"GROUP_SIZE_M": 4,
|
61 |
+
"num_warps": 2,
|
62 |
+
"num_stages": 2,
|
63 |
+
"waves_per_eu": 0
|
64 |
+
},
|
65 |
+
"48": {
|
66 |
+
"BLOCK_SIZE_M": 16,
|
67 |
+
"BLOCK_SIZE_N": 128,
|
68 |
+
"BLOCK_SIZE_K": 128,
|
69 |
+
"GROUP_SIZE_M": 4,
|
70 |
+
"num_warps": 2,
|
71 |
+
"num_stages": 2,
|
72 |
+
"waves_per_eu": 0
|
73 |
+
},
|
74 |
+
"64": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 2,
|
80 |
+
"num_stages": 2,
|
81 |
+
"waves_per_eu": 0
|
82 |
+
},
|
83 |
+
"96": {
|
84 |
+
"BLOCK_SIZE_M": 16,
|
85 |
+
"BLOCK_SIZE_N": 128,
|
86 |
+
"BLOCK_SIZE_K": 128,
|
87 |
+
"GROUP_SIZE_M": 8,
|
88 |
+
"num_warps": 8,
|
89 |
+
"num_stages": 2,
|
90 |
+
"waves_per_eu": 0
|
91 |
+
},
|
92 |
+
"128": {
|
93 |
+
"BLOCK_SIZE_M": 16,
|
94 |
+
"BLOCK_SIZE_N": 128,
|
95 |
+
"BLOCK_SIZE_K": 128,
|
96 |
+
"GROUP_SIZE_M": 4,
|
97 |
+
"num_warps": 4,
|
98 |
+
"num_stages": 2,
|
99 |
+
"waves_per_eu": 0
|
100 |
+
},
|
101 |
+
"256": {
|
102 |
+
"BLOCK_SIZE_M": 16,
|
103 |
+
"BLOCK_SIZE_N": 128,
|
104 |
+
"BLOCK_SIZE_K": 128,
|
105 |
+
"GROUP_SIZE_M": 8,
|
106 |
+
"num_warps": 4,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0
|
109 |
+
},
|
110 |
+
"512": {
|
111 |
+
"BLOCK_SIZE_M": 32,
|
112 |
+
"BLOCK_SIZE_N": 128,
|
113 |
+
"BLOCK_SIZE_K": 128,
|
114 |
+
"GROUP_SIZE_M": 8,
|
115 |
+
"num_warps": 4,
|
116 |
+
"num_stages": 2,
|
117 |
+
"waves_per_eu": 0
|
118 |
+
},
|
119 |
+
"1024": {
|
120 |
+
"BLOCK_SIZE_M": 64,
|
121 |
+
"BLOCK_SIZE_N": 128,
|
122 |
+
"BLOCK_SIZE_K": 128,
|
123 |
+
"GROUP_SIZE_M": 8,
|
124 |
+
"num_warps": 2,
|
125 |
+
"num_stages": 2,
|
126 |
+
"waves_per_eu": 0
|
127 |
+
},
|
128 |
+
"1536": {
|
129 |
+
"BLOCK_SIZE_M": 64,
|
130 |
+
"BLOCK_SIZE_N": 128,
|
131 |
+
"BLOCK_SIZE_K": 128,
|
132 |
+
"GROUP_SIZE_M": 4,
|
133 |
+
"num_warps": 2,
|
134 |
+
"num_stages": 2,
|
135 |
+
"waves_per_eu": 0
|
136 |
+
},
|
137 |
+
"2048": {
|
138 |
+
"BLOCK_SIZE_M": 128,
|
139 |
+
"BLOCK_SIZE_N": 256,
|
140 |
+
"BLOCK_SIZE_K": 128,
|
141 |
+
"GROUP_SIZE_M": 8,
|
142 |
+
"num_warps": 4,
|
143 |
+
"num_stages": 2,
|
144 |
+
"waves_per_eu": 0
|
145 |
+
},
|
146 |
+
"3072": {
|
147 |
+
"BLOCK_SIZE_M": 128,
|
148 |
+
"BLOCK_SIZE_N": 256,
|
149 |
+
"BLOCK_SIZE_K": 128,
|
150 |
+
"GROUP_SIZE_M": 8,
|
151 |
+
"num_warps": 4,
|
152 |
+
"num_stages": 2,
|
153 |
+
"waves_per_eu": 0
|
154 |
+
},
|
155 |
+
"4096": {
|
156 |
+
"BLOCK_SIZE_M": 128,
|
157 |
+
"BLOCK_SIZE_N": 256,
|
158 |
+
"BLOCK_SIZE_K": 128,
|
159 |
+
"GROUP_SIZE_M": 4,
|
160 |
+
"num_warps": 4,
|
161 |
+
"num_stages": 2,
|
162 |
+
"waves_per_eu": 0
|
163 |
+
}
|
164 |
+
}
|
torch-ext/moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 16,
|
5 |
+
"BLOCK_SIZE_K": 256,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0,
|
10 |
+
"matrix_instr_nonkdim": 16,
|
11 |
+
"kpack": 1
|
12 |
+
},
|
13 |
+
"2": {
|
14 |
+
"BLOCK_SIZE_M": 16,
|
15 |
+
"BLOCK_SIZE_N": 64,
|
16 |
+
"BLOCK_SIZE_K": 128,
|
17 |
+
"GROUP_SIZE_M": 1,
|
18 |
+
"num_warps": 4,
|
19 |
+
"num_stages": 2,
|
20 |
+
"waves_per_eu": 0,
|
21 |
+
"matrix_instr_nonkdim": 16,
|
22 |
+
"kpack": 1
|
23 |
+
},
|
24 |
+
"4": {
|
25 |
+
"BLOCK_SIZE_M": 64,
|
26 |
+
"BLOCK_SIZE_N": 64,
|
27 |
+
"BLOCK_SIZE_K": 128,
|
28 |
+
"GROUP_SIZE_M": 1,
|
29 |
+
"num_warps": 8,
|
30 |
+
"num_stages": 2,
|
31 |
+
"waves_per_eu": 0,
|
32 |
+
"matrix_instr_nonkdim": 16,
|
33 |
+
"kpack": 1
|
34 |
+
},
|
35 |
+
"8": {
|
36 |
+
"BLOCK_SIZE_M": 16,
|
37 |
+
"BLOCK_SIZE_N": 64,
|
38 |
+
"BLOCK_SIZE_K": 128,
|
39 |
+
"GROUP_SIZE_M": 1,
|
40 |
+
"num_warps": 4,
|
41 |
+
"num_stages": 2,
|
42 |
+
"waves_per_eu": 0,
|
43 |
+
"matrix_instr_nonkdim": 16,
|
44 |
+
"kpack": 1
|
45 |
+
},
|
46 |
+
"16": {
|
47 |
+
"BLOCK_SIZE_M": 16,
|
48 |
+
"BLOCK_SIZE_N": 64,
|
49 |
+
"BLOCK_SIZE_K": 128,
|
50 |
+
"GROUP_SIZE_M": 1,
|
51 |
+
"num_warps": 4,
|
52 |
+
"num_stages": 2,
|
53 |
+
"waves_per_eu": 0,
|
54 |
+
"matrix_instr_nonkdim": 16,
|
55 |
+
"kpack": 1
|
56 |
+
},
|
57 |
+
"24": {
|
58 |
+
"BLOCK_SIZE_M": 32,
|
59 |
+
"BLOCK_SIZE_N": 128,
|
60 |
+
"BLOCK_SIZE_K": 64,
|
61 |
+
"GROUP_SIZE_M": 1,
|
62 |
+
"num_warps": 8,
|
63 |
+
"num_stages": 2,
|
64 |
+
"waves_per_eu": 0,
|
65 |
+
"matrix_instr_nonkdim": 16,
|
66 |
+
"kpack": 2
|
67 |
+
},
|
68 |
+
"32": {
|
69 |
+
"BLOCK_SIZE_M": 16,
|
70 |
+
"BLOCK_SIZE_N": 64,
|
71 |
+
"BLOCK_SIZE_K": 128,
|
72 |
+
"GROUP_SIZE_M": 1,
|
73 |
+
"num_warps": 4,
|
74 |
+
"num_stages": 2,
|
75 |
+
"waves_per_eu": 0,
|
76 |
+
"matrix_instr_nonkdim": 16,
|
77 |
+
"kpack": 2
|
78 |
+
},
|
79 |
+
"48": {
|
80 |
+
"BLOCK_SIZE_M": 16,
|
81 |
+
"BLOCK_SIZE_N": 64,
|
82 |
+
"BLOCK_SIZE_K": 128,
|
83 |
+
"GROUP_SIZE_M": 1,
|
84 |
+
"num_warps": 2,
|
85 |
+
"num_stages": 2,
|
86 |
+
"waves_per_eu": 0,
|
87 |
+
"matrix_instr_nonkdim": 16,
|
88 |
+
"kpack": 2
|
89 |
+
},
|
90 |
+
"64": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 64,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 2,
|
96 |
+
"num_stages": 2,
|
97 |
+
"waves_per_eu": 0,
|
98 |
+
"matrix_instr_nonkdim": 16,
|
99 |
+
"kpack": 2
|
100 |
+
},
|
101 |
+
"96": {
|
102 |
+
"BLOCK_SIZE_M": 16,
|
103 |
+
"BLOCK_SIZE_N": 32,
|
104 |
+
"BLOCK_SIZE_K": 128,
|
105 |
+
"GROUP_SIZE_M": 4,
|
106 |
+
"num_warps": 2,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0,
|
109 |
+
"matrix_instr_nonkdim": 16,
|
110 |
+
"kpack": 2
|
111 |
+
},
|
112 |
+
"128": {
|
113 |
+
"BLOCK_SIZE_M": 16,
|
114 |
+
"BLOCK_SIZE_N": 32,
|
115 |
+
"BLOCK_SIZE_K": 128,
|
116 |
+
"GROUP_SIZE_M": 4,
|
117 |
+
"num_warps": 1,
|
118 |
+
"num_stages": 2,
|
119 |
+
"waves_per_eu": 0,
|
120 |
+
"matrix_instr_nonkdim": 16,
|
121 |
+
"kpack": 2
|
122 |
+
},
|
123 |
+
"256": {
|
124 |
+
"BLOCK_SIZE_M": 16,
|
125 |
+
"BLOCK_SIZE_N": 32,
|
126 |
+
"BLOCK_SIZE_K": 128,
|
127 |
+
"GROUP_SIZE_M": 1,
|
128 |
+
"num_warps": 2,
|
129 |
+
"num_stages": 2,
|
130 |
+
"waves_per_eu": 0,
|
131 |
+
"matrix_instr_nonkdim": 16,
|
132 |
+
"kpack": 2
|
133 |
+
},
|
134 |
+
"512": {
|
135 |
+
"BLOCK_SIZE_M": 32,
|
136 |
+
"BLOCK_SIZE_N": 64,
|
137 |
+
"BLOCK_SIZE_K": 128,
|
138 |
+
"GROUP_SIZE_M": 8,
|
139 |
+
"num_warps": 8,
|
140 |
+
"num_stages": 2,
|
141 |
+
"waves_per_eu": 0,
|
142 |
+
"matrix_instr_nonkdim": 16,
|
143 |
+
"kpack": 2
|
144 |
+
},
|
145 |
+
"1024": {
|
146 |
+
"BLOCK_SIZE_M": 64,
|
147 |
+
"BLOCK_SIZE_N": 128,
|
148 |
+
"BLOCK_SIZE_K": 128,
|
149 |
+
"GROUP_SIZE_M": 8,
|
150 |
+
"num_warps": 8,
|
151 |
+
"num_stages": 2,
|
152 |
+
"waves_per_eu": 0,
|
153 |
+
"matrix_instr_nonkdim": 16,
|
154 |
+
"kpack": 2
|
155 |
+
},
|
156 |
+
"1536": {
|
157 |
+
"BLOCK_SIZE_M": 64,
|
158 |
+
"BLOCK_SIZE_N": 128,
|
159 |
+
"BLOCK_SIZE_K": 128,
|
160 |
+
"GROUP_SIZE_M": 8,
|
161 |
+
"num_warps": 8,
|
162 |
+
"num_stages": 2,
|
163 |
+
"waves_per_eu": 0,
|
164 |
+
"matrix_instr_nonkdim": 16,
|
165 |
+
"kpack": 2
|
166 |
+
},
|
167 |
+
"2048": {
|
168 |
+
"BLOCK_SIZE_M": 128,
|
169 |
+
"BLOCK_SIZE_N": 128,
|
170 |
+
"BLOCK_SIZE_K": 64,
|
171 |
+
"GROUP_SIZE_M": 8,
|
172 |
+
"num_warps": 8,
|
173 |
+
"num_stages": 2,
|
174 |
+
"waves_per_eu": 0,
|
175 |
+
"matrix_instr_nonkdim": 16,
|
176 |
+
"kpack": 2
|
177 |
+
},
|
178 |
+
"3072": {
|
179 |
+
"BLOCK_SIZE_M": 128,
|
180 |
+
"BLOCK_SIZE_N": 128,
|
181 |
+
"BLOCK_SIZE_K": 64,
|
182 |
+
"GROUP_SIZE_M": 8,
|
183 |
+
"num_warps": 8,
|
184 |
+
"num_stages": 2,
|
185 |
+
"waves_per_eu": 0,
|
186 |
+
"matrix_instr_nonkdim": 16,
|
187 |
+
"kpack": 2
|
188 |
+
},
|
189 |
+
"4096": {
|
190 |
+
"BLOCK_SIZE_M": 128,
|
191 |
+
"BLOCK_SIZE_N": 128,
|
192 |
+
"BLOCK_SIZE_K": 64,
|
193 |
+
"GROUP_SIZE_M": 8,
|
194 |
+
"num_warps": 8,
|
195 |
+
"num_stages": 2,
|
196 |
+
"waves_per_eu": 0,
|
197 |
+
"matrix_instr_nonkdim": 16,
|
198 |
+
"kpack": 2
|
199 |
+
}
|
200 |
+
}
|
torch-ext/moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 16,
|
5 |
+
"BLOCK_SIZE_K": 256,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0,
|
10 |
+
"matrix_instr_nonkdim": 16,
|
11 |
+
"kpack": 1
|
12 |
+
},
|
13 |
+
"2": {
|
14 |
+
"BLOCK_SIZE_M": 16,
|
15 |
+
"BLOCK_SIZE_N": 64,
|
16 |
+
"BLOCK_SIZE_K": 128,
|
17 |
+
"GROUP_SIZE_M": 1,
|
18 |
+
"num_warps": 4,
|
19 |
+
"num_stages": 2,
|
20 |
+
"waves_per_eu": 0,
|
21 |
+
"matrix_instr_nonkdim": 16,
|
22 |
+
"kpack": 1
|
23 |
+
},
|
24 |
+
"4": {
|
25 |
+
"BLOCK_SIZE_M": 64,
|
26 |
+
"BLOCK_SIZE_N": 64,
|
27 |
+
"BLOCK_SIZE_K": 128,
|
28 |
+
"GROUP_SIZE_M": 1,
|
29 |
+
"num_warps": 8,
|
30 |
+
"num_stages": 2,
|
31 |
+
"waves_per_eu": 0,
|
32 |
+
"matrix_instr_nonkdim": 16,
|
33 |
+
"kpack": 1
|
34 |
+
},
|
35 |
+
"8": {
|
36 |
+
"BLOCK_SIZE_M": 16,
|
37 |
+
"BLOCK_SIZE_N": 64,
|
38 |
+
"BLOCK_SIZE_K": 128,
|
39 |
+
"GROUP_SIZE_M": 1,
|
40 |
+
"num_warps": 4,
|
41 |
+
"num_stages": 2,
|
42 |
+
"waves_per_eu": 0,
|
43 |
+
"matrix_instr_nonkdim": 16,
|
44 |
+
"kpack": 1
|
45 |
+
},
|
46 |
+
"16": {
|
47 |
+
"BLOCK_SIZE_M": 16,
|
48 |
+
"BLOCK_SIZE_N": 64,
|
49 |
+
"BLOCK_SIZE_K": 128,
|
50 |
+
"GROUP_SIZE_M": 1,
|
51 |
+
"num_warps": 4,
|
52 |
+
"num_stages": 2,
|
53 |
+
"waves_per_eu": 0,
|
54 |
+
"matrix_instr_nonkdim": 16,
|
55 |
+
"kpack": 1
|
56 |
+
},
|
57 |
+
"24": {
|
58 |
+
"BLOCK_SIZE_M": 32,
|
59 |
+
"BLOCK_SIZE_N": 128,
|
60 |
+
"BLOCK_SIZE_K": 64,
|
61 |
+
"GROUP_SIZE_M": 1,
|
62 |
+
"num_warps": 8,
|
63 |
+
"num_stages": 2,
|
64 |
+
"waves_per_eu": 0,
|
65 |
+
"matrix_instr_nonkdim": 16,
|
66 |
+
"kpack": 2
|
67 |
+
},
|
68 |
+
"32": {
|
69 |
+
"BLOCK_SIZE_M": 16,
|
70 |
+
"BLOCK_SIZE_N": 64,
|
71 |
+
"BLOCK_SIZE_K": 128,
|
72 |
+
"GROUP_SIZE_M": 1,
|
73 |
+
"num_warps": 4,
|
74 |
+
"num_stages": 2,
|
75 |
+
"waves_per_eu": 0,
|
76 |
+
"matrix_instr_nonkdim": 16,
|
77 |
+
"kpack": 2
|
78 |
+
},
|
79 |
+
"48": {
|
80 |
+
"BLOCK_SIZE_M": 16,
|
81 |
+
"BLOCK_SIZE_N": 64,
|
82 |
+
"BLOCK_SIZE_K": 128,
|
83 |
+
"GROUP_SIZE_M": 1,
|
84 |
+
"num_warps": 2,
|
85 |
+
"num_stages": 2,
|
86 |
+
"waves_per_eu": 0,
|
87 |
+
"matrix_instr_nonkdim": 16,
|
88 |
+
"kpack": 2
|
89 |
+
},
|
90 |
+
"64": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 64,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 2,
|
96 |
+
"num_stages": 2,
|
97 |
+
"waves_per_eu": 0,
|
98 |
+
"matrix_instr_nonkdim": 16,
|
99 |
+
"kpack": 2
|
100 |
+
},
|
101 |
+
"96": {
|
102 |
+
"BLOCK_SIZE_M": 16,
|
103 |
+
"BLOCK_SIZE_N": 32,
|
104 |
+
"BLOCK_SIZE_K": 128,
|
105 |
+
"GROUP_SIZE_M": 4,
|
106 |
+
"num_warps": 2,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0,
|
109 |
+
"matrix_instr_nonkdim": 16,
|
110 |
+
"kpack": 2
|
111 |
+
},
|
112 |
+
"128": {
|
113 |
+
"BLOCK_SIZE_M": 16,
|
114 |
+
"BLOCK_SIZE_N": 32,
|
115 |
+
"BLOCK_SIZE_K": 128,
|
116 |
+
"GROUP_SIZE_M": 4,
|
117 |
+
"num_warps": 1,
|
118 |
+
"num_stages": 2,
|
119 |
+
"waves_per_eu": 0,
|
120 |
+
"matrix_instr_nonkdim": 16,
|
121 |
+
"kpack": 2
|
122 |
+
},
|
123 |
+
"256": {
|
124 |
+
"BLOCK_SIZE_M": 16,
|
125 |
+
"BLOCK_SIZE_N": 32,
|
126 |
+
"BLOCK_SIZE_K": 128,
|
127 |
+
"GROUP_SIZE_M": 1,
|
128 |
+
"num_warps": 2,
|
129 |
+
"num_stages": 2,
|
130 |
+
"waves_per_eu": 0,
|
131 |
+
"matrix_instr_nonkdim": 16,
|
132 |
+
"kpack": 2
|
133 |
+
},
|
134 |
+
"512": {
|
135 |
+
"BLOCK_SIZE_M": 32,
|
136 |
+
"BLOCK_SIZE_N": 64,
|
137 |
+
"BLOCK_SIZE_K": 128,
|
138 |
+
"GROUP_SIZE_M": 8,
|
139 |
+
"num_warps": 8,
|
140 |
+
"num_stages": 2,
|
141 |
+
"waves_per_eu": 0,
|
142 |
+
"matrix_instr_nonkdim": 16,
|
143 |
+
"kpack": 2
|
144 |
+
},
|
145 |
+
"1024": {
|
146 |
+
"BLOCK_SIZE_M": 64,
|
147 |
+
"BLOCK_SIZE_N": 128,
|
148 |
+
"BLOCK_SIZE_K": 128,
|
149 |
+
"GROUP_SIZE_M": 8,
|
150 |
+
"num_warps": 8,
|
151 |
+
"num_stages": 2,
|
152 |
+
"waves_per_eu": 0,
|
153 |
+
"matrix_instr_nonkdim": 16,
|
154 |
+
"kpack": 2
|
155 |
+
},
|
156 |
+
"1536": {
|
157 |
+
"BLOCK_SIZE_M": 64,
|
158 |
+
"BLOCK_SIZE_N": 128,
|
159 |
+
"BLOCK_SIZE_K": 128,
|
160 |
+
"GROUP_SIZE_M": 8,
|
161 |
+
"num_warps": 8,
|
162 |
+
"num_stages": 2,
|
163 |
+
"waves_per_eu": 0,
|
164 |
+
"matrix_instr_nonkdim": 16,
|
165 |
+
"kpack": 2
|
166 |
+
},
|
167 |
+
"2048": {
|
168 |
+
"BLOCK_SIZE_M": 128,
|
169 |
+
"BLOCK_SIZE_N": 128,
|
170 |
+
"BLOCK_SIZE_K": 64,
|
171 |
+
"GROUP_SIZE_M": 8,
|
172 |
+
"num_warps": 8,
|
173 |
+
"num_stages": 2,
|
174 |
+
"waves_per_eu": 0,
|
175 |
+
"matrix_instr_nonkdim": 16,
|
176 |
+
"kpack": 2
|
177 |
+
},
|
178 |
+
"3072": {
|
179 |
+
"BLOCK_SIZE_M": 128,
|
180 |
+
"BLOCK_SIZE_N": 128,
|
181 |
+
"BLOCK_SIZE_K": 64,
|
182 |
+
"GROUP_SIZE_M": 8,
|
183 |
+
"num_warps": 8,
|
184 |
+
"num_stages": 2,
|
185 |
+
"waves_per_eu": 0,
|
186 |
+
"matrix_instr_nonkdim": 16,
|
187 |
+
"kpack": 2
|
188 |
+
},
|
189 |
+
"4096": {
|
190 |
+
"BLOCK_SIZE_M": 128,
|
191 |
+
"BLOCK_SIZE_N": 128,
|
192 |
+
"BLOCK_SIZE_K": 64,
|
193 |
+
"GROUP_SIZE_M": 8,
|
194 |
+
"num_warps": 8,
|
195 |
+
"num_stages": 2,
|
196 |
+
"waves_per_eu": 0,
|
197 |
+
"matrix_instr_nonkdim": 16,
|
198 |
+
"kpack": 2
|
199 |
+
}
|
200 |
+
}
|
torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 32,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 16,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 256,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 64,
|
23 |
+
"num_warps": 8,
|
24 |
+
"num_stages": 4
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 256,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 32,
|
31 |
+
"num_warps": 8,
|
32 |
+
"num_stages": 4
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 64,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 5
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 64,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 256,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 256,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 256,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 32,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 64,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 4
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 16,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 64,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 4
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 64,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 4
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 16,
|
100 |
+
"BLOCK_SIZE_N": 256,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 32,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 16,
|
108 |
+
"BLOCK_SIZE_N": 256,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 16,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 1,
|
119 |
+
"num_warps": 4,
|
120 |
+
"num_stages": 3
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 16,
|
124 |
+
"BLOCK_SIZE_N": 256,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 32,
|
132 |
+
"BLOCK_SIZE_N": 256,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 32,
|
140 |
+
"BLOCK_SIZE_N": 256,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 3
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 32,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 3
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 128,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 1,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 32,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 1,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 64,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 64,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 32,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 64,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 16,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 32,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 16,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 16,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 16,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 16,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 16,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 64,
|
116 |
+
"BLOCK_SIZE_N": 64,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 16,
|
119 |
+
"num_warps": 4,
|
120 |
+
"num_stages": 3
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 16,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 32,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 64,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 4
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 64,
|
140 |
+
"BLOCK_SIZE_N": 64,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 5
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 16,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 3
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 64,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 64,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 4
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 64,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 64,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 256,
|
37 |
+
"BLOCK_SIZE_K": 64,
|
38 |
+
"GROUP_SIZE_M": 32,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 64,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 16,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 64,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 64,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 32,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 64,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 32,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 32,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 16,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 32,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 64,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 32,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 64,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 32,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 64,
|
116 |
+
"BLOCK_SIZE_N": 128,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 32,
|
119 |
+
"num_warps": 4,
|
120 |
+
"num_stages": 3
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 64,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 64,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 64,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 4
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 64,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 64,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 8,
|
8 |
+
"num_stages": 3
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 32,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 1,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 3
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 128,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 64,
|
23 |
+
"num_warps": 8,
|
24 |
+
"num_stages": 2
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 256,
|
30 |
+
"GROUP_SIZE_M": 32,
|
31 |
+
"num_warps": 8,
|
32 |
+
"num_stages": 2
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 256,
|
38 |
+
"GROUP_SIZE_M": 16,
|
39 |
+
"num_warps": 8,
|
40 |
+
"num_stages": 2
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 256,
|
46 |
+
"GROUP_SIZE_M": 16,
|
47 |
+
"num_warps": 8,
|
48 |
+
"num_stages": 2
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 256,
|
54 |
+
"GROUP_SIZE_M": 16,
|
55 |
+
"num_warps": 8,
|
56 |
+
"num_stages": 2
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 256,
|
62 |
+
"GROUP_SIZE_M": 16,
|
63 |
+
"num_warps": 8,
|
64 |
+
"num_stages": 2
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 256,
|
70 |
+
"GROUP_SIZE_M": 64,
|
71 |
+
"num_warps": 8,
|
72 |
+
"num_stages": 2
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 256,
|
78 |
+
"GROUP_SIZE_M": 16,
|
79 |
+
"num_warps": 8,
|
80 |
+
"num_stages": 2
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 16,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 256,
|
86 |
+
"GROUP_SIZE_M": 16,
|
87 |
+
"num_warps": 8,
|
88 |
+
"num_stages": 2
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 256,
|
94 |
+
"GROUP_SIZE_M": 32,
|
95 |
+
"num_warps": 8,
|
96 |
+
"num_stages": 2
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 32,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 256,
|
102 |
+
"GROUP_SIZE_M": 32,
|
103 |
+
"num_warps": 8,
|
104 |
+
"num_stages": 2
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 64,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 2
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 64,
|
116 |
+
"BLOCK_SIZE_N": 128,
|
117 |
+
"BLOCK_SIZE_K": 256,
|
118 |
+
"GROUP_SIZE_M": 1,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 2
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 64,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 2
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 64,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 2
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 64,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 2
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 16,
|
5 |
+
"BLOCK_SIZE_K": 256,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0,
|
10 |
+
"matrix_instr_nonkdim": 16,
|
11 |
+
"kpack": 1
|
12 |
+
},
|
13 |
+
"2": {
|
14 |
+
"BLOCK_SIZE_M": 16,
|
15 |
+
"BLOCK_SIZE_N": 64,
|
16 |
+
"BLOCK_SIZE_K": 128,
|
17 |
+
"GROUP_SIZE_M": 1,
|
18 |
+
"num_warps": 4,
|
19 |
+
"num_stages": 2,
|
20 |
+
"waves_per_eu": 0,
|
21 |
+
"matrix_instr_nonkdim": 16,
|
22 |
+
"kpack": 1
|
23 |
+
},
|
24 |
+
"4": {
|
25 |
+
"BLOCK_SIZE_M": 64,
|
26 |
+
"BLOCK_SIZE_N": 64,
|
27 |
+
"BLOCK_SIZE_K": 128,
|
28 |
+
"GROUP_SIZE_M": 1,
|
29 |
+
"num_warps": 8,
|
30 |
+
"num_stages": 2,
|
31 |
+
"waves_per_eu": 0,
|
32 |
+
"matrix_instr_nonkdim": 16,
|
33 |
+
"kpack": 1
|
34 |
+
},
|
35 |
+
"8": {
|
36 |
+
"BLOCK_SIZE_M": 16,
|
37 |
+
"BLOCK_SIZE_N": 64,
|
38 |
+
"BLOCK_SIZE_K": 128,
|
39 |
+
"GROUP_SIZE_M": 1,
|
40 |
+
"num_warps": 4,
|
41 |
+
"num_stages": 2,
|
42 |
+
"waves_per_eu": 0,
|
43 |
+
"matrix_instr_nonkdim": 16,
|
44 |
+
"kpack": 1
|
45 |
+
},
|
46 |
+
"16": {
|
47 |
+
"BLOCK_SIZE_M": 16,
|
48 |
+
"BLOCK_SIZE_N": 64,
|
49 |
+
"BLOCK_SIZE_K": 128,
|
50 |
+
"GROUP_SIZE_M": 1,
|
51 |
+
"num_warps": 4,
|
52 |
+
"num_stages": 2,
|
53 |
+
"waves_per_eu": 0,
|
54 |
+
"matrix_instr_nonkdim": 16,
|
55 |
+
"kpack": 1
|
56 |
+
},
|
57 |
+
"24": {
|
58 |
+
"BLOCK_SIZE_M": 32,
|
59 |
+
"BLOCK_SIZE_N": 128,
|
60 |
+
"BLOCK_SIZE_K": 64,
|
61 |
+
"GROUP_SIZE_M": 1,
|
62 |
+
"num_warps": 8,
|
63 |
+
"num_stages": 2,
|
64 |
+
"waves_per_eu": 0,
|
65 |
+
"matrix_instr_nonkdim": 16,
|
66 |
+
"kpack": 2
|
67 |
+
},
|
68 |
+
"32": {
|
69 |
+
"BLOCK_SIZE_M": 16,
|
70 |
+
"BLOCK_SIZE_N": 64,
|
71 |
+
"BLOCK_SIZE_K": 128,
|
72 |
+
"GROUP_SIZE_M": 1,
|
73 |
+
"num_warps": 4,
|
74 |
+
"num_stages": 2,
|
75 |
+
"waves_per_eu": 0,
|
76 |
+
"matrix_instr_nonkdim": 16,
|
77 |
+
"kpack": 2
|
78 |
+
},
|
79 |
+
"48": {
|
80 |
+
"BLOCK_SIZE_M": 16,
|
81 |
+
"BLOCK_SIZE_N": 64,
|
82 |
+
"BLOCK_SIZE_K": 128,
|
83 |
+
"GROUP_SIZE_M": 1,
|
84 |
+
"num_warps": 2,
|
85 |
+
"num_stages": 2,
|
86 |
+
"waves_per_eu": 0,
|
87 |
+
"matrix_instr_nonkdim": 16,
|
88 |
+
"kpack": 2
|
89 |
+
},
|
90 |
+
"64": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 64,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 2,
|
96 |
+
"num_stages": 2,
|
97 |
+
"waves_per_eu": 0,
|
98 |
+
"matrix_instr_nonkdim": 16,
|
99 |
+
"kpack": 2
|
100 |
+
},
|
101 |
+
"96": {
|
102 |
+
"BLOCK_SIZE_M": 16,
|
103 |
+
"BLOCK_SIZE_N": 32,
|
104 |
+
"BLOCK_SIZE_K": 128,
|
105 |
+
"GROUP_SIZE_M": 4,
|
106 |
+
"num_warps": 2,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0,
|
109 |
+
"matrix_instr_nonkdim": 16,
|
110 |
+
"kpack": 2
|
111 |
+
},
|
112 |
+
"128": {
|
113 |
+
"BLOCK_SIZE_M": 16,
|
114 |
+
"BLOCK_SIZE_N": 32,
|
115 |
+
"BLOCK_SIZE_K": 128,
|
116 |
+
"GROUP_SIZE_M": 4,
|
117 |
+
"num_warps": 1,
|
118 |
+
"num_stages": 2,
|
119 |
+
"waves_per_eu": 0,
|
120 |
+
"matrix_instr_nonkdim": 16,
|
121 |
+
"kpack": 2
|
122 |
+
},
|
123 |
+
"256": {
|
124 |
+
"BLOCK_SIZE_M": 16,
|
125 |
+
"BLOCK_SIZE_N": 32,
|
126 |
+
"BLOCK_SIZE_K": 128,
|
127 |
+
"GROUP_SIZE_M": 1,
|
128 |
+
"num_warps": 2,
|
129 |
+
"num_stages": 2,
|
130 |
+
"waves_per_eu": 0,
|
131 |
+
"matrix_instr_nonkdim": 16,
|
132 |
+
"kpack": 2
|
133 |
+
},
|
134 |
+
"512": {
|
135 |
+
"BLOCK_SIZE_M": 32,
|
136 |
+
"BLOCK_SIZE_N": 64,
|
137 |
+
"BLOCK_SIZE_K": 128,
|
138 |
+
"GROUP_SIZE_M": 8,
|
139 |
+
"num_warps": 8,
|
140 |
+
"num_stages": 2,
|
141 |
+
"waves_per_eu": 0,
|
142 |
+
"matrix_instr_nonkdim": 16,
|
143 |
+
"kpack": 2
|
144 |
+
},
|
145 |
+
"1024": {
|
146 |
+
"BLOCK_SIZE_M": 64,
|
147 |
+
"BLOCK_SIZE_N": 128,
|
148 |
+
"BLOCK_SIZE_K": 128,
|
149 |
+
"GROUP_SIZE_M": 8,
|
150 |
+
"num_warps": 8,
|
151 |
+
"num_stages": 2,
|
152 |
+
"waves_per_eu": 0,
|
153 |
+
"matrix_instr_nonkdim": 16,
|
154 |
+
"kpack": 2
|
155 |
+
},
|
156 |
+
"1536": {
|
157 |
+
"BLOCK_SIZE_M": 64,
|
158 |
+
"BLOCK_SIZE_N": 128,
|
159 |
+
"BLOCK_SIZE_K": 128,
|
160 |
+
"GROUP_SIZE_M": 8,
|
161 |
+
"num_warps": 8,
|
162 |
+
"num_stages": 2,
|
163 |
+
"waves_per_eu": 0,
|
164 |
+
"matrix_instr_nonkdim": 16,
|
165 |
+
"kpack": 2
|
166 |
+
},
|
167 |
+
"2048": {
|
168 |
+
"BLOCK_SIZE_M": 128,
|
169 |
+
"BLOCK_SIZE_N": 128,
|
170 |
+
"BLOCK_SIZE_K": 64,
|
171 |
+
"GROUP_SIZE_M": 8,
|
172 |
+
"num_warps": 8,
|
173 |
+
"num_stages": 2,
|
174 |
+
"waves_per_eu": 0,
|
175 |
+
"matrix_instr_nonkdim": 16,
|
176 |
+
"kpack": 2
|
177 |
+
},
|
178 |
+
"3072": {
|
179 |
+
"BLOCK_SIZE_M": 128,
|
180 |
+
"BLOCK_SIZE_N": 128,
|
181 |
+
"BLOCK_SIZE_K": 64,
|
182 |
+
"GROUP_SIZE_M": 8,
|
183 |
+
"num_warps": 8,
|
184 |
+
"num_stages": 2,
|
185 |
+
"waves_per_eu": 0,
|
186 |
+
"matrix_instr_nonkdim": 16,
|
187 |
+
"kpack": 2
|
188 |
+
},
|
189 |
+
"4096": {
|
190 |
+
"BLOCK_SIZE_M": 128,
|
191 |
+
"BLOCK_SIZE_N": 128,
|
192 |
+
"BLOCK_SIZE_K": 64,
|
193 |
+
"GROUP_SIZE_M": 8,
|
194 |
+
"num_warps": 8,
|
195 |
+
"num_stages": 2,
|
196 |
+
"waves_per_eu": 0,
|
197 |
+
"matrix_instr_nonkdim": 16,
|
198 |
+
"kpack": 2
|
199 |
+
}
|
200 |
+
}
|
torch-ext/moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 64,
|
6 |
+
"GROUP_SIZE_M": 16,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 3
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 64,
|
13 |
+
"BLOCK_SIZE_K": 64,
|
14 |
+
"GROUP_SIZE_M": 64,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 64,
|
22 |
+
"GROUP_SIZE_M": 64,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 4
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 64,
|
29 |
+
"BLOCK_SIZE_K": 64,
|
30 |
+
"GROUP_SIZE_M": 32,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 64,
|
38 |
+
"GROUP_SIZE_M": 64,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 64,
|
45 |
+
"BLOCK_SIZE_K": 64,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 64,
|
53 |
+
"BLOCK_SIZE_K": 64,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 64,
|
61 |
+
"BLOCK_SIZE_K": 64,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 64,
|
69 |
+
"BLOCK_SIZE_K": 64,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 64,
|
78 |
+
"GROUP_SIZE_M": 32,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 16,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 64,
|
86 |
+
"GROUP_SIZE_M": 64,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 64,
|
94 |
+
"GROUP_SIZE_M": 16,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 32,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 64,
|
102 |
+
"GROUP_SIZE_M": 64,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 32,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 64,
|
110 |
+
"GROUP_SIZE_M": 64,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 2
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 64,
|
116 |
+
"BLOCK_SIZE_N": 128,
|
117 |
+
"BLOCK_SIZE_K": 64,
|
118 |
+
"GROUP_SIZE_M": 64,
|
119 |
+
"num_warps": 4,
|
120 |
+
"num_stages": 2
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 32,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 64,
|
126 |
+
"GROUP_SIZE_M": 64,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 2
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 64,
|
134 |
+
"GROUP_SIZE_M": 1,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 64,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 64,
|
142 |
+
"GROUP_SIZE_M": 64,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 2
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0,
|
10 |
+
"matrix_instr_nonkdim": 16,
|
11 |
+
"kpack": 1
|
12 |
+
},
|
13 |
+
"2": {
|
14 |
+
"BLOCK_SIZE_M": 16,
|
15 |
+
"BLOCK_SIZE_N": 128,
|
16 |
+
"BLOCK_SIZE_K": 128,
|
17 |
+
"GROUP_SIZE_M": 1,
|
18 |
+
"num_warps": 8,
|
19 |
+
"num_stages": 2,
|
20 |
+
"waves_per_eu": 0,
|
21 |
+
"matrix_instr_nonkdim": 16,
|
22 |
+
"kpack": 1
|
23 |
+
},
|
24 |
+
"4": {
|
25 |
+
"BLOCK_SIZE_M": 16,
|
26 |
+
"BLOCK_SIZE_N": 32,
|
27 |
+
"BLOCK_SIZE_K": 64,
|
28 |
+
"GROUP_SIZE_M": 1,
|
29 |
+
"num_warps": 2,
|
30 |
+
"num_stages": 2,
|
31 |
+
"waves_per_eu": 0,
|
32 |
+
"matrix_instr_nonkdim": 16,
|
33 |
+
"kpack": 1
|
34 |
+
},
|
35 |
+
"8": {
|
36 |
+
"BLOCK_SIZE_M": 32,
|
37 |
+
"BLOCK_SIZE_N": 256,
|
38 |
+
"BLOCK_SIZE_K": 64,
|
39 |
+
"GROUP_SIZE_M": 1,
|
40 |
+
"num_warps": 8,
|
41 |
+
"num_stages": 2,
|
42 |
+
"waves_per_eu": 0,
|
43 |
+
"matrix_instr_nonkdim": 32,
|
44 |
+
"kpack": 2
|
45 |
+
},
|
46 |
+
"16": {
|
47 |
+
"BLOCK_SIZE_M": 32,
|
48 |
+
"BLOCK_SIZE_N": 128,
|
49 |
+
"BLOCK_SIZE_K": 64,
|
50 |
+
"GROUP_SIZE_M": 1,
|
51 |
+
"num_warps": 2,
|
52 |
+
"num_stages": 2,
|
53 |
+
"waves_per_eu": 0,
|
54 |
+
"matrix_instr_nonkdim": 32,
|
55 |
+
"kpack": 2
|
56 |
+
},
|
57 |
+
"24": {
|
58 |
+
"BLOCK_SIZE_M": 16,
|
59 |
+
"BLOCK_SIZE_N": 64,
|
60 |
+
"BLOCK_SIZE_K": 256,
|
61 |
+
"GROUP_SIZE_M": 1,
|
62 |
+
"num_warps": 4,
|
63 |
+
"num_stages": 2,
|
64 |
+
"waves_per_eu": 0,
|
65 |
+
"matrix_instr_nonkdim": 16,
|
66 |
+
"kpack": 2
|
67 |
+
},
|
68 |
+
"32": {
|
69 |
+
"BLOCK_SIZE_M": 16,
|
70 |
+
"BLOCK_SIZE_N": 64,
|
71 |
+
"BLOCK_SIZE_K": 256,
|
72 |
+
"GROUP_SIZE_M": 1,
|
73 |
+
"num_warps": 4,
|
74 |
+
"num_stages": 2,
|
75 |
+
"waves_per_eu": 0,
|
76 |
+
"matrix_instr_nonkdim": 16,
|
77 |
+
"kpack": 1
|
78 |
+
},
|
79 |
+
"48": {
|
80 |
+
"BLOCK_SIZE_M": 16,
|
81 |
+
"BLOCK_SIZE_N": 64,
|
82 |
+
"BLOCK_SIZE_K": 256,
|
83 |
+
"GROUP_SIZE_M": 4,
|
84 |
+
"num_warps": 4,
|
85 |
+
"num_stages": 2,
|
86 |
+
"waves_per_eu": 0,
|
87 |
+
"matrix_instr_nonkdim": 16,
|
88 |
+
"kpack": 2
|
89 |
+
},
|
90 |
+
"64": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 64,
|
93 |
+
"BLOCK_SIZE_K": 256,
|
94 |
+
"GROUP_SIZE_M": 4,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 2,
|
97 |
+
"waves_per_eu": 0,
|
98 |
+
"matrix_instr_nonkdim": 16,
|
99 |
+
"kpack": 2
|
100 |
+
},
|
101 |
+
"96": {
|
102 |
+
"BLOCK_SIZE_M": 16,
|
103 |
+
"BLOCK_SIZE_N": 64,
|
104 |
+
"BLOCK_SIZE_K": 256,
|
105 |
+
"GROUP_SIZE_M": 4,
|
106 |
+
"num_warps": 4,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0,
|
109 |
+
"matrix_instr_nonkdim": 16,
|
110 |
+
"kpack": 2
|
111 |
+
},
|
112 |
+
"128": {
|
113 |
+
"BLOCK_SIZE_M": 16,
|
114 |
+
"BLOCK_SIZE_N": 64,
|
115 |
+
"BLOCK_SIZE_K": 256,
|
116 |
+
"GROUP_SIZE_M": 4,
|
117 |
+
"num_warps": 4,
|
118 |
+
"num_stages": 2,
|
119 |
+
"waves_per_eu": 0,
|
120 |
+
"matrix_instr_nonkdim": 16,
|
121 |
+
"kpack": 2
|
122 |
+
},
|
123 |
+
"256": {
|
124 |
+
"BLOCK_SIZE_M": 32,
|
125 |
+
"BLOCK_SIZE_N": 32,
|
126 |
+
"BLOCK_SIZE_K": 256,
|
127 |
+
"GROUP_SIZE_M": 4,
|
128 |
+
"num_warps": 4,
|
129 |
+
"num_stages": 2,
|
130 |
+
"waves_per_eu": 0,
|
131 |
+
"matrix_instr_nonkdim": 16,
|
132 |
+
"kpack": 2
|
133 |
+
},
|
134 |
+
"512": {
|
135 |
+
"BLOCK_SIZE_M": 64,
|
136 |
+
"BLOCK_SIZE_N": 64,
|
137 |
+
"BLOCK_SIZE_K": 128,
|
138 |
+
"GROUP_SIZE_M": 32,
|
139 |
+
"num_warps": 8,
|
140 |
+
"num_stages": 2,
|
141 |
+
"waves_per_eu": 0,
|
142 |
+
"matrix_instr_nonkdim": 16,
|
143 |
+
"kpack": 1
|
144 |
+
},
|
145 |
+
"1024": {
|
146 |
+
"BLOCK_SIZE_M": 128,
|
147 |
+
"BLOCK_SIZE_N": 128,
|
148 |
+
"BLOCK_SIZE_K": 64,
|
149 |
+
"GROUP_SIZE_M": 32,
|
150 |
+
"num_warps": 8,
|
151 |
+
"num_stages": 2,
|
152 |
+
"waves_per_eu": 0,
|
153 |
+
"matrix_instr_nonkdim": 16,
|
154 |
+
"kpack": 2
|
155 |
+
},
|
156 |
+
"1536": {
|
157 |
+
"BLOCK_SIZE_M": 128,
|
158 |
+
"BLOCK_SIZE_N": 128,
|
159 |
+
"BLOCK_SIZE_K": 64,
|
160 |
+
"GROUP_SIZE_M": 32,
|
161 |
+
"num_warps": 8,
|
162 |
+
"num_stages": 2,
|
163 |
+
"waves_per_eu": 0,
|
164 |
+
"matrix_instr_nonkdim": 16,
|
165 |
+
"kpack": 2
|
166 |
+
},
|
167 |
+
"2048": {
|
168 |
+
"BLOCK_SIZE_M": 128,
|
169 |
+
"BLOCK_SIZE_N": 128,
|
170 |
+
"BLOCK_SIZE_K": 64,
|
171 |
+
"GROUP_SIZE_M": 4,
|
172 |
+
"num_warps": 8,
|
173 |
+
"num_stages": 2,
|
174 |
+
"waves_per_eu": 0,
|
175 |
+
"matrix_instr_nonkdim": 16,
|
176 |
+
"kpack": 2
|
177 |
+
},
|
178 |
+
"3072": {
|
179 |
+
"BLOCK_SIZE_M": 128,
|
180 |
+
"BLOCK_SIZE_N": 128,
|
181 |
+
"BLOCK_SIZE_K": 64,
|
182 |
+
"GROUP_SIZE_M": 8,
|
183 |
+
"num_warps": 8,
|
184 |
+
"num_stages": 2,
|
185 |
+
"waves_per_eu": 0,
|
186 |
+
"matrix_instr_nonkdim": 16,
|
187 |
+
"kpack": 2
|
188 |
+
},
|
189 |
+
"4096": {
|
190 |
+
"BLOCK_SIZE_M": 128,
|
191 |
+
"BLOCK_SIZE_N": 128,
|
192 |
+
"BLOCK_SIZE_K": 64,
|
193 |
+
"GROUP_SIZE_M": 4,
|
194 |
+
"num_warps": 8,
|
195 |
+
"num_stages": 2,
|
196 |
+
"waves_per_eu": 0,
|
197 |
+
"matrix_instr_nonkdim": 16,
|
198 |
+
"kpack": 2
|
199 |
+
}
|
200 |
+
}
|
torch-ext/moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 32,
|
5 |
+
"BLOCK_SIZE_K": 256,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 2,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0,
|
10 |
+
"matrix_instr_nonkdim": 16,
|
11 |
+
"kpack": 2
|
12 |
+
},
|
13 |
+
"2": {
|
14 |
+
"BLOCK_SIZE_M": 16,
|
15 |
+
"BLOCK_SIZE_N": 16,
|
16 |
+
"BLOCK_SIZE_K": 256,
|
17 |
+
"GROUP_SIZE_M": 1,
|
18 |
+
"num_warps": 4,
|
19 |
+
"num_stages": 2,
|
20 |
+
"waves_per_eu": 0,
|
21 |
+
"matrix_instr_nonkdim": 16,
|
22 |
+
"kpack": 2
|
23 |
+
},
|
24 |
+
"4": {
|
25 |
+
"BLOCK_SIZE_M": 16,
|
26 |
+
"BLOCK_SIZE_N": 16,
|
27 |
+
"BLOCK_SIZE_K": 128,
|
28 |
+
"GROUP_SIZE_M": 1,
|
29 |
+
"num_warps": 4,
|
30 |
+
"num_stages": 2,
|
31 |
+
"waves_per_eu": 0,
|
32 |
+
"matrix_instr_nonkdim": 16,
|
33 |
+
"kpack": 2
|
34 |
+
},
|
35 |
+
"8": {
|
36 |
+
"BLOCK_SIZE_M": 16,
|
37 |
+
"BLOCK_SIZE_N": 64,
|
38 |
+
"BLOCK_SIZE_K": 128,
|
39 |
+
"GROUP_SIZE_M": 1,
|
40 |
+
"num_warps": 4,
|
41 |
+
"num_stages": 2,
|
42 |
+
"waves_per_eu": 0,
|
43 |
+
"matrix_instr_nonkdim": 16,
|
44 |
+
"kpack": 2
|
45 |
+
},
|
46 |
+
"16": {
|
47 |
+
"BLOCK_SIZE_M": 16,
|
48 |
+
"BLOCK_SIZE_N": 16,
|
49 |
+
"BLOCK_SIZE_K": 64,
|
50 |
+
"GROUP_SIZE_M": 4,
|
51 |
+
"num_warps": 2,
|
52 |
+
"num_stages": 2,
|
53 |
+
"waves_per_eu": 0,
|
54 |
+
"matrix_instr_nonkdim": 16,
|
55 |
+
"kpack": 1
|
56 |
+
},
|
57 |
+
"24": {
|
58 |
+
"BLOCK_SIZE_M": 16,
|
59 |
+
"BLOCK_SIZE_N": 16,
|
60 |
+
"BLOCK_SIZE_K": 64,
|
61 |
+
"GROUP_SIZE_M": 4,
|
62 |
+
"num_warps": 2,
|
63 |
+
"num_stages": 2,
|
64 |
+
"waves_per_eu": 0,
|
65 |
+
"matrix_instr_nonkdim": 16,
|
66 |
+
"kpack": 2
|
67 |
+
},
|
68 |
+
"32": {
|
69 |
+
"BLOCK_SIZE_M": 16,
|
70 |
+
"BLOCK_SIZE_N": 32,
|
71 |
+
"BLOCK_SIZE_K": 64,
|
72 |
+
"GROUP_SIZE_M": 8,
|
73 |
+
"num_warps": 2,
|
74 |
+
"num_stages": 2,
|
75 |
+
"waves_per_eu": 0,
|
76 |
+
"matrix_instr_nonkdim": 16,
|
77 |
+
"kpack": 2
|
78 |
+
},
|
79 |
+
"48": {
|
80 |
+
"BLOCK_SIZE_M": 16,
|
81 |
+
"BLOCK_SIZE_N": 32,
|
82 |
+
"BLOCK_SIZE_K": 64,
|
83 |
+
"GROUP_SIZE_M": 8,
|
84 |
+
"num_warps": 2,
|
85 |
+
"num_stages": 2,
|
86 |
+
"waves_per_eu": 0,
|
87 |
+
"matrix_instr_nonkdim": 16,
|
88 |
+
"kpack": 1
|
89 |
+
},
|
90 |
+
"64": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 32,
|
93 |
+
"BLOCK_SIZE_K": 64,
|
94 |
+
"GROUP_SIZE_M": 8,
|
95 |
+
"num_warps": 2,
|
96 |
+
"num_stages": 2,
|
97 |
+
"waves_per_eu": 0,
|
98 |
+
"matrix_instr_nonkdim": 16,
|
99 |
+
"kpack": 1
|
100 |
+
},
|
101 |
+
"96": {
|
102 |
+
"BLOCK_SIZE_M": 16,
|
103 |
+
"BLOCK_SIZE_N": 16,
|
104 |
+
"BLOCK_SIZE_K": 64,
|
105 |
+
"GROUP_SIZE_M": 16,
|
106 |
+
"num_warps": 1,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0,
|
109 |
+
"matrix_instr_nonkdim": 16,
|
110 |
+
"kpack": 1
|
111 |
+
},
|
112 |
+
"128": {
|
113 |
+
"BLOCK_SIZE_M": 16,
|
114 |
+
"BLOCK_SIZE_N": 16,
|
115 |
+
"BLOCK_SIZE_K": 64,
|
116 |
+
"GROUP_SIZE_M": 16,
|
117 |
+
"num_warps": 1,
|
118 |
+
"num_stages": 2,
|
119 |
+
"waves_per_eu": 0,
|
120 |
+
"matrix_instr_nonkdim": 16,
|
121 |
+
"kpack": 1
|
122 |
+
},
|
123 |
+
"256": {
|
124 |
+
"BLOCK_SIZE_M": 32,
|
125 |
+
"BLOCK_SIZE_N": 64,
|
126 |
+
"BLOCK_SIZE_K": 64,
|
127 |
+
"GROUP_SIZE_M": 8,
|
128 |
+
"num_warps": 4,
|
129 |
+
"num_stages": 2,
|
130 |
+
"waves_per_eu": 0,
|
131 |
+
"matrix_instr_nonkdim": 16,
|
132 |
+
"kpack": 1
|
133 |
+
},
|
134 |
+
"512": {
|
135 |
+
"BLOCK_SIZE_M": 64,
|
136 |
+
"BLOCK_SIZE_N": 128,
|
137 |
+
"BLOCK_SIZE_K": 64,
|
138 |
+
"GROUP_SIZE_M": 32,
|
139 |
+
"num_warps": 8,
|
140 |
+
"num_stages": 2,
|
141 |
+
"waves_per_eu": 0,
|
142 |
+
"matrix_instr_nonkdim": 16,
|
143 |
+
"kpack": 1
|
144 |
+
},
|
145 |
+
"1024": {
|
146 |
+
"BLOCK_SIZE_M": 128,
|
147 |
+
"BLOCK_SIZE_N": 64,
|
148 |
+
"BLOCK_SIZE_K": 64,
|
149 |
+
"GROUP_SIZE_M": 16,
|
150 |
+
"num_warps": 4,
|
151 |
+
"num_stages": 2,
|
152 |
+
"waves_per_eu": 0,
|
153 |
+
"matrix_instr_nonkdim": 16,
|
154 |
+
"kpack": 2
|
155 |
+
},
|
156 |
+
"1536": {
|
157 |
+
"BLOCK_SIZE_M": 128,
|
158 |
+
"BLOCK_SIZE_N": 128,
|
159 |
+
"BLOCK_SIZE_K": 64,
|
160 |
+
"GROUP_SIZE_M": 8,
|
161 |
+
"num_warps": 8,
|
162 |
+
"num_stages": 2,
|
163 |
+
"waves_per_eu": 0,
|
164 |
+
"matrix_instr_nonkdim": 16,
|
165 |
+
"kpack": 2
|
166 |
+
},
|
167 |
+
"2048": {
|
168 |
+
"BLOCK_SIZE_M": 64,
|
169 |
+
"BLOCK_SIZE_N": 128,
|
170 |
+
"BLOCK_SIZE_K": 64,
|
171 |
+
"GROUP_SIZE_M": 8,
|
172 |
+
"num_warps": 8,
|
173 |
+
"num_stages": 2,
|
174 |
+
"waves_per_eu": 0,
|
175 |
+
"matrix_instr_nonkdim": 16,
|
176 |
+
"kpack": 1
|
177 |
+
},
|
178 |
+
"3072": {
|
179 |
+
"BLOCK_SIZE_M": 256,
|
180 |
+
"BLOCK_SIZE_N": 128,
|
181 |
+
"BLOCK_SIZE_K": 64,
|
182 |
+
"GROUP_SIZE_M": 4,
|
183 |
+
"num_warps": 8,
|
184 |
+
"num_stages": 2,
|
185 |
+
"waves_per_eu": 0,
|
186 |
+
"matrix_instr_nonkdim": 16,
|
187 |
+
"kpack": 2
|
188 |
+
},
|
189 |
+
"4096": {
|
190 |
+
"BLOCK_SIZE_M": 128,
|
191 |
+
"BLOCK_SIZE_N": 128,
|
192 |
+
"BLOCK_SIZE_K": 64,
|
193 |
+
"GROUP_SIZE_M": 8,
|
194 |
+
"num_warps": 8,
|
195 |
+
"num_stages": 2,
|
196 |
+
"waves_per_eu": 0,
|
197 |
+
"matrix_instr_nonkdim": 16,
|
198 |
+
"kpack": 2
|
199 |
+
}
|
200 |
+
}
|
torch-ext/moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 16,
|
5 |
+
"BLOCK_SIZE_K": 256,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0,
|
10 |
+
"matrix_instr_nonkdim": 16,
|
11 |
+
"kpack": 2
|
12 |
+
},
|
13 |
+
"2": {
|
14 |
+
"BLOCK_SIZE_M": 16,
|
15 |
+
"BLOCK_SIZE_N": 32,
|
16 |
+
"BLOCK_SIZE_K": 128,
|
17 |
+
"GROUP_SIZE_M": 1,
|
18 |
+
"num_warps": 4,
|
19 |
+
"num_stages": 2,
|
20 |
+
"waves_per_eu": 0,
|
21 |
+
"matrix_instr_nonkdim": 16,
|
22 |
+
"kpack": 2
|
23 |
+
},
|
24 |
+
"4": {
|
25 |
+
"BLOCK_SIZE_M": 16,
|
26 |
+
"BLOCK_SIZE_N": 64,
|
27 |
+
"BLOCK_SIZE_K": 128,
|
28 |
+
"GROUP_SIZE_M": 1,
|
29 |
+
"num_warps": 4,
|
30 |
+
"num_stages": 2,
|
31 |
+
"waves_per_eu": 0,
|
32 |
+
"matrix_instr_nonkdim": 16,
|
33 |
+
"kpack": 2
|
34 |
+
},
|
35 |
+
"8": {
|
36 |
+
"BLOCK_SIZE_M": 16,
|
37 |
+
"BLOCK_SIZE_N": 64,
|
38 |
+
"BLOCK_SIZE_K": 128,
|
39 |
+
"GROUP_SIZE_M": 1,
|
40 |
+
"num_warps": 4,
|
41 |
+
"num_stages": 2,
|
42 |
+
"waves_per_eu": 0,
|
43 |
+
"matrix_instr_nonkdim": 16,
|
44 |
+
"kpack": 2
|
45 |
+
},
|
46 |
+
"16": {
|
47 |
+
"BLOCK_SIZE_M": 32,
|
48 |
+
"BLOCK_SIZE_N": 64,
|
49 |
+
"BLOCK_SIZE_K": 64,
|
50 |
+
"GROUP_SIZE_M": 1,
|
51 |
+
"num_warps": 4,
|
52 |
+
"num_stages": 2,
|
53 |
+
"waves_per_eu": 0,
|
54 |
+
"matrix_instr_nonkdim": 16,
|
55 |
+
"kpack": 1
|
56 |
+
},
|
57 |
+
"24": {
|
58 |
+
"BLOCK_SIZE_M": 16,
|
59 |
+
"BLOCK_SIZE_N": 64,
|
60 |
+
"BLOCK_SIZE_K": 64,
|
61 |
+
"GROUP_SIZE_M": 1,
|
62 |
+
"num_warps": 2,
|
63 |
+
"num_stages": 2,
|
64 |
+
"waves_per_eu": 0,
|
65 |
+
"matrix_instr_nonkdim": 16,
|
66 |
+
"kpack": 2
|
67 |
+
},
|
68 |
+
"32": {
|
69 |
+
"BLOCK_SIZE_M": 16,
|
70 |
+
"BLOCK_SIZE_N": 64,
|
71 |
+
"BLOCK_SIZE_K": 64,
|
72 |
+
"GROUP_SIZE_M": 1,
|
73 |
+
"num_warps": 2,
|
74 |
+
"num_stages": 2,
|
75 |
+
"waves_per_eu": 0,
|
76 |
+
"matrix_instr_nonkdim": 16,
|
77 |
+
"kpack": 2
|
78 |
+
},
|
79 |
+
"48": {
|
80 |
+
"BLOCK_SIZE_M": 16,
|
81 |
+
"BLOCK_SIZE_N": 32,
|
82 |
+
"BLOCK_SIZE_K": 64,
|
83 |
+
"GROUP_SIZE_M": 8,
|
84 |
+
"num_warps": 1,
|
85 |
+
"num_stages": 2,
|
86 |
+
"waves_per_eu": 0,
|
87 |
+
"matrix_instr_nonkdim": 16,
|
88 |
+
"kpack": 1
|
89 |
+
},
|
90 |
+
"64": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 32,
|
93 |
+
"BLOCK_SIZE_K": 64,
|
94 |
+
"GROUP_SIZE_M": 8,
|
95 |
+
"num_warps": 2,
|
96 |
+
"num_stages": 2,
|
97 |
+
"waves_per_eu": 0,
|
98 |
+
"matrix_instr_nonkdim": 16,
|
99 |
+
"kpack": 1
|
100 |
+
},
|
101 |
+
"96": {
|
102 |
+
"BLOCK_SIZE_M": 16,
|
103 |
+
"BLOCK_SIZE_N": 32,
|
104 |
+
"BLOCK_SIZE_K": 64,
|
105 |
+
"GROUP_SIZE_M": 16,
|
106 |
+
"num_warps": 2,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0,
|
109 |
+
"matrix_instr_nonkdim": 16,
|
110 |
+
"kpack": 2
|
111 |
+
},
|
112 |
+
"128": {
|
113 |
+
"BLOCK_SIZE_M": 16,
|
114 |
+
"BLOCK_SIZE_N": 32,
|
115 |
+
"BLOCK_SIZE_K": 64,
|
116 |
+
"GROUP_SIZE_M": 4,
|
117 |
+
"num_warps": 2,
|
118 |
+
"num_stages": 2,
|
119 |
+
"waves_per_eu": 0,
|
120 |
+
"matrix_instr_nonkdim": 16,
|
121 |
+
"kpack": 2
|
122 |
+
},
|
123 |
+
"256": {
|
124 |
+
"BLOCK_SIZE_M": 32,
|
125 |
+
"BLOCK_SIZE_N": 32,
|
126 |
+
"BLOCK_SIZE_K": 64,
|
127 |
+
"GROUP_SIZE_M": 8,
|
128 |
+
"num_warps": 4,
|
129 |
+
"num_stages": 2,
|
130 |
+
"waves_per_eu": 0,
|
131 |
+
"matrix_instr_nonkdim": 16,
|
132 |
+
"kpack": 1
|
133 |
+
},
|
134 |
+
"512": {
|
135 |
+
"BLOCK_SIZE_M": 64,
|
136 |
+
"BLOCK_SIZE_N": 64,
|
137 |
+
"BLOCK_SIZE_K": 64,
|
138 |
+
"GROUP_SIZE_M": 32,
|
139 |
+
"num_warps": 4,
|
140 |
+
"num_stages": 2,
|
141 |
+
"waves_per_eu": 0,
|
142 |
+
"matrix_instr_nonkdim": 16,
|
143 |
+
"kpack": 2
|
144 |
+
},
|
145 |
+
"1024": {
|
146 |
+
"BLOCK_SIZE_M": 128,
|
147 |
+
"BLOCK_SIZE_N": 128,
|
148 |
+
"BLOCK_SIZE_K": 64,
|
149 |
+
"GROUP_SIZE_M": 4,
|
150 |
+
"num_warps": 4,
|
151 |
+
"num_stages": 2,
|
152 |
+
"waves_per_eu": 0,
|
153 |
+
"matrix_instr_nonkdim": 16,
|
154 |
+
"kpack": 2
|
155 |
+
},
|
156 |
+
"1536": {
|
157 |
+
"BLOCK_SIZE_M": 128,
|
158 |
+
"BLOCK_SIZE_N": 128,
|
159 |
+
"BLOCK_SIZE_K": 64,
|
160 |
+
"GROUP_SIZE_M": 16,
|
161 |
+
"num_warps": 4,
|
162 |
+
"num_stages": 2,
|
163 |
+
"waves_per_eu": 0,
|
164 |
+
"matrix_instr_nonkdim": 16,
|
165 |
+
"kpack": 1
|
166 |
+
},
|
167 |
+
"2048": {
|
168 |
+
"BLOCK_SIZE_M": 256,
|
169 |
+
"BLOCK_SIZE_N": 256,
|
170 |
+
"BLOCK_SIZE_K": 32,
|
171 |
+
"GROUP_SIZE_M": 16,
|
172 |
+
"num_warps": 8,
|
173 |
+
"num_stages": 2,
|
174 |
+
"waves_per_eu": 0,
|
175 |
+
"matrix_instr_nonkdim": 16,
|
176 |
+
"kpack": 1
|
177 |
+
},
|
178 |
+
"3072": {
|
179 |
+
"BLOCK_SIZE_M": 256,
|
180 |
+
"BLOCK_SIZE_N": 256,
|
181 |
+
"BLOCK_SIZE_K": 32,
|
182 |
+
"GROUP_SIZE_M": 16,
|
183 |
+
"num_warps": 8,
|
184 |
+
"num_stages": 2,
|
185 |
+
"waves_per_eu": 0,
|
186 |
+
"matrix_instr_nonkdim": 16,
|
187 |
+
"kpack": 1
|
188 |
+
},
|
189 |
+
"4096": {
|
190 |
+
"BLOCK_SIZE_M": 128,
|
191 |
+
"BLOCK_SIZE_N": 128,
|
192 |
+
"BLOCK_SIZE_K": 64,
|
193 |
+
"GROUP_SIZE_M": 8,
|
194 |
+
"num_warps": 4,
|
195 |
+
"num_stages": 2,
|
196 |
+
"waves_per_eu": 0,
|
197 |
+
"matrix_instr_nonkdim": 16,
|
198 |
+
"kpack": 2
|
199 |
+
}
|
200 |
+
}
|
torch-ext/moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 16,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0,
|
10 |
+
"matrix_instr_nonkdim": 16,
|
11 |
+
"kpack": 1
|
12 |
+
},
|
13 |
+
"2": {
|
14 |
+
"BLOCK_SIZE_M": 16,
|
15 |
+
"BLOCK_SIZE_N": 64,
|
16 |
+
"BLOCK_SIZE_K": 128,
|
17 |
+
"GROUP_SIZE_M": 1,
|
18 |
+
"num_warps": 4,
|
19 |
+
"num_stages": 2,
|
20 |
+
"waves_per_eu": 0,
|
21 |
+
"matrix_instr_nonkdim": 16,
|
22 |
+
"kpack": 2
|
23 |
+
},
|
24 |
+
"4": {
|
25 |
+
"BLOCK_SIZE_M": 16,
|
26 |
+
"BLOCK_SIZE_N": 128,
|
27 |
+
"BLOCK_SIZE_K": 128,
|
28 |
+
"GROUP_SIZE_M": 1,
|
29 |
+
"num_warps": 8,
|
30 |
+
"num_stages": 2,
|
31 |
+
"waves_per_eu": 0,
|
32 |
+
"matrix_instr_nonkdim": 16,
|
33 |
+
"kpack": 2
|
34 |
+
},
|
35 |
+
"8": {
|
36 |
+
"BLOCK_SIZE_M": 32,
|
37 |
+
"BLOCK_SIZE_N": 64,
|
38 |
+
"BLOCK_SIZE_K": 64,
|
39 |
+
"GROUP_SIZE_M": 1,
|
40 |
+
"num_warps": 4,
|
41 |
+
"num_stages": 2,
|
42 |
+
"waves_per_eu": 0,
|
43 |
+
"matrix_instr_nonkdim": 16,
|
44 |
+
"kpack": 1
|
45 |
+
},
|
46 |
+
"16": {
|
47 |
+
"BLOCK_SIZE_M": 16,
|
48 |
+
"BLOCK_SIZE_N": 16,
|
49 |
+
"BLOCK_SIZE_K": 64,
|
50 |
+
"GROUP_SIZE_M": 1,
|
51 |
+
"num_warps": 2,
|
52 |
+
"num_stages": 2,
|
53 |
+
"waves_per_eu": 0,
|
54 |
+
"matrix_instr_nonkdim": 16,
|
55 |
+
"kpack": 1
|
56 |
+
},
|
57 |
+
"24": {
|
58 |
+
"BLOCK_SIZE_M": 16,
|
59 |
+
"BLOCK_SIZE_N": 64,
|
60 |
+
"BLOCK_SIZE_K": 256,
|
61 |
+
"GROUP_SIZE_M": 4,
|
62 |
+
"num_warps": 4,
|
63 |
+
"num_stages": 2,
|
64 |
+
"waves_per_eu": 0,
|
65 |
+
"matrix_instr_nonkdim": 16,
|
66 |
+
"kpack": 2
|
67 |
+
},
|
68 |
+
"32": {
|
69 |
+
"BLOCK_SIZE_M": 16,
|
70 |
+
"BLOCK_SIZE_N": 64,
|
71 |
+
"BLOCK_SIZE_K": 256,
|
72 |
+
"GROUP_SIZE_M": 1,
|
73 |
+
"num_warps": 4,
|
74 |
+
"num_stages": 2,
|
75 |
+
"waves_per_eu": 0,
|
76 |
+
"matrix_instr_nonkdim": 16,
|
77 |
+
"kpack": 2
|
78 |
+
},
|
79 |
+
"48": {
|
80 |
+
"BLOCK_SIZE_M": 16,
|
81 |
+
"BLOCK_SIZE_N": 64,
|
82 |
+
"BLOCK_SIZE_K": 256,
|
83 |
+
"GROUP_SIZE_M": 4,
|
84 |
+
"num_warps": 4,
|
85 |
+
"num_stages": 2,
|
86 |
+
"waves_per_eu": 0,
|
87 |
+
"matrix_instr_nonkdim": 16,
|
88 |
+
"kpack": 2
|
89 |
+
},
|
90 |
+
"64": {
|
91 |
+
"BLOCK_SIZE_M": 16,
|
92 |
+
"BLOCK_SIZE_N": 64,
|
93 |
+
"BLOCK_SIZE_K": 256,
|
94 |
+
"GROUP_SIZE_M": 4,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 2,
|
97 |
+
"waves_per_eu": 0,
|
98 |
+
"matrix_instr_nonkdim": 16,
|
99 |
+
"kpack": 2
|
100 |
+
},
|
101 |
+
"96": {
|
102 |
+
"BLOCK_SIZE_M": 16,
|
103 |
+
"BLOCK_SIZE_N": 64,
|
104 |
+
"BLOCK_SIZE_K": 256,
|
105 |
+
"GROUP_SIZE_M": 4,
|
106 |
+
"num_warps": 4,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0,
|
109 |
+
"matrix_instr_nonkdim": 16,
|
110 |
+
"kpack": 2
|
111 |
+
},
|
112 |
+
"128": {
|
113 |
+
"BLOCK_SIZE_M": 16,
|
114 |
+
"BLOCK_SIZE_N": 64,
|
115 |
+
"BLOCK_SIZE_K": 256,
|
116 |
+
"GROUP_SIZE_M": 4,
|
117 |
+
"num_warps": 4,
|
118 |
+
"num_stages": 2,
|
119 |
+
"waves_per_eu": 0,
|
120 |
+
"matrix_instr_nonkdim": 16,
|
121 |
+
"kpack": 2
|
122 |
+
},
|
123 |
+
"256": {
|
124 |
+
"BLOCK_SIZE_M": 32,
|
125 |
+
"BLOCK_SIZE_N": 32,
|
126 |
+
"BLOCK_SIZE_K": 256,
|
127 |
+
"GROUP_SIZE_M": 4,
|
128 |
+
"num_warps": 4,
|
129 |
+
"num_stages": 2,
|
130 |
+
"waves_per_eu": 0,
|
131 |
+
"matrix_instr_nonkdim": 16,
|
132 |
+
"kpack": 2
|
133 |
+
},
|
134 |
+
"512": {
|
135 |
+
"BLOCK_SIZE_M": 64,
|
136 |
+
"BLOCK_SIZE_N": 64,
|
137 |
+
"BLOCK_SIZE_K": 128,
|
138 |
+
"GROUP_SIZE_M": 4,
|
139 |
+
"num_warps": 8,
|
140 |
+
"num_stages": 2,
|
141 |
+
"waves_per_eu": 0,
|
142 |
+
"matrix_instr_nonkdim": 16,
|
143 |
+
"kpack": 2
|
144 |
+
},
|
145 |
+
"1024": {
|
146 |
+
"BLOCK_SIZE_M": 64,
|
147 |
+
"BLOCK_SIZE_N": 64,
|
148 |
+
"BLOCK_SIZE_K": 128,
|
149 |
+
"GROUP_SIZE_M": 1,
|
150 |
+
"num_warps": 4,
|
151 |
+
"num_stages": 2,
|
152 |
+
"waves_per_eu": 0,
|
153 |
+
"matrix_instr_nonkdim": 16,
|
154 |
+
"kpack": 2
|
155 |
+
},
|
156 |
+
"1536": {
|
157 |
+
"BLOCK_SIZE_M": 128,
|
158 |
+
"BLOCK_SIZE_N": 128,
|
159 |
+
"BLOCK_SIZE_K": 64,
|
160 |
+
"GROUP_SIZE_M": 32,
|
161 |
+
"num_warps": 8,
|
162 |
+
"num_stages": 2,
|
163 |
+
"waves_per_eu": 0,
|
164 |
+
"matrix_instr_nonkdim": 16,
|
165 |
+
"kpack": 2
|
166 |
+
},
|
167 |
+
"2048": {
|
168 |
+
"BLOCK_SIZE_M": 128,
|
169 |
+
"BLOCK_SIZE_N": 128,
|
170 |
+
"BLOCK_SIZE_K": 64,
|
171 |
+
"GROUP_SIZE_M": 1,
|
172 |
+
"num_warps": 8,
|
173 |
+
"num_stages": 2,
|
174 |
+
"waves_per_eu": 0,
|
175 |
+
"matrix_instr_nonkdim": 16,
|
176 |
+
"kpack": 2
|
177 |
+
},
|
178 |
+
"3072": {
|
179 |
+
"BLOCK_SIZE_M": 128,
|
180 |
+
"BLOCK_SIZE_N": 128,
|
181 |
+
"BLOCK_SIZE_K": 64,
|
182 |
+
"GROUP_SIZE_M": 8,
|
183 |
+
"num_warps": 8,
|
184 |
+
"num_stages": 2,
|
185 |
+
"waves_per_eu": 0,
|
186 |
+
"matrix_instr_nonkdim": 16,
|
187 |
+
"kpack": 2
|
188 |
+
},
|
189 |
+
"4096": {
|
190 |
+
"BLOCK_SIZE_M": 128,
|
191 |
+
"BLOCK_SIZE_N": 128,
|
192 |
+
"BLOCK_SIZE_K": 64,
|
193 |
+
"GROUP_SIZE_M": 8,
|
194 |
+
"num_warps": 8,
|
195 |
+
"num_stages": 2,
|
196 |
+
"waves_per_eu": 0,
|
197 |
+
"matrix_instr_nonkdim": 16,
|
198 |
+
"kpack": 2
|
199 |
+
}
|
200 |
+
}
|
torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 32,
|
5 |
+
"BLOCK_SIZE_K": 64,
|
6 |
+
"GROUP_SIZE_M": 16,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 5
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 64,
|
14 |
+
"GROUP_SIZE_M": 64,
|
15 |
+
"num_warps": 8,
|
16 |
+
"num_stages": 5
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 64,
|
22 |
+
"GROUP_SIZE_M": 16,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 4
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 32,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 64,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 64,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 32,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 256,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 2
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 32,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 32,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 256,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 8,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 64,
|
100 |
+
"BLOCK_SIZE_N": 256,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 1,
|
103 |
+
"num_warps": 8,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 64,
|
108 |
+
"BLOCK_SIZE_N": 256,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 64,
|
118 |
+
"GROUP_SIZE_M": 1,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 64,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 64,
|
134 |
+
"GROUP_SIZE_M": 1,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 64,
|
142 |
+
"GROUP_SIZE_M": 1,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 64,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 32,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 64,
|
12 |
+
"BLOCK_SIZE_N": 64,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 1,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 64,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 1,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 64,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 32,
|
31 |
+
"num_warps": 8,
|
32 |
+
"num_stages": 4
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 64,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 64,
|
39 |
+
"num_warps": 8,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 64,
|
44 |
+
"BLOCK_SIZE_N": 64,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 64,
|
52 |
+
"BLOCK_SIZE_N": 64,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 16,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 64,
|
60 |
+
"BLOCK_SIZE_N": 64,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 16,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 64,
|
68 |
+
"BLOCK_SIZE_N": 64,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 64,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 64,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 64,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 64,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 1,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 256,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 4
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 1,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 256,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 4
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 256,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 1,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 4
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 256,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 64,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 4
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 64,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 64,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 1,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 3
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 64,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 64,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 64,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 64,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 64,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 1,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 64,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 16,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 4
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 64,
|
52 |
+
"BLOCK_SIZE_N": 256,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 5
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 64,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 4
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 64,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 4
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 4
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 4
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 4
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 128,
|
100 |
+
"BLOCK_SIZE_N": 256,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 1,
|
103 |
+
"num_warps": 8,
|
104 |
+
"num_stages": 4
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 256,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 4
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 16,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 256,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 16,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 4
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 256,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 1,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 4
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 256,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 4
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 64,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 16,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 2
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 32,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 5
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 32,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 64,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 16,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 2
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 256,
|
46 |
+
"GROUP_SIZE_M": 32,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 2
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 4
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 32,
|
63 |
+
"num_warps": 8,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 8,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 32,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 32,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 64,
|
93 |
+
"BLOCK_SIZE_K": 64,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 128,
|
100 |
+
"BLOCK_SIZE_N": 256,
|
101 |
+
"BLOCK_SIZE_K": 64,
|
102 |
+
"GROUP_SIZE_M": 1,
|
103 |
+
"num_warps": 8,
|
104 |
+
"num_stages": 4
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 256,
|
109 |
+
"BLOCK_SIZE_K": 64,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 4
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 64,
|
118 |
+
"GROUP_SIZE_M": 1,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 256,
|
125 |
+
"BLOCK_SIZE_K": 64,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 4
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 256,
|
133 |
+
"BLOCK_SIZE_K": 64,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 4
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 256,
|
141 |
+
"BLOCK_SIZE_K": 64,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 4
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 64,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 64,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 3
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 64,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 256,
|
14 |
+
"GROUP_SIZE_M": 64,
|
15 |
+
"num_warps": 8,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 64,
|
20 |
+
"BLOCK_SIZE_N": 128,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 16,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 5
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 64,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 16,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 4
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 64,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 16,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 64,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 256,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 8,
|
48 |
+
"num_stages": 4
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 64,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 256,
|
54 |
+
"GROUP_SIZE_M": 16,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 4
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 64,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 256,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 8,
|
64 |
+
"num_stages": 4
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 64,
|
68 |
+
"BLOCK_SIZE_N": 64,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 64,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 64,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 64,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 128,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 1,
|
103 |
+
"num_warps": 8,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 256,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 4
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 16,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 256,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 4
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 256,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 4
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 256,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 4
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 64,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 16,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 3
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 64,
|
12 |
+
"BLOCK_SIZE_N": 256,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 64,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 5
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 64,
|
20 |
+
"BLOCK_SIZE_N": 256,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 1,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 5
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 64,
|
28 |
+
"BLOCK_SIZE_N": 256,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 16,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 5
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 64,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 64,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 64,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 16,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 4
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 64,
|
52 |
+
"BLOCK_SIZE_N": 256,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 5
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 64,
|
60 |
+
"BLOCK_SIZE_N": 256,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 5
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 64,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 128,
|
100 |
+
"BLOCK_SIZE_N": 256,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 1,
|
103 |
+
"num_warps": 8,
|
104 |
+
"num_stages": 4
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 256,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 4
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 1,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 256,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 16,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 4
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 256,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 1,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 4
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 256,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 1,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 4
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 32,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 2
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 32,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 5
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 64,
|
22 |
+
"GROUP_SIZE_M": 16,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 5
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 64,
|
29 |
+
"BLOCK_SIZE_K": 256,
|
30 |
+
"GROUP_SIZE_M": 1,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 16,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 5
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 64,
|
45 |
+
"BLOCK_SIZE_K": 256,
|
46 |
+
"GROUP_SIZE_M": 16,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 64,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 5
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 256,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 2
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 8,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 32,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 32,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 64,
|
93 |
+
"BLOCK_SIZE_K": 64,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 128,
|
100 |
+
"BLOCK_SIZE_N": 256,
|
101 |
+
"BLOCK_SIZE_K": 64,
|
102 |
+
"GROUP_SIZE_M": 1,
|
103 |
+
"num_warps": 8,
|
104 |
+
"num_stages": 4
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 256,
|
109 |
+
"BLOCK_SIZE_K": 64,
|
110 |
+
"GROUP_SIZE_M": 16,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 4
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 64,
|
118 |
+
"GROUP_SIZE_M": 16,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 256,
|
125 |
+
"BLOCK_SIZE_K": 64,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 4
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 256,
|
133 |
+
"BLOCK_SIZE_K": 64,
|
134 |
+
"GROUP_SIZE_M": 1,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 4
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 256,
|
141 |
+
"BLOCK_SIZE_K": 64,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 4
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 64,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 16,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 64,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 16,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 64,
|
20 |
+
"BLOCK_SIZE_N": 128,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 16,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 4
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 64,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 1,
|
31 |
+
"num_warps": 8,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 64,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 64,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 64,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 64,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 64,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 64,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 64,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 16,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 32,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 64,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 64,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 2
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 1,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 1,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 1,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 64,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 16,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 64,
|
22 |
+
"GROUP_SIZE_M": 1,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 4
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 64,
|
29 |
+
"BLOCK_SIZE_K": 64,
|
30 |
+
"GROUP_SIZE_M": 32,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 32,
|
36 |
+
"BLOCK_SIZE_N": 64,
|
37 |
+
"BLOCK_SIZE_K": 64,
|
38 |
+
"GROUP_SIZE_M": 64,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 64,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 2
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 64,
|
53 |
+
"BLOCK_SIZE_K": 64,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 64,
|
61 |
+
"BLOCK_SIZE_K": 64,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 64,
|
69 |
+
"BLOCK_SIZE_K": 64,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 64,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 2
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 32,
|
84 |
+
"BLOCK_SIZE_N": 64,
|
85 |
+
"BLOCK_SIZE_K": 64,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 64,
|
94 |
+
"GROUP_SIZE_M": 16,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 64,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 64,
|
102 |
+
"GROUP_SIZE_M": 32,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 4
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 64,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 64,
|
118 |
+
"GROUP_SIZE_M": 1,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 256,
|
125 |
+
"BLOCK_SIZE_K": 64,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 4
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 256,
|
133 |
+
"BLOCK_SIZE_K": 64,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 4
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 256,
|
141 |
+
"BLOCK_SIZE_K": 64,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 4
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 64,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 16,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 64,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 16,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 64,
|
20 |
+
"BLOCK_SIZE_N": 128,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 32,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 4
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 64,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 64,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 64,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 64,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 64,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 64,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 64,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 64,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 16,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 32,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 64,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 16,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 2
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 1,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 1,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 1,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=320,device_name=NVIDIA_H200.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 64,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 1,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 64,
|
22 |
+
"GROUP_SIZE_M": 32,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 4
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 64,
|
29 |
+
"BLOCK_SIZE_K": 64,
|
30 |
+
"GROUP_SIZE_M": 16,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 64,
|
37 |
+
"BLOCK_SIZE_K": 64,
|
38 |
+
"GROUP_SIZE_M": 1,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 64,
|
45 |
+
"BLOCK_SIZE_K": 64,
|
46 |
+
"GROUP_SIZE_M": 16,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 64,
|
53 |
+
"BLOCK_SIZE_K": 64,
|
54 |
+
"GROUP_SIZE_M": 16,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 64,
|
61 |
+
"BLOCK_SIZE_K": 64,
|
62 |
+
"GROUP_SIZE_M": 32,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 64,
|
69 |
+
"BLOCK_SIZE_K": 64,
|
70 |
+
"GROUP_SIZE_M": 16,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 64,
|
78 |
+
"GROUP_SIZE_M": 32,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 64,
|
86 |
+
"GROUP_SIZE_M": 64,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 64,
|
94 |
+
"GROUP_SIZE_M": 64,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 64,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 64,
|
102 |
+
"GROUP_SIZE_M": 16,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 2
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 64,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 3
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 64,
|
118 |
+
"GROUP_SIZE_M": 16,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 64,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 64,
|
134 |
+
"GROUP_SIZE_M": 1,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 64,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 64,
|
6 |
+
"GROUP_SIZE_M": 16,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 32,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 64,
|
14 |
+
"GROUP_SIZE_M": 1,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 5
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 32,
|
21 |
+
"BLOCK_SIZE_K": 64,
|
22 |
+
"GROUP_SIZE_M": 16,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 4
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 64,
|
30 |
+
"GROUP_SIZE_M": 1,
|
31 |
+
"num_warps": 8,
|
32 |
+
"num_stages": 5
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 64,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 16,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 2
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 64,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 1,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 5
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 32,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 32,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 32,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 32,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 256,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 8,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 64,
|
100 |
+
"BLOCK_SIZE_N": 256,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 1,
|
103 |
+
"num_warps": 8,
|
104 |
+
"num_stages": 3
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 64,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 64,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 4,
|
112 |
+
"num_stages": 4
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 64,
|
118 |
+
"GROUP_SIZE_M": 32,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 3
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 64,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 4,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 64,
|
134 |
+
"GROUP_SIZE_M": 32,
|
135 |
+
"num_warps": 4,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 64,
|
142 |
+
"GROUP_SIZE_M": 1,
|
143 |
+
"num_warps": 4,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 32,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 3
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 32,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 32,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 2
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 16,
|
20 |
+
"BLOCK_SIZE_N": 64,
|
21 |
+
"BLOCK_SIZE_K": 256,
|
22 |
+
"GROUP_SIZE_M": 16,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 32,
|
28 |
+
"BLOCK_SIZE_N": 64,
|
29 |
+
"BLOCK_SIZE_K": 256,
|
30 |
+
"GROUP_SIZE_M": 32,
|
31 |
+
"num_warps": 8,
|
32 |
+
"num_stages": 3
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 32,
|
37 |
+
"BLOCK_SIZE_K": 256,
|
38 |
+
"GROUP_SIZE_M": 16,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 5
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 64,
|
45 |
+
"BLOCK_SIZE_K": 256,
|
46 |
+
"GROUP_SIZE_M": 64,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 2
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 32,
|
53 |
+
"BLOCK_SIZE_K": 256,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 8,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 64,
|
61 |
+
"BLOCK_SIZE_K": 256,
|
62 |
+
"GROUP_SIZE_M": 64,
|
63 |
+
"num_warps": 8,
|
64 |
+
"num_stages": 2
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 32,
|
68 |
+
"BLOCK_SIZE_N": 32,
|
69 |
+
"BLOCK_SIZE_K": 256,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 2
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 32,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 16,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 32,
|
84 |
+
"BLOCK_SIZE_N": 64,
|
85 |
+
"BLOCK_SIZE_K": 256,
|
86 |
+
"GROUP_SIZE_M": 16,
|
87 |
+
"num_warps": 8,
|
88 |
+
"num_stages": 2
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 32,
|
92 |
+
"BLOCK_SIZE_N": 64,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 16,
|
95 |
+
"num_warps": 8,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 32,
|
100 |
+
"BLOCK_SIZE_N": 64,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 32,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 4
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 64,
|
108 |
+
"BLOCK_SIZE_N": 128,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 16,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 2
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 128,
|
117 |
+
"BLOCK_SIZE_K": 64,
|
118 |
+
"GROUP_SIZE_M": 32,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 3
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 128,
|
125 |
+
"BLOCK_SIZE_K": 64,
|
126 |
+
"GROUP_SIZE_M": 32,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 128,
|
133 |
+
"BLOCK_SIZE_K": 64,
|
134 |
+
"GROUP_SIZE_M": 64,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 128,
|
141 |
+
"BLOCK_SIZE_K": 64,
|
142 |
+
"GROUP_SIZE_M": 1,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 64,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 64,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 16,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 64,
|
20 |
+
"BLOCK_SIZE_N": 128,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 1,
|
23 |
+
"num_warps": 8,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 64,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 1,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 4
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 64,
|
36 |
+
"BLOCK_SIZE_N": 64,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 32,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 4
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 64,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 16,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 64,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 64,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 64,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 8,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 16,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 64,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 16,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 4
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 256,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 16,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 4
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 16,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 256,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 16,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 4
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 256,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 4
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 256,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 1,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 4
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 64,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 64,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 128,
|
14 |
+
"GROUP_SIZE_M": 16,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 64,
|
20 |
+
"BLOCK_SIZE_N": 128,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 1,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 64,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 64,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 4
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 64,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 64,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 64,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 16,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 64,
|
52 |
+
"BLOCK_SIZE_N": 128,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 1,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 64,
|
60 |
+
"BLOCK_SIZE_N": 128,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 64,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 128,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 32,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 64,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 64,
|
100 |
+
"BLOCK_SIZE_N": 128,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 16,
|
103 |
+
"num_warps": 4,
|
104 |
+
"num_stages": 2
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 256,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 16,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 4
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 16,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 256,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 4
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 256,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 4
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 256,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 4
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=64,N=640,device_name=NVIDIA_H200.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 128,
|
6 |
+
"GROUP_SIZE_M": 32,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 5
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 16,
|
12 |
+
"BLOCK_SIZE_N": 64,
|
13 |
+
"BLOCK_SIZE_K": 64,
|
14 |
+
"GROUP_SIZE_M": 1,
|
15 |
+
"num_warps": 4,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 64,
|
20 |
+
"BLOCK_SIZE_N": 128,
|
21 |
+
"BLOCK_SIZE_K": 64,
|
22 |
+
"GROUP_SIZE_M": 32,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 3
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 16,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 128,
|
30 |
+
"GROUP_SIZE_M": 32,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 4
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 16,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 32,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 16,
|
44 |
+
"BLOCK_SIZE_N": 64,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 16,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 3
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 16,
|
52 |
+
"BLOCK_SIZE_N": 64,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 16,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 3
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 16,
|
60 |
+
"BLOCK_SIZE_N": 64,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 16,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 16,
|
68 |
+
"BLOCK_SIZE_N": 64,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 16,
|
76 |
+
"BLOCK_SIZE_N": 64,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 32,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 2
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 64,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 64,
|
92 |
+
"BLOCK_SIZE_N": 128,
|
93 |
+
"BLOCK_SIZE_K": 64,
|
94 |
+
"GROUP_SIZE_M": 1,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 3
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 128,
|
100 |
+
"BLOCK_SIZE_N": 256,
|
101 |
+
"BLOCK_SIZE_K": 64,
|
102 |
+
"GROUP_SIZE_M": 16,
|
103 |
+
"num_warps": 8,
|
104 |
+
"num_stages": 4
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 256,
|
109 |
+
"BLOCK_SIZE_K": 64,
|
110 |
+
"GROUP_SIZE_M": 1,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 4
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 64,
|
118 |
+
"GROUP_SIZE_M": 16,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 256,
|
125 |
+
"BLOCK_SIZE_K": 64,
|
126 |
+
"GROUP_SIZE_M": 1,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 4
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 256,
|
133 |
+
"BLOCK_SIZE_K": 64,
|
134 |
+
"GROUP_SIZE_M": 16,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 4
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 256,
|
141 |
+
"BLOCK_SIZE_K": 64,
|
142 |
+
"GROUP_SIZE_M": 1,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 4
|
145 |
+
}
|
146 |
+
}
|
torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 64,
|
5 |
+
"BLOCK_SIZE_K": 256,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0
|
10 |
+
},
|
11 |
+
"2": {
|
12 |
+
"BLOCK_SIZE_M": 16,
|
13 |
+
"BLOCK_SIZE_N": 16,
|
14 |
+
"BLOCK_SIZE_K": 256,
|
15 |
+
"GROUP_SIZE_M": 1,
|
16 |
+
"num_warps": 2,
|
17 |
+
"num_stages": 2,
|
18 |
+
"waves_per_eu": 0
|
19 |
+
},
|
20 |
+
"4": {
|
21 |
+
"BLOCK_SIZE_M": 16,
|
22 |
+
"BLOCK_SIZE_N": 64,
|
23 |
+
"BLOCK_SIZE_K": 256,
|
24 |
+
"GROUP_SIZE_M": 1,
|
25 |
+
"num_warps": 4,
|
26 |
+
"num_stages": 2,
|
27 |
+
"waves_per_eu": 0
|
28 |
+
},
|
29 |
+
"8": {
|
30 |
+
"BLOCK_SIZE_M": 16,
|
31 |
+
"BLOCK_SIZE_N": 32,
|
32 |
+
"BLOCK_SIZE_K": 256,
|
33 |
+
"GROUP_SIZE_M": 1,
|
34 |
+
"num_warps": 2,
|
35 |
+
"num_stages": 2,
|
36 |
+
"waves_per_eu": 0
|
37 |
+
},
|
38 |
+
"16": {
|
39 |
+
"BLOCK_SIZE_M": 16,
|
40 |
+
"BLOCK_SIZE_N": 64,
|
41 |
+
"BLOCK_SIZE_K": 256,
|
42 |
+
"GROUP_SIZE_M": 1,
|
43 |
+
"num_warps": 2,
|
44 |
+
"num_stages": 2,
|
45 |
+
"waves_per_eu": 0
|
46 |
+
},
|
47 |
+
"24": {
|
48 |
+
"BLOCK_SIZE_M": 16,
|
49 |
+
"BLOCK_SIZE_N": 64,
|
50 |
+
"BLOCK_SIZE_K": 256,
|
51 |
+
"GROUP_SIZE_M": 1,
|
52 |
+
"num_warps": 2,
|
53 |
+
"num_stages": 2,
|
54 |
+
"waves_per_eu": 0
|
55 |
+
},
|
56 |
+
"32": {
|
57 |
+
"BLOCK_SIZE_M": 16,
|
58 |
+
"BLOCK_SIZE_N": 32,
|
59 |
+
"BLOCK_SIZE_K": 256,
|
60 |
+
"GROUP_SIZE_M": 4,
|
61 |
+
"num_warps": 2,
|
62 |
+
"num_stages": 2,
|
63 |
+
"waves_per_eu": 0
|
64 |
+
},
|
65 |
+
"48": {
|
66 |
+
"BLOCK_SIZE_M": 16,
|
67 |
+
"BLOCK_SIZE_N": 64,
|
68 |
+
"BLOCK_SIZE_K": 128,
|
69 |
+
"GROUP_SIZE_M": 4,
|
70 |
+
"num_warps": 4,
|
71 |
+
"num_stages": 2,
|
72 |
+
"waves_per_eu": 0
|
73 |
+
},
|
74 |
+
"64": {
|
75 |
+
"BLOCK_SIZE_M": 32,
|
76 |
+
"BLOCK_SIZE_N": 64,
|
77 |
+
"BLOCK_SIZE_K": 256,
|
78 |
+
"GROUP_SIZE_M": 4,
|
79 |
+
"num_warps": 2,
|
80 |
+
"num_stages": 2,
|
81 |
+
"waves_per_eu": 0
|
82 |
+
},
|
83 |
+
"96": {
|
84 |
+
"BLOCK_SIZE_M": 32,
|
85 |
+
"BLOCK_SIZE_N": 64,
|
86 |
+
"BLOCK_SIZE_K": 256,
|
87 |
+
"GROUP_SIZE_M": 1,
|
88 |
+
"num_warps": 2,
|
89 |
+
"num_stages": 2,
|
90 |
+
"waves_per_eu": 0
|
91 |
+
},
|
92 |
+
"128": {
|
93 |
+
"BLOCK_SIZE_M": 64,
|
94 |
+
"BLOCK_SIZE_N": 128,
|
95 |
+
"BLOCK_SIZE_K": 256,
|
96 |
+
"GROUP_SIZE_M": 4,
|
97 |
+
"num_warps": 8,
|
98 |
+
"num_stages": 2,
|
99 |
+
"waves_per_eu": 0
|
100 |
+
},
|
101 |
+
"256": {
|
102 |
+
"BLOCK_SIZE_M": 128,
|
103 |
+
"BLOCK_SIZE_N": 128,
|
104 |
+
"BLOCK_SIZE_K": 256,
|
105 |
+
"GROUP_SIZE_M": 4,
|
106 |
+
"num_warps": 8,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0
|
109 |
+
},
|
110 |
+
"512": {
|
111 |
+
"BLOCK_SIZE_M": 256,
|
112 |
+
"BLOCK_SIZE_N": 128,
|
113 |
+
"BLOCK_SIZE_K": 128,
|
114 |
+
"GROUP_SIZE_M": 4,
|
115 |
+
"num_warps": 8,
|
116 |
+
"num_stages": 2,
|
117 |
+
"waves_per_eu": 0
|
118 |
+
},
|
119 |
+
"1024": {
|
120 |
+
"BLOCK_SIZE_M": 128,
|
121 |
+
"BLOCK_SIZE_N": 128,
|
122 |
+
"BLOCK_SIZE_K": 64,
|
123 |
+
"GROUP_SIZE_M": 1,
|
124 |
+
"num_warps": 4,
|
125 |
+
"num_stages": 2,
|
126 |
+
"waves_per_eu": 0
|
127 |
+
},
|
128 |
+
"1536": {
|
129 |
+
"BLOCK_SIZE_M": 128,
|
130 |
+
"BLOCK_SIZE_N": 256,
|
131 |
+
"BLOCK_SIZE_K": 128,
|
132 |
+
"GROUP_SIZE_M": 1,
|
133 |
+
"num_warps": 8,
|
134 |
+
"num_stages": 2,
|
135 |
+
"waves_per_eu": 0
|
136 |
+
},
|
137 |
+
"2048": {
|
138 |
+
"BLOCK_SIZE_M": 128,
|
139 |
+
"BLOCK_SIZE_N": 256,
|
140 |
+
"BLOCK_SIZE_K": 128,
|
141 |
+
"GROUP_SIZE_M": 1,
|
142 |
+
"num_warps": 8,
|
143 |
+
"num_stages": 2,
|
144 |
+
"waves_per_eu": 0
|
145 |
+
},
|
146 |
+
"3072": {
|
147 |
+
"BLOCK_SIZE_M": 128,
|
148 |
+
"BLOCK_SIZE_N": 256,
|
149 |
+
"BLOCK_SIZE_K": 128,
|
150 |
+
"GROUP_SIZE_M": 1,
|
151 |
+
"num_warps": 8,
|
152 |
+
"num_stages": 2,
|
153 |
+
"waves_per_eu": 0
|
154 |
+
},
|
155 |
+
"4096": {
|
156 |
+
"BLOCK_SIZE_M": 256,
|
157 |
+
"BLOCK_SIZE_N": 256,
|
158 |
+
"BLOCK_SIZE_K": 64,
|
159 |
+
"GROUP_SIZE_M": 1,
|
160 |
+
"num_warps": 8,
|
161 |
+
"num_stages": 2,
|
162 |
+
"waves_per_eu": 0
|
163 |
+
}
|
164 |
+
}
|
torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json
CHANGED
@@ -1,123 +1,123 @@
|
|
1 |
{
|
2 |
"1": {
|
3 |
"BLOCK_SIZE_M": 16,
|
4 |
-
"BLOCK_SIZE_N":
|
5 |
"BLOCK_SIZE_K": 256,
|
6 |
"GROUP_SIZE_M": 1,
|
7 |
"num_warps": 2,
|
8 |
-
"num_stages":
|
9 |
"waves_per_eu": 0,
|
10 |
"matrix_instr_nonkdim": 16,
|
11 |
-
"kpack":
|
12 |
},
|
13 |
"2": {
|
14 |
"BLOCK_SIZE_M": 16,
|
15 |
"BLOCK_SIZE_N": 16,
|
16 |
-
"BLOCK_SIZE_K":
|
17 |
"GROUP_SIZE_M": 1,
|
18 |
-
"num_warps":
|
19 |
-
"num_stages":
|
20 |
"waves_per_eu": 0,
|
21 |
"matrix_instr_nonkdim": 16,
|
22 |
"kpack": 2
|
23 |
},
|
24 |
"4": {
|
25 |
"BLOCK_SIZE_M": 16,
|
26 |
-
"BLOCK_SIZE_N":
|
27 |
-
"BLOCK_SIZE_K":
|
28 |
"GROUP_SIZE_M": 1,
|
29 |
-
"num_warps":
|
30 |
-
"num_stages":
|
31 |
"waves_per_eu": 0,
|
32 |
"matrix_instr_nonkdim": 16,
|
33 |
"kpack": 2
|
34 |
},
|
35 |
"8": {
|
36 |
"BLOCK_SIZE_M": 16,
|
37 |
-
"BLOCK_SIZE_N":
|
38 |
-
"BLOCK_SIZE_K":
|
39 |
"GROUP_SIZE_M": 1,
|
40 |
-
"num_warps":
|
41 |
-
"num_stages":
|
42 |
"waves_per_eu": 0,
|
43 |
"matrix_instr_nonkdim": 16,
|
44 |
"kpack": 2
|
45 |
},
|
46 |
"16": {
|
47 |
"BLOCK_SIZE_M": 16,
|
48 |
-
"BLOCK_SIZE_N":
|
49 |
-
"BLOCK_SIZE_K":
|
50 |
"GROUP_SIZE_M": 1,
|
51 |
-
"num_warps":
|
52 |
-
"num_stages":
|
53 |
"waves_per_eu": 0,
|
54 |
"matrix_instr_nonkdim": 16,
|
55 |
"kpack": 2
|
56 |
},
|
57 |
"24": {
|
58 |
"BLOCK_SIZE_M": 16,
|
59 |
-
"BLOCK_SIZE_N":
|
60 |
-
"BLOCK_SIZE_K":
|
61 |
"GROUP_SIZE_M": 1,
|
62 |
-
"num_warps":
|
63 |
-
"num_stages":
|
64 |
"waves_per_eu": 0,
|
65 |
"matrix_instr_nonkdim": 16,
|
66 |
"kpack": 2
|
67 |
},
|
68 |
"32": {
|
69 |
"BLOCK_SIZE_M": 16,
|
70 |
-
"BLOCK_SIZE_N":
|
71 |
-
"BLOCK_SIZE_K":
|
72 |
"GROUP_SIZE_M": 4,
|
73 |
"num_warps": 2,
|
74 |
-
"num_stages":
|
75 |
"waves_per_eu": 0,
|
76 |
"matrix_instr_nonkdim": 16,
|
77 |
-
"kpack":
|
78 |
},
|
79 |
"48": {
|
80 |
"BLOCK_SIZE_M": 16,
|
81 |
-
"BLOCK_SIZE_N":
|
82 |
"BLOCK_SIZE_K": 128,
|
83 |
"GROUP_SIZE_M": 4,
|
84 |
-
"num_warps":
|
85 |
-
"num_stages":
|
86 |
"waves_per_eu": 0,
|
87 |
"matrix_instr_nonkdim": 16,
|
88 |
-
"kpack":
|
89 |
},
|
90 |
"64": {
|
91 |
"BLOCK_SIZE_M": 32,
|
92 |
"BLOCK_SIZE_N": 64,
|
93 |
"BLOCK_SIZE_K": 128,
|
94 |
"GROUP_SIZE_M": 4,
|
95 |
-
"num_warps":
|
96 |
-
"num_stages":
|
97 |
"waves_per_eu": 0,
|
98 |
"matrix_instr_nonkdim": 16,
|
99 |
"kpack": 2
|
100 |
},
|
101 |
"96": {
|
102 |
"BLOCK_SIZE_M": 32,
|
103 |
-
"BLOCK_SIZE_N":
|
104 |
-
"BLOCK_SIZE_K":
|
105 |
"GROUP_SIZE_M": 4,
|
106 |
-
"num_warps":
|
107 |
-
"num_stages":
|
108 |
"waves_per_eu": 0,
|
109 |
"matrix_instr_nonkdim": 16,
|
110 |
-
"kpack":
|
111 |
},
|
112 |
"128": {
|
113 |
"BLOCK_SIZE_M": 64,
|
114 |
"BLOCK_SIZE_N": 64,
|
115 |
-
"BLOCK_SIZE_K":
|
116 |
"GROUP_SIZE_M": 4,
|
117 |
-
"num_warps":
|
118 |
-
"num_stages":
|
119 |
"waves_per_eu": 0,
|
120 |
-
"matrix_instr_nonkdim":
|
121 |
"kpack": 2
|
122 |
},
|
123 |
"256": {
|
@@ -126,10 +126,10 @@
|
|
126 |
"BLOCK_SIZE_K": 64,
|
127 |
"GROUP_SIZE_M": 4,
|
128 |
"num_warps": 8,
|
129 |
-
"num_stages":
|
130 |
"waves_per_eu": 0,
|
131 |
"matrix_instr_nonkdim": 16,
|
132 |
-
"kpack":
|
133 |
},
|
134 |
"512": {
|
135 |
"BLOCK_SIZE_M": 128,
|
@@ -137,7 +137,7 @@
|
|
137 |
"BLOCK_SIZE_K": 64,
|
138 |
"GROUP_SIZE_M": 4,
|
139 |
"num_warps": 8,
|
140 |
-
"num_stages":
|
141 |
"waves_per_eu": 0,
|
142 |
"matrix_instr_nonkdim": 16,
|
143 |
"kpack": 2
|
@@ -148,9 +148,9 @@
|
|
148 |
"BLOCK_SIZE_K": 64,
|
149 |
"GROUP_SIZE_M": 1,
|
150 |
"num_warps": 8,
|
151 |
-
"num_stages":
|
152 |
"waves_per_eu": 0,
|
153 |
-
"matrix_instr_nonkdim":
|
154 |
"kpack": 2
|
155 |
},
|
156 |
"1536": {
|
@@ -159,7 +159,7 @@
|
|
159 |
"BLOCK_SIZE_K": 64,
|
160 |
"GROUP_SIZE_M": 1,
|
161 |
"num_warps": 8,
|
162 |
-
"num_stages":
|
163 |
"waves_per_eu": 0,
|
164 |
"matrix_instr_nonkdim": 16,
|
165 |
"kpack": 2
|
@@ -170,7 +170,7 @@
|
|
170 |
"BLOCK_SIZE_K": 64,
|
171 |
"GROUP_SIZE_M": 1,
|
172 |
"num_warps": 8,
|
173 |
-
"num_stages":
|
174 |
"waves_per_eu": 0,
|
175 |
"matrix_instr_nonkdim": 16,
|
176 |
"kpack": 2
|
@@ -181,10 +181,10 @@
|
|
181 |
"BLOCK_SIZE_K": 64,
|
182 |
"GROUP_SIZE_M": 1,
|
183 |
"num_warps": 8,
|
184 |
-
"num_stages":
|
185 |
"waves_per_eu": 0,
|
186 |
"matrix_instr_nonkdim": 16,
|
187 |
-
"kpack":
|
188 |
},
|
189 |
"4096": {
|
190 |
"BLOCK_SIZE_M": 128,
|
@@ -192,9 +192,9 @@
|
|
192 |
"BLOCK_SIZE_K": 64,
|
193 |
"GROUP_SIZE_M": 1,
|
194 |
"num_warps": 8,
|
195 |
-
"num_stages":
|
196 |
"waves_per_eu": 0,
|
197 |
"matrix_instr_nonkdim": 16,
|
198 |
-
"kpack":
|
199 |
}
|
200 |
}
|
|
|
1 |
{
|
2 |
"1": {
|
3 |
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 16,
|
5 |
"BLOCK_SIZE_K": 256,
|
6 |
"GROUP_SIZE_M": 1,
|
7 |
"num_warps": 2,
|
8 |
+
"num_stages": 2,
|
9 |
"waves_per_eu": 0,
|
10 |
"matrix_instr_nonkdim": 16,
|
11 |
+
"kpack": 2
|
12 |
},
|
13 |
"2": {
|
14 |
"BLOCK_SIZE_M": 16,
|
15 |
"BLOCK_SIZE_N": 16,
|
16 |
+
"BLOCK_SIZE_K": 256,
|
17 |
"GROUP_SIZE_M": 1,
|
18 |
+
"num_warps": 4,
|
19 |
+
"num_stages": 2,
|
20 |
"waves_per_eu": 0,
|
21 |
"matrix_instr_nonkdim": 16,
|
22 |
"kpack": 2
|
23 |
},
|
24 |
"4": {
|
25 |
"BLOCK_SIZE_M": 16,
|
26 |
+
"BLOCK_SIZE_N": 16,
|
27 |
+
"BLOCK_SIZE_K": 128,
|
28 |
"GROUP_SIZE_M": 1,
|
29 |
+
"num_warps": 1,
|
30 |
+
"num_stages": 2,
|
31 |
"waves_per_eu": 0,
|
32 |
"matrix_instr_nonkdim": 16,
|
33 |
"kpack": 2
|
34 |
},
|
35 |
"8": {
|
36 |
"BLOCK_SIZE_M": 16,
|
37 |
+
"BLOCK_SIZE_N": 64,
|
38 |
+
"BLOCK_SIZE_K": 64,
|
39 |
"GROUP_SIZE_M": 1,
|
40 |
+
"num_warps": 2,
|
41 |
+
"num_stages": 2,
|
42 |
"waves_per_eu": 0,
|
43 |
"matrix_instr_nonkdim": 16,
|
44 |
"kpack": 2
|
45 |
},
|
46 |
"16": {
|
47 |
"BLOCK_SIZE_M": 16,
|
48 |
+
"BLOCK_SIZE_N": 64,
|
49 |
+
"BLOCK_SIZE_K": 64,
|
50 |
"GROUP_SIZE_M": 1,
|
51 |
+
"num_warps": 2,
|
52 |
+
"num_stages": 2,
|
53 |
"waves_per_eu": 0,
|
54 |
"matrix_instr_nonkdim": 16,
|
55 |
"kpack": 2
|
56 |
},
|
57 |
"24": {
|
58 |
"BLOCK_SIZE_M": 16,
|
59 |
+
"BLOCK_SIZE_N": 16,
|
60 |
+
"BLOCK_SIZE_K": 256,
|
61 |
"GROUP_SIZE_M": 1,
|
62 |
+
"num_warps": 2,
|
63 |
+
"num_stages": 2,
|
64 |
"waves_per_eu": 0,
|
65 |
"matrix_instr_nonkdim": 16,
|
66 |
"kpack": 2
|
67 |
},
|
68 |
"32": {
|
69 |
"BLOCK_SIZE_M": 16,
|
70 |
+
"BLOCK_SIZE_N": 32,
|
71 |
+
"BLOCK_SIZE_K": 256,
|
72 |
"GROUP_SIZE_M": 4,
|
73 |
"num_warps": 2,
|
74 |
+
"num_stages": 2,
|
75 |
"waves_per_eu": 0,
|
76 |
"matrix_instr_nonkdim": 16,
|
77 |
+
"kpack": 2
|
78 |
},
|
79 |
"48": {
|
80 |
"BLOCK_SIZE_M": 16,
|
81 |
+
"BLOCK_SIZE_N": 64,
|
82 |
"BLOCK_SIZE_K": 128,
|
83 |
"GROUP_SIZE_M": 4,
|
84 |
+
"num_warps": 4,
|
85 |
+
"num_stages": 2,
|
86 |
"waves_per_eu": 0,
|
87 |
"matrix_instr_nonkdim": 16,
|
88 |
+
"kpack": 1
|
89 |
},
|
90 |
"64": {
|
91 |
"BLOCK_SIZE_M": 32,
|
92 |
"BLOCK_SIZE_N": 64,
|
93 |
"BLOCK_SIZE_K": 128,
|
94 |
"GROUP_SIZE_M": 4,
|
95 |
+
"num_warps": 4,
|
96 |
+
"num_stages": 2,
|
97 |
"waves_per_eu": 0,
|
98 |
"matrix_instr_nonkdim": 16,
|
99 |
"kpack": 2
|
100 |
},
|
101 |
"96": {
|
102 |
"BLOCK_SIZE_M": 32,
|
103 |
+
"BLOCK_SIZE_N": 64,
|
104 |
+
"BLOCK_SIZE_K": 256,
|
105 |
"GROUP_SIZE_M": 4,
|
106 |
+
"num_warps": 8,
|
107 |
+
"num_stages": 2,
|
108 |
"waves_per_eu": 0,
|
109 |
"matrix_instr_nonkdim": 16,
|
110 |
+
"kpack": 1
|
111 |
},
|
112 |
"128": {
|
113 |
"BLOCK_SIZE_M": 64,
|
114 |
"BLOCK_SIZE_N": 64,
|
115 |
+
"BLOCK_SIZE_K": 128,
|
116 |
"GROUP_SIZE_M": 4,
|
117 |
+
"num_warps": 4,
|
118 |
+
"num_stages": 2,
|
119 |
"waves_per_eu": 0,
|
120 |
+
"matrix_instr_nonkdim": 32,
|
121 |
"kpack": 2
|
122 |
},
|
123 |
"256": {
|
|
|
126 |
"BLOCK_SIZE_K": 64,
|
127 |
"GROUP_SIZE_M": 4,
|
128 |
"num_warps": 8,
|
129 |
+
"num_stages": 2,
|
130 |
"waves_per_eu": 0,
|
131 |
"matrix_instr_nonkdim": 16,
|
132 |
+
"kpack": 2
|
133 |
},
|
134 |
"512": {
|
135 |
"BLOCK_SIZE_M": 128,
|
|
|
137 |
"BLOCK_SIZE_K": 64,
|
138 |
"GROUP_SIZE_M": 4,
|
139 |
"num_warps": 8,
|
140 |
+
"num_stages": 2,
|
141 |
"waves_per_eu": 0,
|
142 |
"matrix_instr_nonkdim": 16,
|
143 |
"kpack": 2
|
|
|
148 |
"BLOCK_SIZE_K": 64,
|
149 |
"GROUP_SIZE_M": 1,
|
150 |
"num_warps": 8,
|
151 |
+
"num_stages": 2,
|
152 |
"waves_per_eu": 0,
|
153 |
+
"matrix_instr_nonkdim": 16,
|
154 |
"kpack": 2
|
155 |
},
|
156 |
"1536": {
|
|
|
159 |
"BLOCK_SIZE_K": 64,
|
160 |
"GROUP_SIZE_M": 1,
|
161 |
"num_warps": 8,
|
162 |
+
"num_stages": 2,
|
163 |
"waves_per_eu": 0,
|
164 |
"matrix_instr_nonkdim": 16,
|
165 |
"kpack": 2
|
|
|
170 |
"BLOCK_SIZE_K": 64,
|
171 |
"GROUP_SIZE_M": 1,
|
172 |
"num_warps": 8,
|
173 |
+
"num_stages": 2,
|
174 |
"waves_per_eu": 0,
|
175 |
"matrix_instr_nonkdim": 16,
|
176 |
"kpack": 2
|
|
|
181 |
"BLOCK_SIZE_K": 64,
|
182 |
"GROUP_SIZE_M": 1,
|
183 |
"num_warps": 8,
|
184 |
+
"num_stages": 2,
|
185 |
"waves_per_eu": 0,
|
186 |
"matrix_instr_nonkdim": 16,
|
187 |
+
"kpack": 2
|
188 |
},
|
189 |
"4096": {
|
190 |
"BLOCK_SIZE_M": 128,
|
|
|
192 |
"BLOCK_SIZE_K": 64,
|
193 |
"GROUP_SIZE_M": 1,
|
194 |
"num_warps": 8,
|
195 |
+
"num_stages": 2,
|
196 |
"waves_per_eu": 0,
|
197 |
"matrix_instr_nonkdim": 16,
|
198 |
+
"kpack": 2
|
199 |
}
|
200 |
}
|
torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 16,
|
5 |
+
"BLOCK_SIZE_K": 256,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 2,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0
|
10 |
+
},
|
11 |
+
"2": {
|
12 |
+
"BLOCK_SIZE_M": 16,
|
13 |
+
"BLOCK_SIZE_N": 32,
|
14 |
+
"BLOCK_SIZE_K": 256,
|
15 |
+
"GROUP_SIZE_M": 1,
|
16 |
+
"num_warps": 4,
|
17 |
+
"num_stages": 2,
|
18 |
+
"waves_per_eu": 0
|
19 |
+
},
|
20 |
+
"4": {
|
21 |
+
"BLOCK_SIZE_M": 16,
|
22 |
+
"BLOCK_SIZE_N": 16,
|
23 |
+
"BLOCK_SIZE_K": 256,
|
24 |
+
"GROUP_SIZE_M": 1,
|
25 |
+
"num_warps": 4,
|
26 |
+
"num_stages": 2,
|
27 |
+
"waves_per_eu": 0
|
28 |
+
},
|
29 |
+
"8": {
|
30 |
+
"BLOCK_SIZE_M": 16,
|
31 |
+
"BLOCK_SIZE_N": 64,
|
32 |
+
"BLOCK_SIZE_K": 256,
|
33 |
+
"GROUP_SIZE_M": 1,
|
34 |
+
"num_warps": 2,
|
35 |
+
"num_stages": 2,
|
36 |
+
"waves_per_eu": 0
|
37 |
+
},
|
38 |
+
"16": {
|
39 |
+
"BLOCK_SIZE_M": 64,
|
40 |
+
"BLOCK_SIZE_N": 64,
|
41 |
+
"BLOCK_SIZE_K": 256,
|
42 |
+
"GROUP_SIZE_M": 1,
|
43 |
+
"num_warps": 4,
|
44 |
+
"num_stages": 2,
|
45 |
+
"waves_per_eu": 0
|
46 |
+
},
|
47 |
+
"24": {
|
48 |
+
"BLOCK_SIZE_M": 32,
|
49 |
+
"BLOCK_SIZE_N": 64,
|
50 |
+
"BLOCK_SIZE_K": 256,
|
51 |
+
"GROUP_SIZE_M": 1,
|
52 |
+
"num_warps": 2,
|
53 |
+
"num_stages": 2,
|
54 |
+
"waves_per_eu": 0
|
55 |
+
},
|
56 |
+
"32": {
|
57 |
+
"BLOCK_SIZE_M": 16,
|
58 |
+
"BLOCK_SIZE_N": 32,
|
59 |
+
"BLOCK_SIZE_K": 256,
|
60 |
+
"GROUP_SIZE_M": 4,
|
61 |
+
"num_warps": 2,
|
62 |
+
"num_stages": 2,
|
63 |
+
"waves_per_eu": 0
|
64 |
+
},
|
65 |
+
"48": {
|
66 |
+
"BLOCK_SIZE_M": 16,
|
67 |
+
"BLOCK_SIZE_N": 64,
|
68 |
+
"BLOCK_SIZE_K": 256,
|
69 |
+
"GROUP_SIZE_M": 1,
|
70 |
+
"num_warps": 4,
|
71 |
+
"num_stages": 2,
|
72 |
+
"waves_per_eu": 0
|
73 |
+
},
|
74 |
+
"64": {
|
75 |
+
"BLOCK_SIZE_M": 32,
|
76 |
+
"BLOCK_SIZE_N": 16,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 4,
|
79 |
+
"num_warps": 2,
|
80 |
+
"num_stages": 2,
|
81 |
+
"waves_per_eu": 0
|
82 |
+
},
|
83 |
+
"96": {
|
84 |
+
"BLOCK_SIZE_M": 32,
|
85 |
+
"BLOCK_SIZE_N": 64,
|
86 |
+
"BLOCK_SIZE_K": 256,
|
87 |
+
"GROUP_SIZE_M": 1,
|
88 |
+
"num_warps": 2,
|
89 |
+
"num_stages": 2,
|
90 |
+
"waves_per_eu": 0
|
91 |
+
},
|
92 |
+
"128": {
|
93 |
+
"BLOCK_SIZE_M": 64,
|
94 |
+
"BLOCK_SIZE_N": 64,
|
95 |
+
"BLOCK_SIZE_K": 256,
|
96 |
+
"GROUP_SIZE_M": 4,
|
97 |
+
"num_warps": 4,
|
98 |
+
"num_stages": 2,
|
99 |
+
"waves_per_eu": 0
|
100 |
+
},
|
101 |
+
"256": {
|
102 |
+
"BLOCK_SIZE_M": 128,
|
103 |
+
"BLOCK_SIZE_N": 128,
|
104 |
+
"BLOCK_SIZE_K": 256,
|
105 |
+
"GROUP_SIZE_M": 4,
|
106 |
+
"num_warps": 8,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0
|
109 |
+
},
|
110 |
+
"512": {
|
111 |
+
"BLOCK_SIZE_M": 256,
|
112 |
+
"BLOCK_SIZE_N": 128,
|
113 |
+
"BLOCK_SIZE_K": 128,
|
114 |
+
"GROUP_SIZE_M": 4,
|
115 |
+
"num_warps": 8,
|
116 |
+
"num_stages": 2,
|
117 |
+
"waves_per_eu": 0
|
118 |
+
},
|
119 |
+
"1024": {
|
120 |
+
"BLOCK_SIZE_M": 128,
|
121 |
+
"BLOCK_SIZE_N": 128,
|
122 |
+
"BLOCK_SIZE_K": 256,
|
123 |
+
"GROUP_SIZE_M": 1,
|
124 |
+
"num_warps": 8,
|
125 |
+
"num_stages": 2,
|
126 |
+
"waves_per_eu": 0
|
127 |
+
},
|
128 |
+
"1536": {
|
129 |
+
"BLOCK_SIZE_M": 256,
|
130 |
+
"BLOCK_SIZE_N": 256,
|
131 |
+
"BLOCK_SIZE_K": 64,
|
132 |
+
"GROUP_SIZE_M": 1,
|
133 |
+
"num_warps": 8,
|
134 |
+
"num_stages": 2,
|
135 |
+
"waves_per_eu": 0
|
136 |
+
},
|
137 |
+
"2048": {
|
138 |
+
"BLOCK_SIZE_M": 128,
|
139 |
+
"BLOCK_SIZE_N": 256,
|
140 |
+
"BLOCK_SIZE_K": 128,
|
141 |
+
"GROUP_SIZE_M": 1,
|
142 |
+
"num_warps": 8,
|
143 |
+
"num_stages": 2,
|
144 |
+
"waves_per_eu": 0
|
145 |
+
},
|
146 |
+
"3072": {
|
147 |
+
"BLOCK_SIZE_M": 128,
|
148 |
+
"BLOCK_SIZE_N": 256,
|
149 |
+
"BLOCK_SIZE_K": 128,
|
150 |
+
"GROUP_SIZE_M": 1,
|
151 |
+
"num_warps": 8,
|
152 |
+
"num_stages": 2,
|
153 |
+
"waves_per_eu": 0
|
154 |
+
},
|
155 |
+
"4096": {
|
156 |
+
"BLOCK_SIZE_M": 256,
|
157 |
+
"BLOCK_SIZE_N": 256,
|
158 |
+
"BLOCK_SIZE_K": 64,
|
159 |
+
"GROUP_SIZE_M": 1,
|
160 |
+
"num_warps": 8,
|
161 |
+
"num_stages": 2,
|
162 |
+
"waves_per_eu": 0
|
163 |
+
}
|
164 |
+
}
|
torch-ext/moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 16,
|
4 |
+
"BLOCK_SIZE_N": 16,
|
5 |
+
"BLOCK_SIZE_K": 256,
|
6 |
+
"GROUP_SIZE_M": 1,
|
7 |
+
"num_warps": 2,
|
8 |
+
"num_stages": 2,
|
9 |
+
"waves_per_eu": 0,
|
10 |
+
"matrix_instr_nonkdim": 16,
|
11 |
+
"kpack": 1
|
12 |
+
},
|
13 |
+
"2": {
|
14 |
+
"BLOCK_SIZE_M": 16,
|
15 |
+
"BLOCK_SIZE_N": 16,
|
16 |
+
"BLOCK_SIZE_K": 256,
|
17 |
+
"GROUP_SIZE_M": 1,
|
18 |
+
"num_warps": 4,
|
19 |
+
"num_stages": 2,
|
20 |
+
"waves_per_eu": 0,
|
21 |
+
"matrix_instr_nonkdim": 16,
|
22 |
+
"kpack": 2
|
23 |
+
},
|
24 |
+
"4": {
|
25 |
+
"BLOCK_SIZE_M": 16,
|
26 |
+
"BLOCK_SIZE_N": 16,
|
27 |
+
"BLOCK_SIZE_K": 128,
|
28 |
+
"GROUP_SIZE_M": 1,
|
29 |
+
"num_warps": 1,
|
30 |
+
"num_stages": 2,
|
31 |
+
"waves_per_eu": 0,
|
32 |
+
"matrix_instr_nonkdim": 16,
|
33 |
+
"kpack": 2
|
34 |
+
},
|
35 |
+
"8": {
|
36 |
+
"BLOCK_SIZE_M": 16,
|
37 |
+
"BLOCK_SIZE_N": 32,
|
38 |
+
"BLOCK_SIZE_K": 128,
|
39 |
+
"GROUP_SIZE_M": 1,
|
40 |
+
"num_warps": 2,
|
41 |
+
"num_stages": 2,
|
42 |
+
"waves_per_eu": 0,
|
43 |
+
"matrix_instr_nonkdim": 16,
|
44 |
+
"kpack": 2
|
45 |
+
},
|
46 |
+
"16": {
|
47 |
+
"BLOCK_SIZE_M": 16,
|
48 |
+
"BLOCK_SIZE_N": 64,
|
49 |
+
"BLOCK_SIZE_K": 64,
|
50 |
+
"GROUP_SIZE_M": 1,
|
51 |
+
"num_warps": 2,
|
52 |
+
"num_stages": 2,
|
53 |
+
"waves_per_eu": 0,
|
54 |
+
"matrix_instr_nonkdim": 16,
|
55 |
+
"kpack": 2
|
56 |
+
},
|
57 |
+
"24": {
|
58 |
+
"BLOCK_SIZE_M": 16,
|
59 |
+
"BLOCK_SIZE_N": 128,
|
60 |
+
"BLOCK_SIZE_K": 64,
|
61 |
+
"GROUP_SIZE_M": 1,
|
62 |
+
"num_warps": 4,
|
63 |
+
"num_stages": 2,
|
64 |
+
"waves_per_eu": 0,
|
65 |
+
"matrix_instr_nonkdim": 16,
|
66 |
+
"kpack": 2
|
67 |
+
},
|
68 |
+
"32": {
|
69 |
+
"BLOCK_SIZE_M": 16,
|
70 |
+
"BLOCK_SIZE_N": 64,
|
71 |
+
"BLOCK_SIZE_K": 128,
|
72 |
+
"GROUP_SIZE_M": 4,
|
73 |
+
"num_warps": 4,
|
74 |
+
"num_stages": 2,
|
75 |
+
"waves_per_eu": 0,
|
76 |
+
"matrix_instr_nonkdim": 16,
|
77 |
+
"kpack": 2
|
78 |
+
},
|
79 |
+
"48": {
|
80 |
+
"BLOCK_SIZE_M": 16,
|
81 |
+
"BLOCK_SIZE_N": 64,
|
82 |
+
"BLOCK_SIZE_K": 128,
|
83 |
+
"GROUP_SIZE_M": 4,
|
84 |
+
"num_warps": 1,
|
85 |
+
"num_stages": 2,
|
86 |
+
"waves_per_eu": 0,
|
87 |
+
"matrix_instr_nonkdim": 16,
|
88 |
+
"kpack": 1
|
89 |
+
},
|
90 |
+
"64": {
|
91 |
+
"BLOCK_SIZE_M": 32,
|
92 |
+
"BLOCK_SIZE_N": 64,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 4,
|
95 |
+
"num_warps": 8,
|
96 |
+
"num_stages": 2,
|
97 |
+
"waves_per_eu": 0,
|
98 |
+
"matrix_instr_nonkdim": 16,
|
99 |
+
"kpack": 2
|
100 |
+
},
|
101 |
+
"96": {
|
102 |
+
"BLOCK_SIZE_M": 32,
|
103 |
+
"BLOCK_SIZE_N": 64,
|
104 |
+
"BLOCK_SIZE_K": 256,
|
105 |
+
"GROUP_SIZE_M": 4,
|
106 |
+
"num_warps": 8,
|
107 |
+
"num_stages": 2,
|
108 |
+
"waves_per_eu": 0,
|
109 |
+
"matrix_instr_nonkdim": 16,
|
110 |
+
"kpack": 1
|
111 |
+
},
|
112 |
+
"128": {
|
113 |
+
"BLOCK_SIZE_M": 64,
|
114 |
+
"BLOCK_SIZE_N": 64,
|
115 |
+
"BLOCK_SIZE_K": 128,
|
116 |
+
"GROUP_SIZE_M": 4,
|
117 |
+
"num_warps": 4,
|
118 |
+
"num_stages": 2,
|
119 |
+
"waves_per_eu": 0,
|
120 |
+
"matrix_instr_nonkdim": 32,
|
121 |
+
"kpack": 2
|
122 |
+
},
|
123 |
+
"256": {
|
124 |
+
"BLOCK_SIZE_M": 128,
|
125 |
+
"BLOCK_SIZE_N": 128,
|
126 |
+
"BLOCK_SIZE_K": 64,
|
127 |
+
"GROUP_SIZE_M": 4,
|
128 |
+
"num_warps": 8,
|
129 |
+
"num_stages": 2,
|
130 |
+
"waves_per_eu": 0,
|
131 |
+
"matrix_instr_nonkdim": 16,
|
132 |
+
"kpack": 2
|
133 |
+
},
|
134 |
+
"512": {
|
135 |
+
"BLOCK_SIZE_M": 128,
|
136 |
+
"BLOCK_SIZE_N": 128,
|
137 |
+
"BLOCK_SIZE_K": 64,
|
138 |
+
"GROUP_SIZE_M": 1,
|
139 |
+
"num_warps": 8,
|
140 |
+
"num_stages": 2,
|
141 |
+
"waves_per_eu": 0,
|
142 |
+
"matrix_instr_nonkdim": 16,
|
143 |
+
"kpack": 2
|
144 |
+
},
|
145 |
+
"1024": {
|
146 |
+
"BLOCK_SIZE_M": 128,
|
147 |
+
"BLOCK_SIZE_N": 128,
|
148 |
+
"BLOCK_SIZE_K": 64,
|
149 |
+
"GROUP_SIZE_M": 1,
|
150 |
+
"num_warps": 8,
|
151 |
+
"num_stages": 2,
|
152 |
+
"waves_per_eu": 0,
|
153 |
+
"matrix_instr_nonkdim": 16,
|
154 |
+
"kpack": 2
|
155 |
+
},
|
156 |
+
"1536": {
|
157 |
+
"BLOCK_SIZE_M": 128,
|
158 |
+
"BLOCK_SIZE_N": 128,
|
159 |
+
"BLOCK_SIZE_K": 64,
|
160 |
+
"GROUP_SIZE_M": 1,
|
161 |
+
"num_warps": 8,
|
162 |
+
"num_stages": 2,
|
163 |
+
"waves_per_eu": 0,
|
164 |
+
"matrix_instr_nonkdim": 16,
|
165 |
+
"kpack": 2
|
166 |
+
},
|
167 |
+
"2048": {
|
168 |
+
"BLOCK_SIZE_M": 128,
|
169 |
+
"BLOCK_SIZE_N": 128,
|
170 |
+
"BLOCK_SIZE_K": 64,
|
171 |
+
"GROUP_SIZE_M": 1,
|
172 |
+
"num_warps": 8,
|
173 |
+
"num_stages": 2,
|
174 |
+
"waves_per_eu": 0,
|
175 |
+
"matrix_instr_nonkdim": 16,
|
176 |
+
"kpack": 2
|
177 |
+
},
|
178 |
+
"3072": {
|
179 |
+
"BLOCK_SIZE_M": 128,
|
180 |
+
"BLOCK_SIZE_N": 128,
|
181 |
+
"BLOCK_SIZE_K": 64,
|
182 |
+
"GROUP_SIZE_M": 1,
|
183 |
+
"num_warps": 8,
|
184 |
+
"num_stages": 2,
|
185 |
+
"waves_per_eu": 0,
|
186 |
+
"matrix_instr_nonkdim": 16,
|
187 |
+
"kpack": 2
|
188 |
+
},
|
189 |
+
"4096": {
|
190 |
+
"BLOCK_SIZE_M": 128,
|
191 |
+
"BLOCK_SIZE_N": 128,
|
192 |
+
"BLOCK_SIZE_K": 64,
|
193 |
+
"GROUP_SIZE_M": 1,
|
194 |
+
"num_warps": 8,
|
195 |
+
"num_stages": 2,
|
196 |
+
"waves_per_eu": 0,
|
197 |
+
"matrix_instr_nonkdim": 16,
|
198 |
+
"kpack": 2
|
199 |
+
}
|
200 |
+
}
|
torch-ext/moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"1": {
|
3 |
+
"BLOCK_SIZE_M": 64,
|
4 |
+
"BLOCK_SIZE_N": 128,
|
5 |
+
"BLOCK_SIZE_K": 256,
|
6 |
+
"GROUP_SIZE_M": 64,
|
7 |
+
"num_warps": 4,
|
8 |
+
"num_stages": 4
|
9 |
+
},
|
10 |
+
"2": {
|
11 |
+
"BLOCK_SIZE_M": 64,
|
12 |
+
"BLOCK_SIZE_N": 128,
|
13 |
+
"BLOCK_SIZE_K": 256,
|
14 |
+
"GROUP_SIZE_M": 64,
|
15 |
+
"num_warps": 8,
|
16 |
+
"num_stages": 4
|
17 |
+
},
|
18 |
+
"4": {
|
19 |
+
"BLOCK_SIZE_M": 64,
|
20 |
+
"BLOCK_SIZE_N": 256,
|
21 |
+
"BLOCK_SIZE_K": 128,
|
22 |
+
"GROUP_SIZE_M": 32,
|
23 |
+
"num_warps": 4,
|
24 |
+
"num_stages": 5
|
25 |
+
},
|
26 |
+
"8": {
|
27 |
+
"BLOCK_SIZE_M": 64,
|
28 |
+
"BLOCK_SIZE_N": 128,
|
29 |
+
"BLOCK_SIZE_K": 256,
|
30 |
+
"GROUP_SIZE_M": 64,
|
31 |
+
"num_warps": 4,
|
32 |
+
"num_stages": 4
|
33 |
+
},
|
34 |
+
"16": {
|
35 |
+
"BLOCK_SIZE_M": 64,
|
36 |
+
"BLOCK_SIZE_N": 128,
|
37 |
+
"BLOCK_SIZE_K": 128,
|
38 |
+
"GROUP_SIZE_M": 16,
|
39 |
+
"num_warps": 4,
|
40 |
+
"num_stages": 3
|
41 |
+
},
|
42 |
+
"24": {
|
43 |
+
"BLOCK_SIZE_M": 64,
|
44 |
+
"BLOCK_SIZE_N": 128,
|
45 |
+
"BLOCK_SIZE_K": 128,
|
46 |
+
"GROUP_SIZE_M": 64,
|
47 |
+
"num_warps": 4,
|
48 |
+
"num_stages": 4
|
49 |
+
},
|
50 |
+
"32": {
|
51 |
+
"BLOCK_SIZE_M": 64,
|
52 |
+
"BLOCK_SIZE_N": 256,
|
53 |
+
"BLOCK_SIZE_K": 128,
|
54 |
+
"GROUP_SIZE_M": 32,
|
55 |
+
"num_warps": 4,
|
56 |
+
"num_stages": 5
|
57 |
+
},
|
58 |
+
"48": {
|
59 |
+
"BLOCK_SIZE_M": 64,
|
60 |
+
"BLOCK_SIZE_N": 64,
|
61 |
+
"BLOCK_SIZE_K": 128,
|
62 |
+
"GROUP_SIZE_M": 1,
|
63 |
+
"num_warps": 4,
|
64 |
+
"num_stages": 3
|
65 |
+
},
|
66 |
+
"64": {
|
67 |
+
"BLOCK_SIZE_M": 64,
|
68 |
+
"BLOCK_SIZE_N": 128,
|
69 |
+
"BLOCK_SIZE_K": 128,
|
70 |
+
"GROUP_SIZE_M": 1,
|
71 |
+
"num_warps": 4,
|
72 |
+
"num_stages": 3
|
73 |
+
},
|
74 |
+
"96": {
|
75 |
+
"BLOCK_SIZE_M": 64,
|
76 |
+
"BLOCK_SIZE_N": 64,
|
77 |
+
"BLOCK_SIZE_K": 128,
|
78 |
+
"GROUP_SIZE_M": 1,
|
79 |
+
"num_warps": 4,
|
80 |
+
"num_stages": 3
|
81 |
+
},
|
82 |
+
"128": {
|
83 |
+
"BLOCK_SIZE_M": 64,
|
84 |
+
"BLOCK_SIZE_N": 128,
|
85 |
+
"BLOCK_SIZE_K": 128,
|
86 |
+
"GROUP_SIZE_M": 1,
|
87 |
+
"num_warps": 4,
|
88 |
+
"num_stages": 3
|
89 |
+
},
|
90 |
+
"256": {
|
91 |
+
"BLOCK_SIZE_M": 128,
|
92 |
+
"BLOCK_SIZE_N": 256,
|
93 |
+
"BLOCK_SIZE_K": 128,
|
94 |
+
"GROUP_SIZE_M": 32,
|
95 |
+
"num_warps": 8,
|
96 |
+
"num_stages": 4
|
97 |
+
},
|
98 |
+
"512": {
|
99 |
+
"BLOCK_SIZE_M": 128,
|
100 |
+
"BLOCK_SIZE_N": 256,
|
101 |
+
"BLOCK_SIZE_K": 128,
|
102 |
+
"GROUP_SIZE_M": 32,
|
103 |
+
"num_warps": 8,
|
104 |
+
"num_stages": 4
|
105 |
+
},
|
106 |
+
"1024": {
|
107 |
+
"BLOCK_SIZE_M": 128,
|
108 |
+
"BLOCK_SIZE_N": 256,
|
109 |
+
"BLOCK_SIZE_K": 128,
|
110 |
+
"GROUP_SIZE_M": 16,
|
111 |
+
"num_warps": 8,
|
112 |
+
"num_stages": 4
|
113 |
+
},
|
114 |
+
"1536": {
|
115 |
+
"BLOCK_SIZE_M": 128,
|
116 |
+
"BLOCK_SIZE_N": 256,
|
117 |
+
"BLOCK_SIZE_K": 128,
|
118 |
+
"GROUP_SIZE_M": 16,
|
119 |
+
"num_warps": 8,
|
120 |
+
"num_stages": 4
|
121 |
+
},
|
122 |
+
"2048": {
|
123 |
+
"BLOCK_SIZE_M": 128,
|
124 |
+
"BLOCK_SIZE_N": 256,
|
125 |
+
"BLOCK_SIZE_K": 128,
|
126 |
+
"GROUP_SIZE_M": 32,
|
127 |
+
"num_warps": 8,
|
128 |
+
"num_stages": 3
|
129 |
+
},
|
130 |
+
"3072": {
|
131 |
+
"BLOCK_SIZE_M": 128,
|
132 |
+
"BLOCK_SIZE_N": 256,
|
133 |
+
"BLOCK_SIZE_K": 128,
|
134 |
+
"GROUP_SIZE_M": 32,
|
135 |
+
"num_warps": 8,
|
136 |
+
"num_stages": 3
|
137 |
+
},
|
138 |
+
"4096": {
|
139 |
+
"BLOCK_SIZE_M": 128,
|
140 |
+
"BLOCK_SIZE_N": 256,
|
141 |
+
"BLOCK_SIZE_K": 128,
|
142 |
+
"GROUP_SIZE_M": 16,
|
143 |
+
"num_warps": 8,
|
144 |
+
"num_stages": 3
|
145 |
+
}
|
146 |
+
}
|