File size: 6,629 Bytes
b37c16f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
from torch import nn
import cupy as cp
import math
from os import path
class Quantizer(nn.Module):
def __init__(self, config, codebook):
super().__init__()
self.nsq, nc, self.d = codebook.shape
self.b = int(math.log2(nc))
head_dim = config.hidden_size // config.num_attention_heads
self.head_dim = head_dim
qpk = config.num_attention_heads // config.num_key_value_heads
self.window_length = getattr(config, 'window_length', 32)
self.register_buffer('codebook', codebook)
with open(path.join(path.dirname(__file__), "quantize.cu"), "r") as f:
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d))
self._quantize = cp.RawKernel(
kernel_code,
'quantize',
backend="nvrtc"
)
with open(path.join(path.dirname(__file__), "dequantize.cu"), "r") as f:
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d))
self._dequantize = cp.RawKernel(
kernel_code,
'dequantize',
backend="nvrtc"
)
with open(path.join(path.dirname(__file__), "dequantize_rope.cu"), "r") as f:
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim))
self._dequantize_rope = cp.RawKernel(
kernel_code,
'dequantize_rope',
backend="nvrtc"
)
with open(path.join(path.dirname(__file__), "fused_rope_mult.cu"), "r") as f:
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim))
self._fused_rope_mult = cp.RawKernel(
kernel_code,
'fused_rope_mult',
backend="nvrtc"
)
with open(path.join(path.dirname(__file__), "fused_rope_pos_mult_mqa.cu"), "r") as f:
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)).replace('__QPK__', str(qpk)).replace('__ROPE_THETA__', str(config.rope_theta))
self._fused_rope_pos_mult = cp.RawKernel(
kernel_code,
'fused_rope_pos_mult',
backend="nvrtc"
)
with open(path.join(path.dirname(__file__), "fused_mult_len.cu"), "r") as f:
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)).replace('__QPK__', str(qpk))
self._fused_mult = cp.RawKernel(
kernel_code,
'fused_mult',
backend="nvrtc"
)
def quantize(self, x):
n = x.numel() // x.shape[-1]
codes = torch.empty(self.nsq, n, dtype=torch.uint8, device=x.device)
blocks_per_grid = (self.nsq, )
threads_per_block = (1024, )
self._quantize(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[
self.codebook.data_ptr(),
x.data_ptr(),
codes.data_ptr(),
n
])
return codes
def dequantize(self, codes):
n = codes.numel() // codes.shape[0]
x = torch.zeros(n, self.nsq * self.d, dtype=torch.float16, device=codes.device)
blocks_per_grid = (self.nsq, )
threads_per_block = (1024, )
self._dequantize(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[
self.codebook.data_ptr(),
codes.data_ptr(),
x.data_ptr(),
n
])
return x
def dequantize_rope(self, codes):
_, batch_size, seq_len = codes.shape
n = batch_size * seq_len
x = torch.zeros(n, self.nsq * self.d, dtype=torch.float16, device=codes.device)
blocks_per_grid = (self.nsq, )
threads_per_block = (1024, )
self._dequantize_rope(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[
self.codebook.data_ptr(),
codes.data_ptr(),
x.data_ptr(),
batch_size, seq_len
])
return x
def fused_rope_mult(self, codes, queries):
_, batch_size, k_len = codes.shape
_, n_heads, q_len, _ = queries.shape
out = torch.zeros(batch_size, n_heads, q_len, k_len, dtype=torch.float16, device=codes.device)
blocks_per_grid = (self.nsq, )
threads_per_block = (1024, )
self._fused_rope_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[
self.codebook.data_ptr(),
codes.data_ptr(),
queries.data_ptr(),
out.data_ptr(),
batch_size, q_len, k_len
])
return out
def fused_rope_pos_mult(self, codes, queries, position_ids):
_, batch_size, k_len = codes.shape
_, n_heads, q_len, _ = queries.shape
position_offsets = position_ids[:, -1] - k_len + 1
out = torch.zeros(batch_size, n_heads, q_len, k_len, dtype=torch.float32, device=codes.device)
blocks_per_grid = (self.nsq, )
threads_per_block = (1024, )
self._fused_rope_pos_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[
self.codebook.data_ptr(),
codes.data_ptr(),
position_offsets.data_ptr(),
queries.data_ptr(),
out.data_ptr(),
batch_size, q_len, k_len
])
return out
def fused_mult(self, codes, weights, skip_last=0):
batch_size, n_heads, q_len, k_len = weights.shape
out = torch.zeros(batch_size, n_heads, q_len, self.head_dim, dtype=torch.float16, device=codes.device)
blocks_per_grid = (self.nsq, )
threads_per_block = (min(1024, batch_size), )
self._fused_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[
self.codebook.data_ptr(),
codes.data_ptr(),
weights.data_ptr(),
out.data_ptr(),
batch_size, q_len, k_len, k_len - skip_last
])
return out
|