import torch from torch import nn import cupy as cp import math from os import path class Quantizer(nn.Module): def __init__(self, config, codebook): super().__init__() self.nsq, nc, self.d = codebook.shape self.b = int(math.log2(nc)) head_dim = config.hidden_size // config.num_attention_heads self.head_dim = head_dim qpk = config.num_attention_heads // config.num_key_value_heads self.window_length = getattr(config, 'window_length', 32) self.register_buffer('codebook', codebook) with open(path.join(path.dirname(__file__), "quantize.cu"), "r") as f: kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)) self._quantize = cp.RawKernel( kernel_code, 'quantize', backend="nvrtc" ) with open(path.join(path.dirname(__file__), "dequantize.cu"), "r") as f: kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)) self._dequantize = cp.RawKernel( kernel_code, 'dequantize', backend="nvrtc" ) with open(path.join(path.dirname(__file__), "dequantize_rope.cu"), "r") as f: kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)) self._dequantize_rope = cp.RawKernel( kernel_code, 'dequantize_rope', backend="nvrtc" ) with open(path.join(path.dirname(__file__), "fused_rope_mult.cu"), "r") as f: kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)) self._fused_rope_mult = cp.RawKernel( kernel_code, 'fused_rope_mult', backend="nvrtc" ) with open(path.join(path.dirname(__file__), "fused_rope_pos_mult_mqa.cu"), "r") as f: kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)).replace('__QPK__', str(qpk)).replace('__ROPE_THETA__', str(config.rope_theta)) self._fused_rope_pos_mult = cp.RawKernel( kernel_code, 'fused_rope_pos_mult', backend="nvrtc" ) with open(path.join(path.dirname(__file__), "fused_mult_len.cu"), "r") as f: kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)).replace('__QPK__', str(qpk)) self._fused_mult = cp.RawKernel( kernel_code, 'fused_mult', backend="nvrtc" ) def quantize(self, x): n = x.numel() // x.shape[-1] codes = torch.empty(self.nsq, n, dtype=torch.uint8, device=x.device) blocks_per_grid = (self.nsq, ) threads_per_block = (1024, ) self._quantize(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ self.codebook.data_ptr(), x.data_ptr(), codes.data_ptr(), n ]) return codes def dequantize(self, codes): n = codes.numel() // codes.shape[0] x = torch.zeros(n, self.nsq * self.d, dtype=torch.float16, device=codes.device) blocks_per_grid = (self.nsq, ) threads_per_block = (1024, ) self._dequantize(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ self.codebook.data_ptr(), codes.data_ptr(), x.data_ptr(), n ]) return x def dequantize_rope(self, codes): _, batch_size, seq_len = codes.shape n = batch_size * seq_len x = torch.zeros(n, self.nsq * self.d, dtype=torch.float16, device=codes.device) blocks_per_grid = (self.nsq, ) threads_per_block = (1024, ) self._dequantize_rope(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ self.codebook.data_ptr(), codes.data_ptr(), x.data_ptr(), batch_size, seq_len ]) return x def fused_rope_mult(self, codes, queries): _, batch_size, k_len = codes.shape _, n_heads, q_len, _ = queries.shape out = torch.zeros(batch_size, n_heads, q_len, k_len, dtype=torch.float16, device=codes.device) blocks_per_grid = (self.nsq, ) threads_per_block = (1024, ) self._fused_rope_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ self.codebook.data_ptr(), codes.data_ptr(), queries.data_ptr(), out.data_ptr(), batch_size, q_len, k_len ]) return out def fused_rope_pos_mult(self, codes, queries, position_ids): _, batch_size, k_len = codes.shape _, n_heads, q_len, _ = queries.shape position_offsets = position_ids[:, -1] - k_len + 1 out = torch.zeros(batch_size, n_heads, q_len, k_len, dtype=torch.float32, device=codes.device) blocks_per_grid = (self.nsq, ) threads_per_block = (1024, ) self._fused_rope_pos_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ self.codebook.data_ptr(), codes.data_ptr(), position_offsets.data_ptr(), queries.data_ptr(), out.data_ptr(), batch_size, q_len, k_len ]) return out def fused_mult(self, codes, weights, skip_last=0): batch_size, n_heads, q_len, k_len = weights.shape out = torch.zeros(batch_size, n_heads, q_len, self.head_dim, dtype=torch.float16, device=codes.device) blocks_per_grid = (self.nsq, ) threads_per_block = (min(1024, batch_size), ) self._fused_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ self.codebook.data_ptr(), codes.data_ptr(), weights.data_ptr(), out.data_ptr(), batch_size, q_len, k_len, k_len - skip_last ]) return out