|
#include <torch/all.h> |
|
#include "cub/cub.cuh" |
|
#include <cuda_runtime.h> |
|
#include <cuda_fp16.h> |
|
#include <c10/cuda/CUDAGuard.h> |
|
#include "fpA_intB_gemm_wrapper.h" |
|
#include "fpA_intB_gemm.h" |
|
#include "cutlass_preprocessors.h" |
|
#include "cuda_utils.h" |
|
#include "weightOnlyBatchedGemv/enabled.h" |
|
#include "weightOnlyBatchedGemv/kernelLauncher.h" |
|
#include "torch_utils.h" |
|
|
|
#include <vector> |
|
|
|
namespace ft = fastertransformer; |
|
|
|
int getWorkspaceSize(const int m, const int n, const int k) |
|
{ |
|
|
|
const int max_grid_m = (m + 31) / 32; |
|
const int max_grid_n = (n + 127) / 128; |
|
const int split_k_limit = 7; |
|
|
|
return max_grid_m * max_grid_n * split_k_limit * 4; |
|
} |
|
|
|
std::vector<torch::Tensor> |
|
symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight, |
|
at::ScalarType quant_type, |
|
bool return_unprocessed_quantized_tensor) |
|
{ |
|
CHECK_CPU(weight); |
|
CHECK_CONTIGUOUS(weight); |
|
TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor"); |
|
TORCH_CHECK(weight.dim() == 2 || weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3"); |
|
|
|
auto _st = weight.scalar_type(); |
|
TORCH_CHECK(_st == torch::kFloat32 || _st == torch::kFloat16, "Invalid datatype. Weight must be FP16 or FP32"); |
|
TORCH_CHECK(quant_type == torch::kInt8 || quant_type == at::ScalarType::QUInt4x2, "Must be int4 or int8 quantization"); |
|
ft::QuantType ft_quant_type = ft::get_ft_quant_type(quant_type); |
|
|
|
const size_t num_experts = weight.dim() == 2 ? 1 : weight.size(0); |
|
const size_t num_rows = weight.size(-2); |
|
const size_t num_cols = weight.size(-1); |
|
|
|
const size_t bits_in_type = ft::get_bits_in_quant_type(ft_quant_type); |
|
const size_t bytes_per_out_col = num_cols * bits_in_type / 8; |
|
|
|
const size_t input_mat_size = num_rows * num_cols; |
|
const size_t quantized_mat_size = num_rows * bytes_per_out_col; |
|
|
|
std::vector<long int> quantized_weight_shape; |
|
std::vector<long int> scale_shape; |
|
if (weight.dim() == 2) { |
|
quantized_weight_shape = {long(num_rows), long(bytes_per_out_col)}; |
|
scale_shape = {long(num_cols)}; |
|
} |
|
else if (weight.dim() == 3) { |
|
quantized_weight_shape = {long(num_experts), long(num_rows), long(bytes_per_out_col)}; |
|
scale_shape = {long(num_experts), long(num_cols)}; |
|
} |
|
else { |
|
TORCH_CHECK(false, "Invalid weight dimension. Weight must have dim 2 or 3"); |
|
} |
|
|
|
torch::Tensor unprocessed_quantized_weight = |
|
torch::empty(quantized_weight_shape, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false)); |
|
|
|
torch::Tensor processed_quantized_weight = torch::empty_like(unprocessed_quantized_weight); |
|
|
|
torch::Tensor scales = torch::empty(scale_shape, torch::dtype(weight.dtype()).device(torch::kCPU).requires_grad(false)); |
|
|
|
int8_t *unprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(unprocessed_quantized_weight.data_ptr()); |
|
int8_t *processed_quantized_weight_ptr = reinterpret_cast<int8_t *>(processed_quantized_weight.data_ptr()); |
|
|
|
if (weight.scalar_type() == at::ScalarType::Float) |
|
{ |
|
ft::symmetric_quantize<float, float>(processed_quantized_weight_ptr, |
|
unprocessed_quantized_weight_ptr, |
|
reinterpret_cast<float *>(scales.data_ptr()), |
|
reinterpret_cast<const float *>(weight.data_ptr()), |
|
{num_rows, num_cols}, |
|
ft_quant_type); |
|
} |
|
else if (weight.scalar_type() == at::ScalarType::Half) |
|
{ |
|
ft::symmetric_quantize<half, half>(processed_quantized_weight_ptr, |
|
unprocessed_quantized_weight_ptr, |
|
reinterpret_cast<half *>(scales.data_ptr()), |
|
reinterpret_cast<const half *>(weight.data_ptr()), |
|
{num_rows, num_cols}, |
|
ft_quant_type); |
|
} |
|
else |
|
{ |
|
TORCH_CHECK(false, "Invalid data type. Weight must be FP32/FP16"); |
|
} |
|
|
|
if (return_unprocessed_quantized_tensor) |
|
{ |
|
return std::vector<torch::Tensor>{unprocessed_quantized_weight, processed_quantized_weight, scales}; |
|
} |
|
|
|
return std::vector<torch::Tensor>{processed_quantized_weight, scales}; |
|
} |
|
|
|
torch::Tensor preprocess_weights_cuda(torch::Tensor const &origin_weight, |
|
bool is_int4) |
|
{ |
|
|
|
CHECK_CPU(origin_weight); |
|
|
|
torch::Tensor preprocessed_quantized_weight = torch::empty_like(origin_weight); |
|
int8_t *preprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(preprocessed_quantized_weight.data_ptr()); |
|
const int8_t *row_major_quantized_weight_ptr = reinterpret_cast<const int8_t *>(origin_weight.data_ptr()); |
|
size_t rows = origin_weight.size(-2); |
|
size_t cols = origin_weight.size(-1); |
|
int arch = ft::getSMVersion(); |
|
ft::preprocess_weights(preprocessed_quantized_weight_ptr, |
|
row_major_quantized_weight_ptr, |
|
rows, |
|
cols, |
|
is_int4, |
|
arch); |
|
return preprocessed_quantized_weight; |
|
} |
|
|
|
torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input, |
|
torch::Tensor const &weight, |
|
torch::Tensor const &scale) |
|
{ |
|
c10::cuda::CUDAGuard device_guard(input.device()); |
|
|
|
const int m = input.dim() == 2 ? input.size(0) : input.size(0) * input.size(1); |
|
const int k = input.size(-1); |
|
const int n = weight.size(-1); |
|
auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); |
|
torch::Tensor output = input.dim() == 2 ? torch::empty({m, n}, options) : torch::empty({input.size(0), input.size(1), n}, options); |
|
const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr()); |
|
const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr()); |
|
const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr()); |
|
ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr()); |
|
|
|
|
|
|
|
|
|
const bool use_cuda_kernel = m <= SMALL_M_FAST_PATH; |
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
if(use_cuda_kernel){ |
|
tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16; |
|
tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b; |
|
tensorrt_llm::kernels::WeightOnlyParams params{weight_ptr, reinterpret_cast<const uint8_t *>(scale.data_ptr()), nullptr, |
|
reinterpret_cast<half *>(input.data_ptr()), nullptr, nullptr, reinterpret_cast<half *>(output.data_ptr()), m, n, k, 0, weight_only_quant_type, |
|
tensorrt_llm::kernels::WeightOnlyType::PerChannel, |
|
tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type}; |
|
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream); |
|
} |
|
else |
|
ft::gemm_fp16_int( |
|
input_ptr, |
|
weight_ptr, |
|
scale_ptr, |
|
output_ptr, |
|
m, n, k, |
|
nullptr, |
|
0, |
|
stream); |
|
return output; |
|
} |
|
|
|
|
|
torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input, |
|
torch::Tensor const &weight, |
|
torch::Tensor const &scale, |
|
torch::Tensor &output, |
|
const int64_t m, |
|
const int64_t n, |
|
const int64_t k) |
|
{ |
|
c10::cuda::CUDAGuard device_guard(input.device()); |
|
|
|
const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr()); |
|
const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr()); |
|
const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr()); |
|
ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr()); |
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
ft::gemm_fp16_int( |
|
input_ptr, |
|
weight_ptr, |
|
scale_ptr, |
|
output_ptr, |
|
m, n, k, |
|
nullptr, |
|
0, |
|
stream); |
|
return output; |
|
} |
|
|