File size: 383 Bytes
			
			| 3224250 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | #pragma once
// // Set default if not already defined
// #ifndef GROUPED_GEMM_CUTLASS
// #define GROUPED_GEMM_CUTLASS 0
// #endif
// #include <torch/extension.h>
#include <torch/torch.h>
namespace grouped_gemm {
void GroupedGemm(torch::Tensor a,
		 torch::Tensor b,
		 torch::Tensor c,
		 torch::Tensor batch_sizes,
		 bool trans_a, bool trans_b);
}  // namespace grouped_gemm
 | 
