gpt2-quantzed-gguf / quant.py
kyrylokumar's picture
Added extra files
35e23cc verified
raw
history blame
2.39 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import transformers
class ReplacedLinearLayer(nn.Module):
def __init__(self, input_dim, output_dim, if_conv=True):
super().__init__()
self.register_buffer('weights', torch.zeros([output_dim, input_dim], dtype=torch.int8))
self.register_buffer('scale_matrix', torch.zeros(output_dim, dtype=torch.int8))
# self.register_buffer("bias", torch.zeros((1, output_dim), dtype = torch.float32))
self.bias = None
self.if_conv = if_conv
def forward(self, x):
fp32_weights = self.weights.to(x.dtype)
# print(fp32_weights.shape, self.scales.shape, )
try:
x = F.linear(x, fp32_weights )* self.scales
if self.bias is not None:
x += self.bias
except Exception as e:
print(e)
print(fp32_weights.shape, self.scales.shape, )
exit()
return x
def do_quantization(self, W, ):
if self.if_conv:
W32 = W.clone().squeeze().T
else:
W32 = W.clone()
scales = (torch.max(W32.abs(), dim=-1)[0]/127).to(torch.float32)
self.scales = scales
self.weights = torch.round(W32 / scales[:, None]).to(torch.int8)
def perform_quantization(module, regex='.*'):
pattern = re.compile(regex)
for name, node in module.named_modules():
for name2, child in node.named_children():
if ( isinstance(child, nn.Linear) or isinstance(child, transformers.pytorch_utils.Conv1D) ) and pattern.match(f'{name}.{name2}'):
# print(name, name2, node, child)
fp32_weight, fp32_bias = child.weight, child.bias
quant_module = ReplacedLinearLayer(child.weight.shape[1], child.weight.shape[0], if_conv=isinstance(child, transformers.pytorch_utils.Conv1D))
setattr(node, name2, quant_module)
# print(getattr(node, name2).custom_weights)
# return
getattr(node, name2).do_quantization(fp32_weight)
if fp32_bias is not None:
getattr(node, name2).bias = fp32_bias
# print(getattr(node, name2).weights)
# return