from transformers import LlamaForCausalLM from .configuration_oursvd_llama import CovSVDLlamaConfig import torch.nn as nn import torch.nn.functional as F import torch class CovSVDLinear(nn.Module): def __init__(self, in_features, out_features, rank, bias=True): super().__init__() self.BLinear = nn.Linear(in_features, rank, bias=False) self.ALinear = nn.Linear(rank, out_features, bias=bias) self.weight_residual = nn.Parameter(torch.zeros(out_features, in_features)) self.weight_residual.requires_grad = False def forward(self, input): y = self.BLinear(input) y = self.ALinear(y) + F.linear(input, self.weight_residual) return y class CovSVDLlamaForCausalLM(LlamaForCausalLM): config_class = CovSVDLlamaConfig def __init__(self, config:CovSVDLlamaConfig): super().__init__(config) self.lora_r = config.lora_r full_name_dict = {module: name for name, module in self.named_modules()} linear_info = {} modules = [self] while len(modules) > 0: submodule = modules.pop() for name, raw_linear in submodule.named_children(): if isinstance(raw_linear, nn.Linear): full_name = full_name_dict[raw_linear] linear_info[raw_linear] = { "father": submodule, "name": name, "full_name": full_name, } else: modules.append(raw_linear) for name,module in self.named_modules(): if "lm_head" not in name and isinstance(module, nn.Linear): info=linear_info[module] new_layer=CovSVDLinear(module.in_features, module.out_features, self.lora_r, bias=module.bias is not None) setattr(info["father"], info["name"], new_layer)