File size: 1,151 Bytes
b35c785
 
 
fabe522
cf462e1
fabe522
b35c785
 
cf462e1
 
 
 
 
 
 
 
 
 
b35c785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn

from .gptq import *
# from modelutils import *
from .quant import *
from transformers import BloomForCausalLM as LM

def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res
    
class SakuraForCausalLM(LM):
    def __init__(self,*args,**kwargs):
        def noop(*args, **kwargs):
            pass
        torch.nn.init.kaiming_uniform_ = noop
        torch.nn.init.uniform_ = noop
        torch.nn.init.normal_ = noop
        torch.set_default_dtype(torch.half)
        transformers.modeling_utils._init_weights = False
        torch.set_default_dtype(torch.half)
        super().__init__(*args,**kwargs)
        torch.set_default_dtype(torch.float)
        self.eval()
        layers = find_layers(self)
        for name in ['lm_head']:
            if name in layers:
                del layers[name]
        make_quant(self, layers, 4, -1)