|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import re |
|
import transformers |
|
|
|
class ReplacedLinearLayer(nn.Module): |
|
def __init__(self, input_dim, output_dim, if_conv=True): |
|
super().__init__() |
|
|
|
self.register_buffer('weights', torch.zeros([output_dim, input_dim], dtype=torch.int8)) |
|
self.register_buffer('scale_matrix', torch.zeros(output_dim, dtype=torch.int8)) |
|
|
|
|
|
self.bias = None |
|
self.if_conv = if_conv |
|
|
|
def forward(self, x): |
|
fp32_weights = self.weights.to(x.dtype) |
|
|
|
try: |
|
x = F.linear(x, fp32_weights )* self.scales |
|
if self.bias is not None: |
|
x += self.bias |
|
except Exception as e: |
|
print(e) |
|
print(fp32_weights.shape, self.scales.shape, ) |
|
|
|
exit() |
|
return x |
|
|
|
def do_quantization(self, W, ): |
|
if self.if_conv: |
|
W32 = W.clone().squeeze().T |
|
else: |
|
W32 = W.clone() |
|
|
|
scales = (torch.max(W32.abs(), dim=-1)[0]/127).to(torch.float32) |
|
self.scales = scales |
|
self.weights = torch.round(W32 / scales[:, None]).to(torch.int8) |
|
|
|
|
|
def perform_quantization(module, regex='.*'): |
|
pattern = re.compile(regex) |
|
for name, node in module.named_modules(): |
|
for name2, child in node.named_children(): |
|
if ( isinstance(child, nn.Linear) or isinstance(child, transformers.pytorch_utils.Conv1D) ) and pattern.match(f'{name}.{name2}'): |
|
|
|
fp32_weight, fp32_bias = child.weight, child.bias |
|
|
|
quant_module = ReplacedLinearLayer(child.weight.shape[1], child.weight.shape[0], if_conv=isinstance(child, transformers.pytorch_utils.Conv1D)) |
|
setattr(node, name2, quant_module) |
|
|
|
|
|
getattr(node, name2).do_quantization(fp32_weight) |
|
if fp32_bias is not None: |
|
getattr(node, name2).bias = fp32_bias |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|