Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import cupy as cp | |
import math | |
from os import path | |
class Quantizer(nn.Module): | |
def __init__(self, config, codebook): | |
super().__init__() | |
self.nsq, nc, self.d = codebook.shape | |
self.b = int(math.log2(nc)) | |
head_dim = config.hidden_size // config.num_attention_heads | |
self.head_dim = head_dim | |
qpk = config.num_attention_heads // config.num_key_value_heads | |
self.window_length = getattr(config, 'window_length', 32) | |
self.register_buffer('codebook', codebook) | |
with open(path.join(path.dirname(__file__), "quantize.cu"), "r") as f: | |
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)) | |
self._quantize = cp.RawKernel( | |
kernel_code, | |
'quantize', | |
backend="nvrtc" | |
) | |
with open(path.join(path.dirname(__file__), "dequantize.cu"), "r") as f: | |
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)) | |
self._dequantize = cp.RawKernel( | |
kernel_code, | |
'dequantize', | |
backend="nvrtc" | |
) | |
with open(path.join(path.dirname(__file__), "dequantize_rope.cu"), "r") as f: | |
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)) | |
self._dequantize_rope = cp.RawKernel( | |
kernel_code, | |
'dequantize_rope', | |
backend="nvrtc" | |
) | |
with open(path.join(path.dirname(__file__), "fused_rope_mult.cu"), "r") as f: | |
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)) | |
self._fused_rope_mult = cp.RawKernel( | |
kernel_code, | |
'fused_rope_mult', | |
backend="nvrtc" | |
) | |
with open(path.join(path.dirname(__file__), "fused_rope_pos_mult_mqa.cu"), "r") as f: | |
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)).replace('__QPK__', str(qpk)).replace('__ROPE_THETA__', str(config.rope_theta)) | |
self._fused_rope_pos_mult = cp.RawKernel( | |
kernel_code, | |
'fused_rope_pos_mult', | |
backend="nvrtc" | |
) | |
with open(path.join(path.dirname(__file__), "fused_mult_len.cu"), "r") as f: | |
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)).replace('__QPK__', str(qpk)) | |
self._fused_mult = cp.RawKernel( | |
kernel_code, | |
'fused_mult', | |
backend="nvrtc" | |
) | |
def quantize(self, x): | |
n = x.numel() // x.shape[-1] | |
codes = torch.empty(self.nsq, n, dtype=torch.uint8, device=x.device) | |
blocks_per_grid = (self.nsq, ) | |
threads_per_block = (1024, ) | |
self._quantize(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ | |
self.codebook.data_ptr(), | |
x.data_ptr(), | |
codes.data_ptr(), | |
n | |
]) | |
return codes | |
def dequantize(self, codes): | |
n = codes.numel() // codes.shape[0] | |
x = torch.zeros(n, self.nsq * self.d, dtype=torch.float16, device=codes.device) | |
blocks_per_grid = (self.nsq, ) | |
threads_per_block = (1024, ) | |
self._dequantize(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ | |
self.codebook.data_ptr(), | |
codes.data_ptr(), | |
x.data_ptr(), | |
n | |
]) | |
return x | |
def dequantize_rope(self, codes): | |
_, batch_size, seq_len = codes.shape | |
n = batch_size * seq_len | |
x = torch.zeros(n, self.nsq * self.d, dtype=torch.float16, device=codes.device) | |
blocks_per_grid = (self.nsq, ) | |
threads_per_block = (1024, ) | |
self._dequantize_rope(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ | |
self.codebook.data_ptr(), | |
codes.data_ptr(), | |
x.data_ptr(), | |
batch_size, seq_len | |
]) | |
return x | |
def fused_rope_mult(self, codes, queries): | |
_, batch_size, k_len = codes.shape | |
_, n_heads, q_len, _ = queries.shape | |
out = torch.zeros(batch_size, n_heads, q_len, k_len, dtype=torch.float16, device=codes.device) | |
blocks_per_grid = (self.nsq, ) | |
threads_per_block = (1024, ) | |
self._fused_rope_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ | |
self.codebook.data_ptr(), | |
codes.data_ptr(), | |
queries.data_ptr(), | |
out.data_ptr(), | |
batch_size, q_len, k_len | |
]) | |
return out | |
def fused_rope_pos_mult(self, codes, queries, position_ids): | |
_, batch_size, k_len = codes.shape | |
_, n_heads, q_len, _ = queries.shape | |
position_offsets = position_ids[:, -1] - k_len + 1 | |
out = torch.zeros(batch_size, n_heads, q_len, k_len, dtype=torch.float32, device=codes.device) | |
blocks_per_grid = (self.nsq, ) | |
threads_per_block = (1024, ) | |
self._fused_rope_pos_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ | |
self.codebook.data_ptr(), | |
codes.data_ptr(), | |
position_offsets.data_ptr(), | |
queries.data_ptr(), | |
out.data_ptr(), | |
batch_size, q_len, k_len | |
]) | |
return out | |
def fused_mult(self, codes, weights, skip_last=0): | |
batch_size, n_heads, q_len, k_len = weights.shape | |
out = torch.zeros(batch_size, n_heads, q_len, self.head_dim, dtype=torch.float16, device=codes.device) | |
blocks_per_grid = (self.nsq, ) | |
threads_per_block = (min(1024, batch_size), ) | |
self._fused_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ | |
self.codebook.data_ptr(), | |
codes.data_ptr(), | |
weights.data_ptr(), | |
out.data_ptr(), | |
batch_size, q_len, k_len, k_len - skip_last | |
]) | |
return out | |