File size: 9,313 Bytes
1dc29e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
#include <torch/all.h>
#include "cub/cub.cuh"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <c10/cuda/CUDAGuard.h>
#include "fpA_intB_gemm_wrapper.h"
#include "fpA_intB_gemm.h"
#include "cutlass_preprocessors.h"
#include "cuda_utils.h"
#include "weightOnlyBatchedGemv/enabled.h"
#include "weightOnlyBatchedGemv/kernelLauncher.h"
#include "torch_utils.h"
#include <vector>
namespace ft = fastertransformer;
int getWorkspaceSize(const int m, const int n, const int k)
{
// These are the min tile sizes for each config, which would launch the maximum number of blocks
const int max_grid_m = (m + 31) / 32;
const int max_grid_n = (n + 127) / 128;
const int split_k_limit = 7;
// We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
return max_grid_m * max_grid_n * split_k_limit * 4;
}
std::vector<torch::Tensor>
symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
at::ScalarType quant_type,
bool return_unprocessed_quantized_tensor)
{
CHECK_CPU(weight);
CHECK_CONTIGUOUS(weight);
TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
TORCH_CHECK(weight.dim() == 2 || weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3");
auto _st = weight.scalar_type();
TORCH_CHECK(_st == torch::kFloat32 || _st == torch::kFloat16, "Invalid datatype. Weight must be FP16 or FP32");
TORCH_CHECK(quant_type == torch::kInt8 || quant_type == at::ScalarType::QUInt4x2, "Must be int4 or int8 quantization");
ft::QuantType ft_quant_type = ft::get_ft_quant_type(quant_type);
const size_t num_experts = weight.dim() == 2 ? 1 : weight.size(0);
const size_t num_rows = weight.size(-2);
const size_t num_cols = weight.size(-1);
const size_t bits_in_type = ft::get_bits_in_quant_type(ft_quant_type);
const size_t bytes_per_out_col = num_cols * bits_in_type / 8;
const size_t input_mat_size = num_rows * num_cols;
const size_t quantized_mat_size = num_rows * bytes_per_out_col;
std::vector<long int> quantized_weight_shape;
std::vector<long int> scale_shape;
if (weight.dim() == 2) {
quantized_weight_shape = {long(num_rows), long(bytes_per_out_col)};
scale_shape = {long(num_cols)};
}
else if (weight.dim() == 3) {
quantized_weight_shape = {long(num_experts), long(num_rows), long(bytes_per_out_col)};
scale_shape = {long(num_experts), long(num_cols)};
}
else {
TORCH_CHECK(false, "Invalid weight dimension. Weight must have dim 2 or 3");
}
torch::Tensor unprocessed_quantized_weight =
torch::empty(quantized_weight_shape, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false));
torch::Tensor processed_quantized_weight = torch::empty_like(unprocessed_quantized_weight);
torch::Tensor scales = torch::empty(scale_shape, torch::dtype(weight.dtype()).device(torch::kCPU).requires_grad(false));
int8_t *unprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(unprocessed_quantized_weight.data_ptr());
int8_t *processed_quantized_weight_ptr = reinterpret_cast<int8_t *>(processed_quantized_weight.data_ptr());
if (weight.scalar_type() == at::ScalarType::Float)
{
ft::symmetric_quantize<float, float>(processed_quantized_weight_ptr,
unprocessed_quantized_weight_ptr,
reinterpret_cast<float *>(scales.data_ptr()),
reinterpret_cast<const float *>(weight.data_ptr()),
{num_rows, num_cols},
ft_quant_type);
}
else if (weight.scalar_type() == at::ScalarType::Half)
{
ft::symmetric_quantize<half, half>(processed_quantized_weight_ptr,
unprocessed_quantized_weight_ptr,
reinterpret_cast<half *>(scales.data_ptr()),
reinterpret_cast<const half *>(weight.data_ptr()),
{num_rows, num_cols},
ft_quant_type);
}
else
{
TORCH_CHECK(false, "Invalid data type. Weight must be FP32/FP16");
}
if (return_unprocessed_quantized_tensor)
{
return std::vector<torch::Tensor>{unprocessed_quantized_weight, processed_quantized_weight, scales};
}
return std::vector<torch::Tensor>{processed_quantized_weight, scales};
}
torch::Tensor preprocess_weights_cuda(torch::Tensor const &origin_weight,
bool is_int4)
{
// guarantee the weight is cpu tensor
CHECK_CPU(origin_weight);
torch::Tensor preprocessed_quantized_weight = torch::empty_like(origin_weight);
int8_t *preprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(preprocessed_quantized_weight.data_ptr());
const int8_t *row_major_quantized_weight_ptr = reinterpret_cast<const int8_t *>(origin_weight.data_ptr());
size_t rows = origin_weight.size(-2);
size_t cols = origin_weight.size(-1);
int arch = ft::getSMVersion();
ft::preprocess_weights(preprocessed_quantized_weight_ptr,
row_major_quantized_weight_ptr,
rows,
cols,
is_int4,
arch);
return preprocessed_quantized_weight;
}
torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
torch::Tensor const &weight,
torch::Tensor const &scale)
{
c10::cuda::CUDAGuard device_guard(input.device());
// TORCH_CHECK(input.dim() == 3 || input.dim() == 2, "Invalid input dim: ", input.dim());
const int m = input.dim() == 2 ? input.size(0) : input.size(0) * input.size(1);
const int k = input.size(-1);
const int n = weight.size(-1);
auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device());
torch::Tensor output = input.dim() == 2 ? torch::empty({m, n}, options) : torch::empty({input.size(0), input.size(1), n}, options);
const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
// const int max_size = std::max(n, k);
// size_t workspace_size = getWorkspaceSize(m, max_size, max_size);
// void *ptr = nullptr;
// char *workspace_ptr = workspace_size > 0 ? (char *)cudaMalloc((void **)&ptr, workspace_size) : nullptr;
const bool use_cuda_kernel = m <= SMALL_M_FAST_PATH;
// const bool use_cuda_kernel = false;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if(use_cuda_kernel){
tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16;
tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b;
tensorrt_llm::kernels::WeightOnlyParams params{weight_ptr, reinterpret_cast<const uint8_t *>(scale.data_ptr()), nullptr,
reinterpret_cast<half *>(input.data_ptr()), nullptr, nullptr, reinterpret_cast<half *>(output.data_ptr()), m, n, k, 0, weight_only_quant_type,
tensorrt_llm::kernels::WeightOnlyType::PerChannel,
tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);
}
else
ft::gemm_fp16_int(
input_ptr,
weight_ptr,
scale_ptr,
output_ptr,
m, n, k,
nullptr,
0,
stream);
return output;
}
torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
torch::Tensor const &weight,
torch::Tensor const &scale,
torch::Tensor &output,
const int64_t m,
const int64_t n,
const int64_t k)
{
c10::cuda::CUDAGuard device_guard(input.device());
const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
ft::gemm_fp16_int(
input_ptr,
weight_ptr,
scale_ptr,
output_ptr,
m, n, k,
nullptr,
0,
stream);
return output;
}
|