File size: 9,313 Bytes
1dc29e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
#include <torch/all.h>
#include "cub/cub.cuh"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <c10/cuda/CUDAGuard.h>
#include "fpA_intB_gemm_wrapper.h"
#include "fpA_intB_gemm.h"
#include "cutlass_preprocessors.h"
#include "cuda_utils.h"
#include "weightOnlyBatchedGemv/enabled.h"
#include "weightOnlyBatchedGemv/kernelLauncher.h"
#include "torch_utils.h"

#include <vector>

namespace ft = fastertransformer;

int getWorkspaceSize(const int m, const int n, const int k)
{
    // These are the min tile sizes for each config, which would launch the maximum number of blocks
    const int max_grid_m = (m + 31) / 32;
    const int max_grid_n = (n + 127) / 128;
    const int split_k_limit = 7;
    // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
    return max_grid_m * max_grid_n * split_k_limit * 4;
}

std::vector<torch::Tensor>
symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
                                       at::ScalarType quant_type,
                                       bool return_unprocessed_quantized_tensor)
{
    CHECK_CPU(weight);
    CHECK_CONTIGUOUS(weight);
    TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
    TORCH_CHECK(weight.dim() == 2 || weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3");

    auto _st = weight.scalar_type();
    TORCH_CHECK(_st == torch::kFloat32 || _st == torch::kFloat16, "Invalid datatype. Weight must be FP16 or FP32");
    TORCH_CHECK(quant_type == torch::kInt8 || quant_type == at::ScalarType::QUInt4x2, "Must be int4 or int8 quantization");
    ft::QuantType ft_quant_type = ft::get_ft_quant_type(quant_type);

    const size_t num_experts = weight.dim() == 2 ? 1 : weight.size(0);
    const size_t num_rows    = weight.size(-2);
    const size_t num_cols    = weight.size(-1);

    const size_t bits_in_type      = ft::get_bits_in_quant_type(ft_quant_type);
    const size_t bytes_per_out_col = num_cols * bits_in_type / 8;

    const size_t input_mat_size     = num_rows * num_cols;
    const size_t quantized_mat_size = num_rows * bytes_per_out_col;

    std::vector<long int> quantized_weight_shape;
    std::vector<long int> scale_shape;
    if (weight.dim() == 2) {
        quantized_weight_shape = {long(num_rows), long(bytes_per_out_col)};
        scale_shape            = {long(num_cols)};
    }
    else if (weight.dim() == 3) {
        quantized_weight_shape = {long(num_experts), long(num_rows), long(bytes_per_out_col)};
        scale_shape            = {long(num_experts), long(num_cols)};
    }
    else {
        TORCH_CHECK(false, "Invalid weight dimension. Weight must have dim 2 or 3");
    }

    torch::Tensor unprocessed_quantized_weight =
        torch::empty(quantized_weight_shape, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false));

    torch::Tensor processed_quantized_weight = torch::empty_like(unprocessed_quantized_weight);

    torch::Tensor scales = torch::empty(scale_shape, torch::dtype(weight.dtype()).device(torch::kCPU).requires_grad(false));

    int8_t *unprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(unprocessed_quantized_weight.data_ptr());
    int8_t *processed_quantized_weight_ptr = reinterpret_cast<int8_t *>(processed_quantized_weight.data_ptr());

    if (weight.scalar_type() == at::ScalarType::Float)
    {
        ft::symmetric_quantize<float, float>(processed_quantized_weight_ptr,
                                             unprocessed_quantized_weight_ptr,
                                             reinterpret_cast<float *>(scales.data_ptr()),
                                             reinterpret_cast<const float *>(weight.data_ptr()),
                                             {num_rows, num_cols},
                                             ft_quant_type);
    }
    else if (weight.scalar_type() == at::ScalarType::Half)
    {
        ft::symmetric_quantize<half, half>(processed_quantized_weight_ptr,
                                           unprocessed_quantized_weight_ptr,
                                           reinterpret_cast<half *>(scales.data_ptr()),
                                           reinterpret_cast<const half *>(weight.data_ptr()),
                                           {num_rows, num_cols},
                                           ft_quant_type);
    }
    else
    {
        TORCH_CHECK(false, "Invalid data type. Weight must be FP32/FP16");
    }

    if (return_unprocessed_quantized_tensor)
    {
        return std::vector<torch::Tensor>{unprocessed_quantized_weight, processed_quantized_weight, scales};
    }

    return std::vector<torch::Tensor>{processed_quantized_weight, scales};
}

torch::Tensor preprocess_weights_cuda(torch::Tensor const &origin_weight,
                                      bool is_int4)
{
    // guarantee the weight is cpu tensor
    CHECK_CPU(origin_weight);

    torch::Tensor preprocessed_quantized_weight = torch::empty_like(origin_weight);
    int8_t *preprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(preprocessed_quantized_weight.data_ptr());
    const int8_t *row_major_quantized_weight_ptr = reinterpret_cast<const int8_t *>(origin_weight.data_ptr());
    size_t rows = origin_weight.size(-2);
    size_t cols = origin_weight.size(-1);
    int arch = ft::getSMVersion();
    ft::preprocess_weights(preprocessed_quantized_weight_ptr,
                                          row_major_quantized_weight_ptr,
                                          rows,
                                          cols,
                                          is_int4,
                                          arch);
    return preprocessed_quantized_weight;
}

torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
                                       torch::Tensor const &weight,
                                       torch::Tensor const &scale)
{
    c10::cuda::CUDAGuard device_guard(input.device());
    // TORCH_CHECK(input.dim() == 3 || input.dim() == 2, "Invalid input dim: ", input.dim());
    const int m = input.dim() == 2 ? input.size(0) : input.size(0) * input.size(1);
    const int k = input.size(-1);
    const int n = weight.size(-1);
    auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device());
    torch::Tensor output = input.dim() == 2 ? torch::empty({m, n}, options) : torch::empty({input.size(0), input.size(1), n}, options);
    const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
    const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
    const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
    ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
    // const int max_size = std::max(n, k);
    // size_t workspace_size = getWorkspaceSize(m, max_size, max_size);
    // void *ptr = nullptr;
    // char *workspace_ptr = workspace_size > 0 ? (char *)cudaMalloc((void **)&ptr, workspace_size) : nullptr;
    const bool use_cuda_kernel = m <= SMALL_M_FAST_PATH;
    // const bool use_cuda_kernel = false; 
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

    if(use_cuda_kernel){
        tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16;
        tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b;
        tensorrt_llm::kernels::WeightOnlyParams params{weight_ptr, reinterpret_cast<const uint8_t *>(scale.data_ptr()), nullptr,
            reinterpret_cast<half *>(input.data_ptr()), nullptr, nullptr, reinterpret_cast<half *>(output.data_ptr()), m, n, k, 0, weight_only_quant_type,
            tensorrt_llm::kernels::WeightOnlyType::PerChannel,
            tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
        tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);
    }
    else
        ft::gemm_fp16_int(
            input_ptr,
            weight_ptr,
            scale_ptr,
            output_ptr,
            m, n, k,
            nullptr,
            0,
            stream);
    return output;
}


torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
                                        torch::Tensor const &weight,
                                        torch::Tensor const &scale,
                                        torch::Tensor &output,
                                        const int64_t m,
                                        const int64_t n,
                                        const int64_t k)
{
    c10::cuda::CUDAGuard device_guard(input.device());

    const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
    const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
    const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
    ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

    ft::gemm_fp16_int(
        input_ptr,
        weight_ptr,
        scale_ptr,
        output_ptr,
        m, n, k,
        nullptr,
        0,
        stream);
    return output;
}