import torch
import torch.nn as nn

from .gptq import *
from .modelutils import *
from .quant import *
from transformers import BloomForCausalLM as LM
    
class SakuraForCausalLM(LM):
    def __init__(self,*args,**kwargs):
        def noop(*args, **kwargs):
            pass
        torch.nn.init.kaiming_uniform_ = noop
        torch.nn.init.uniform_ = noop
        torch.nn.init.normal_ = noop
        torch.set_default_dtype(torch.half)
        transformers.modeling_utils._init_weights = False
        torch.set_default_dtype(torch.half)
        super().__init__(*args,**kwargs)
        torch.set_default_dtype(torch.float)
        self.eval()
        layers = find_layers(self)
        for name in ['lm_head']:
            if name in layers:
                del layers[name]
        make_quant(self, layers, 4, -1)