|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h" |
|
#include "utils/activation_types.h" |
|
#include <cuda_runtime_api.h> |
|
|
|
namespace fastertransformer { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename T, typename WeightType> |
|
class CutlassFpAIntBGemmRunner { |
|
public: |
|
CutlassFpAIntBGemmRunner(); |
|
~CutlassFpAIntBGemmRunner(); |
|
|
|
void gemm(const T* A, |
|
const WeightType* B, |
|
const T* weight_scales, |
|
T* C, |
|
int m, |
|
int n, |
|
int k, |
|
char* workspace_ptr, |
|
const size_t workspace_bytes, |
|
cudaStream_t stream); |
|
|
|
void gemm_bias_act(const T* A, |
|
const WeightType* B, |
|
const T* weight_scales, |
|
const T* biases, |
|
T* C, |
|
int m, |
|
int n, |
|
int k, |
|
int bias_stride, |
|
ActivationType activation_type, |
|
char* workspace_ptr, |
|
const size_t workspace_bytes, |
|
cudaStream_t stream); |
|
|
|
void gemm_bias_act_residual(const T *A, const WeightType *B, |
|
const T *weight_scales, const T *biases, |
|
const T *residual, T *C, int m, int n, int k, |
|
const std::string& activation, const std::string& binary_op, |
|
const std::string& unary_op, |
|
char *workspace_ptr, |
|
const size_t workspace_bytes, |
|
cudaStream_t stream); |
|
|
|
|
|
int getWorkspaceSize(const int m, const int n, const int k); |
|
|
|
private: |
|
template<typename EpilogueTag> |
|
void dispatch_to_arch(const T* A, |
|
const WeightType* B, |
|
const T* weight_scales, |
|
const T* biases, |
|
T* C, |
|
int m, |
|
int n, |
|
int k, |
|
int bias_stride, |
|
CutlassGemmConfig gemm_config, |
|
char* workspace_ptr, |
|
const size_t workspace_bytes, |
|
cudaStream_t stream, |
|
int* occupancy = nullptr); |
|
|
|
template<typename EpilogueTag> |
|
void run_gemm(const T* A, |
|
const WeightType* B, |
|
const T* weight_scales, |
|
const T* biases, |
|
T* C, |
|
int m, |
|
int n, |
|
int k, |
|
int bias_stride, |
|
char* workspace_ptr, |
|
const size_t workspace_bytes, |
|
cudaStream_t stream); |
|
|
|
private: |
|
static constexpr int split_k_limit = 7; |
|
|
|
int sm_; |
|
int multi_processor_count_; |
|
}; |
|
|
|
} |
|
|