Spaces:
Running
Running
import torch | |
from torch import nn | |
from abc import ABC, abstractmethod | |
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, set_module | |
from utils.common.log import logger | |
from .base import FMLoRA_Util, LoRA | |
class ToQKV_WrappedWithLoRA(nn.Module): | |
def __init__(self, fc: nn.Linear, ab_r: int): | |
super(ToQKV_WrappedWithLoRA, self).__init__() | |
self.fc = fc | |
self.ab = self.create_ab_as_linear(fc.weight.data, ab_r) | |
def create_ab_as_linear(self, fc_weight: torch.Tensor, ab_r: int): | |
res = nn.Sequential( | |
LoRA(fc_weight.size(1), fc_weight.size(0) // ab_r, bias=False), | |
LoRA(fc_weight.size(0) // ab_r, fc_weight.size(0), bias=False) | |
).to(fc_weight.device) | |
nn.init.kaiming_uniform_(res[0].weight, a=5 ** 0.5) | |
nn.init.zeros_(res[1].weight) | |
return res | |
def forward(self, x): | |
x1 = self.fc(x) | |
x2 = self.ab(x) | |
return x1 + x2 | |
class FMLoRA_CLIP_Util(FMLoRA_Util): | |
def add_lora_ab_to_fm(self, fm: nn.Module, ab_r: int, samples: dict): | |
fm.eval() | |
for k, v in samples.items(): | |
if isinstance(v, torch.Tensor): | |
samples[k] = v.to(get_model_device(fm)) | |
print(k) | |
o1 = fm(**samples) | |
for name, module in fm.named_modules(): | |
if name.endswith(('k_proj', 'q_proj', 'v_proj')): | |
set_module(fm, name, ToQKV_WrappedWithLoRA(module, ab_r)) | |
o2 = fm(**samples) | |
output_diff = ((o1.logits_per_image - o2.logits_per_image) ** 2).sum() + ((o1.logits_per_text - o2.logits_per_text) ** 2).sum() | |
assert output_diff < 1e-5 | |
return fm | |
def absorb_lora_and_recover_net_structure(self, fm: nn.Module, samples: dict): | |
fm.eval() | |
# print('absorb lora before') | |
for k, v in samples.items(): | |
if isinstance(v, torch.Tensor): | |
samples[k] = v.to(get_model_device(fm)) | |
print(k) | |
o1 = fm(**samples) | |
for name, module in fm.named_modules(): | |
if not isinstance(module, ToQKV_WrappedWithLoRA): | |
continue | |
fc = module.fc | |
ab = module.ab | |
fc.weight.add_(ab[1].weight @ ab[0].weight) | |
set_module(fm, name, fc) | |
# print('absorb lora after') | |
o2 = fm(**samples) | |
output_diff = ((o1.logits_per_image - o2.logits_per_image) ** 2).sum() + ((o1.logits_per_text - o2.logits_per_text) ** 2).sum() | |
assert output_diff < 1e-6, output_diff | |
return fm | |