File size: 1,151 Bytes
b35c785 fabe522 cf462e1 fabe522 b35c785 cf462e1 b35c785 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
import torch.nn as nn
from .gptq import *
# from modelutils import *
from .quant import *
from transformers import BloomForCausalLM as LM
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
return res
class SakuraForCausalLM(LM):
def __init__(self,*args,**kwargs):
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
super().__init__(*args,**kwargs)
torch.set_default_dtype(torch.float)
self.eval()
layers = find_layers(self)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant(self, layers, 4, -1) |