File size: 4,385 Bytes
8aa00a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import glob
import itertools
import os
import subprocess

import jinja2

FILE_HEAD = """
// auto generated by generate.py
// clang-format off

#include "kernel.h"
#include "marlin_template.h"

namespace MARLIN_NAMESPACE_NAME {
""".strip()

TEMPLATE = ("template __global__ void Marlin<"
            "{{scalar_t}}, "
            "{{w_type_id}}, "
            "{{threads}}, "
            "{{thread_m_blocks}}, "
            "{{thread_n_blocks}}, "
            "{{thread_k_blocks}}, "
            "{{'true' if m_block_size_8 else 'false'}}, "
            "{{stages}}, "
            "{{group_blocks}}, "
            "{{'true' if is_zp_float else 'false'}}>"
            "( MARLIN_KERNEL_PARAMS );")

# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
    "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
    "vllm::kFE2M1f"
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
                  (128, 64, 128)]

THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
#   = 0 : act order case
#   = -1 : channelwise quantization
#   > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]


def remove_old_kernels():
    for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
        subprocess.call(["rm", "-f", filename])


def generate_new_kernels():
    for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
        all_template_str_list = []

        for group_blocks, m_blocks, thread_configs in itertools.product(
                GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):

            # act order case only support gptq-int4 and gptq-int8
            if group_blocks == 0 and scalar_type not in [
                    "vllm::kU4B8", "vllm::kU8B128"
            ]:
                continue
            if thread_configs[2] == 256:
                # for small batch (m_blocks == 1), we only need (128, 128, 256)
                # for large batch (m_blocks > 1), we only need (64, 256, 256)
                if m_blocks <= 1 and thread_configs[0] != 128:
                    continue
                if m_blocks > 1 and thread_configs[0] != 64:
                    continue

            # we only support channelwise quantization and group_size == 128
            # for fp8
            if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
                continue
            # nvfp4 only supports group_size == 16
            if scalar_type == "vllm::kFE2M1f" and group_blocks != 1:
                continue
            # other quantization methods don't support group_size = 16
            if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
                continue

            k_blocks = thread_configs[0] // 16
            n_blocks = thread_configs[1] // 16
            threads = thread_configs[2]

            c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"

            is_zp_float_list = [False]
            if dtype == "fp16" and scalar_type == "vllm::kU4" and \
                    group_blocks == 4:
                # HQQ (is_zp_float = true) only supports
                # 4bit quantization and fp16
                is_zp_float_list.append(True)

            for is_zp_float in is_zp_float_list:
                template_str = jinja2.Template(TEMPLATE).render(
                    scalar_t=c_dtype,
                    w_type_id=scalar_type + ".id()",
                    threads=threads,
                    thread_m_blocks=max(m_blocks, 1),
                    thread_n_blocks=n_blocks,
                    thread_k_blocks=k_blocks,
                    m_block_size_8=m_blocks == 0.5,
                    stages="pipe_stages",
                    group_blocks=group_blocks,
                    is_zp_float=is_zp_float,
                )

                all_template_str_list.append(template_str)

        file_content = FILE_HEAD + "\n\n"
        file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
        filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"

        with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
            f.write(file_content)


if __name__ == "__main__":
    remove_old_kernels()
    generate_new_kernels()