File size: 8,336 Bytes
b4cad21
 
5c6fb68
b4cad21
 
 
3c8bb73
 
b4cad21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c8bb73
 
5c6fb68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c8bb73
 
5c6fb68
 
 
 
 
c31b5ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165b25c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c8bb73
165b25c
 
3c8bb73
 
165b25c
 
 
 
 
c31b5ce
165b25c
 
c31b5ce
 
 
 
 
b4cad21
 
3c8bb73
 
b4cad21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include <torch/library.h>

#include "core/registration.h"
#include "torch_binding.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  #ifndef USE_ROCM

  // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column                                                                                                                                            
  // quantization, as well as bias                                                                                                                                                                                   
  ops.def(                                                                                                                                                                                                           
      "cutlass_scaled_mm(Tensor! out, Tensor a,"                                                                                                                                                                     
      "                  Tensor b, Tensor a_scales,"                                                                                                                                                                 
      "                  Tensor b_scales, Tensor? bias) -> ()");                                                                                                                                                     
  ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);                                                                                                                                                   
                                                                                                                                                                                                                     
  // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column                                                                                                                                           
  // quantization.                                                                                                                                                                                                   
  ops.def(                                                                                                                                                                                                           
      "cutlass_scaled_mm_azp(Tensor! out, Tensor a,"                                                                                                                                                                 
      "                  Tensor b, Tensor a_scales,"                                                                                                                                                                 
      "                  Tensor b_scales, Tensor azp_adj,"                                                                                                                                                           
      "                  Tensor? azp, Tensor? bias) -> ()");                                                                                                                                                         
  ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);                                                                                                                                           
                                                                                                                                                                                                                     
  // Check if cutlass scaled_mm is supported for CUDA devices of the given                                                                                                                                           
  // capability                                                                                                                                                                                                      
  ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");                                                                                                                                     
  ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);                                            

  #endif

  // Compute FP8 quantized tensor for given scaling factor.
  ops.def(
      "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
      "()");
  ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);

  // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
  ops.def(
      "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
      "-> "
      "()");
  ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);

  // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
  ops.def(
      "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
      "Tensor! scale, Tensor? scale_ub) -> "
      "()");
  ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
           &dynamic_per_token_scaled_fp8_quant);

  // Compute int8 quantized tensor for given scaling factor.
  ops.def(
      "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
      "Tensor? azp) -> ()");
  ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);

  // Compute int8 quantized tensor and scaling factor
  ops.def(
      "dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
      "Tensor!? azp) -> ()");
  ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
           &dynamic_scaled_int8_quant);

  #ifndef USE_ROCM

  // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
  ops.def(
      "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
      "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
      "SymInt size_k) -> Tensor");

  // awq_marlin repack from AWQ.
  ops.def(
      "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
      "SymInt size_n, int num_bits) -> Tensor");

  // gptq_marlin Optimized Quantized GEMM for GPTQ.
  ops.def(
      "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
      "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
      "int b_q_type, "
      "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
      "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");

  // gptq_marlin repack from GPTQ.
  ops.def(
      "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
      "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");

  // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
  ops.def(
      "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
      "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
      "Tensor");

  // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
  ops.def(
      "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
      "Tensor b_scales, Tensor workspace, "
      "int b_q_type, "
      "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");

  // marlin_qqq_gemm for QQQ.
  ops.def(
      "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
      "Tensor s_tok, Tensor s_ch, Tensor s_group, "
      "Tensor! workspace, SymInt size_m, SymInt size_n, "
      "SymInt size_k) -> Tensor");
  #endif
}

#ifndef USE_ROCM

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
  ops.impl("awq_marlin_repack", &awq_marlin_repack);
  ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
  ops.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
  ops.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
  ops.impl("gptq_marlin_repack", &gptq_marlin_repack);
  ops.impl("marlin_gemm", &marlin_gemm);
  ops.impl("marlin_qqq_gemm", &marlin_qqq_gemm);
}

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
  ops.impl("awq_marlin_repack", &awq_marlin_repack_meta);
  ops.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
}

#endif

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)