Aston-xMAD's picture
init commit
b37c16f verified
import torch
from torch import nn
import cupy as cp
import math
from os import path
class Quantizer(nn.Module):
def __init__(self, config, codebook):
super().__init__()
self.nsq, nc, self.d = codebook.shape
self.b = int(math.log2(nc))
head_dim = config.hidden_size // config.num_attention_heads
self.head_dim = head_dim
qpk = config.num_attention_heads // config.num_key_value_heads
self.window_length = getattr(config, 'window_length', 32)
self.register_buffer('codebook', codebook)
with open(path.join(path.dirname(__file__), "quantize.cu"), "r") as f:
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d))
self._quantize = cp.RawKernel(
kernel_code,
'quantize',
backend="nvrtc"
)
with open(path.join(path.dirname(__file__), "dequantize.cu"), "r") as f:
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d))
self._dequantize = cp.RawKernel(
kernel_code,
'dequantize',
backend="nvrtc"
)
with open(path.join(path.dirname(__file__), "dequantize_rope.cu"), "r") as f:
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim))
self._dequantize_rope = cp.RawKernel(
kernel_code,
'dequantize_rope',
backend="nvrtc"
)
with open(path.join(path.dirname(__file__), "fused_rope_mult.cu"), "r") as f:
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim))
self._fused_rope_mult = cp.RawKernel(
kernel_code,
'fused_rope_mult',
backend="nvrtc"
)
with open(path.join(path.dirname(__file__), "fused_rope_pos_mult_mqa.cu"), "r") as f:
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)).replace('__QPK__', str(qpk)).replace('__ROPE_THETA__', str(config.rope_theta))
self._fused_rope_pos_mult = cp.RawKernel(
kernel_code,
'fused_rope_pos_mult',
backend="nvrtc"
)
with open(path.join(path.dirname(__file__), "fused_mult_len.cu"), "r") as f:
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)).replace('__QPK__', str(qpk))
self._fused_mult = cp.RawKernel(
kernel_code,
'fused_mult',
backend="nvrtc"
)
def quantize(self, x):
n = x.numel() // x.shape[-1]
codes = torch.empty(self.nsq, n, dtype=torch.uint8, device=x.device)
blocks_per_grid = (self.nsq, )
threads_per_block = (1024, )
self._quantize(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[
self.codebook.data_ptr(),
x.data_ptr(),
codes.data_ptr(),
n
])
return codes
def dequantize(self, codes):
n = codes.numel() // codes.shape[0]
x = torch.zeros(n, self.nsq * self.d, dtype=torch.float16, device=codes.device)
blocks_per_grid = (self.nsq, )
threads_per_block = (1024, )
self._dequantize(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[
self.codebook.data_ptr(),
codes.data_ptr(),
x.data_ptr(),
n
])
return x
def dequantize_rope(self, codes):
_, batch_size, seq_len = codes.shape
n = batch_size * seq_len
x = torch.zeros(n, self.nsq * self.d, dtype=torch.float16, device=codes.device)
blocks_per_grid = (self.nsq, )
threads_per_block = (1024, )
self._dequantize_rope(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[
self.codebook.data_ptr(),
codes.data_ptr(),
x.data_ptr(),
batch_size, seq_len
])
return x
def fused_rope_mult(self, codes, queries):
_, batch_size, k_len = codes.shape
_, n_heads, q_len, _ = queries.shape
out = torch.zeros(batch_size, n_heads, q_len, k_len, dtype=torch.float16, device=codes.device)
blocks_per_grid = (self.nsq, )
threads_per_block = (1024, )
self._fused_rope_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[
self.codebook.data_ptr(),
codes.data_ptr(),
queries.data_ptr(),
out.data_ptr(),
batch_size, q_len, k_len
])
return out
def fused_rope_pos_mult(self, codes, queries, position_ids):
_, batch_size, k_len = codes.shape
_, n_heads, q_len, _ = queries.shape
position_offsets = position_ids[:, -1] - k_len + 1
out = torch.zeros(batch_size, n_heads, q_len, k_len, dtype=torch.float32, device=codes.device)
blocks_per_grid = (self.nsq, )
threads_per_block = (1024, )
self._fused_rope_pos_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[
self.codebook.data_ptr(),
codes.data_ptr(),
position_offsets.data_ptr(),
queries.data_ptr(),
out.data_ptr(),
batch_size, q_len, k_len
])
return out
def fused_mult(self, codes, weights, skip_last=0):
batch_size, n_heads, q_len, k_len = weights.shape
out = torch.zeros(batch_size, n_heads, q_len, self.head_dim, dtype=torch.float16, device=codes.device)
blocks_per_grid = (self.nsq, )
threads_per_block = (min(1024, batch_size), )
self._fused_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[
self.codebook.data_ptr(),
codes.data_ptr(),
weights.data_ptr(),
out.data_ptr(),
batch_size, q_len, k_len, k_len - skip_last
])
return out