Spaces:
Sleeping
Sleeping
# Copyright (c) Microsoft Corporation. | |
# SPDX-License-Identifier: Apache-2.0 | |
# DeepSpeed Team | |
import math | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from deepspeed.compression.helper import recursive_getattr, recursive_setattr | |
import deepspeed | |
class LinearLayer_LoRA(nn.Module): | |
# an simple implementation of LoRA | |
# for now only support Linear Layer | |
def __init__(self, | |
weight, | |
lora_dim=0, | |
lora_scaling=1, | |
lora_droppout=0, | |
bias=None): | |
super(LinearLayer_LoRA, self).__init__() | |
self.weight = weight | |
self.bias = bias | |
if lora_dim <= 0: | |
raise ValueError( | |
"You are training to use LoRA, whose reduced dim should be larger than 1" | |
) | |
try: | |
# for zero stage 3 | |
rows, columns = weight.ds_shape | |
except: | |
rows, columns = weight.shape | |
self.lora_right_weight = nn.Parameter(torch.zeros( | |
columns, | |
lora_dim)) # apply transpose so in forward we do not need to | |
self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows)) | |
self.lora_scaling = lora_scaling / lora_dim | |
if lora_droppout > 0: | |
self.lora_dropout = nn.Dropout(lora_droppout) | |
else: | |
self.lora_dropout = nn.Identity() | |
self.reset_parameters() | |
# disable the original weight gradient | |
self.weight.requires_grad = False | |
# fuse LoRA to the original weight | |
self.fuse_lora = False | |
def eval(self): | |
self.lora_dropout.eval() | |
# self.fuse_lora_weight() | |
def train(self, mode=True): | |
self.lora_dropout.train(mode) | |
# self.unfuse_lora_weight() | |
def reset_parameters(self): | |
nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5)) | |
nn.init.zeros_(self.lora_left_weight) | |
def fuse_lora_weight(self): | |
if not self.fuse_lora: | |
self.weight.data += self.lora_scaling * torch.matmul( | |
self.lora_left_weight.t(), self.lora_right_weight.t()) | |
self.fuse_lora = True | |
def unfuse_lora_weight(self): | |
if self.fuse_lora: | |
self.weight.data -= self.lora_scaling * torch.matmul( | |
self.lora_left_weight.t(), self.lora_right_weight.t()) | |
self.fuse_lora = False | |
def forward(self, input): | |
if self.fuse_lora: | |
return F.linear(input, self.weight, self.bias) | |
else: | |
return F.linear( | |
input, self.weight, | |
self.bias) + (self.lora_dropout(input) @ self.lora_right_weight | |
# convert the linear layer to LoRA | |
def convert_linear_layer_to_lora(model, | |
part_module_name, | |
lora_dim=0, | |
lora_scaling=1, | |
lora_droppout=0): | |
repalce_name = [] | |
for name, module in model.named_modules(): | |
if isinstance(module, nn.Linear) and part_module_name in name: | |
repalce_name.append(name) | |
for name in repalce_name: | |
module = recursive_getattr(model, name) | |
tmp = LinearLayer_LoRA( | |
module.weight, lora_dim, lora_scaling, lora_droppout, | |
module.bias).to(module.weight.device).to(module.weight.dtype) | |
recursive_setattr(model, name, tmp) | |
return model | |
def _z3_params_to_fetch(param_list): | |
return [ | |
p for p in param_list | |
if hasattr(p, 'ds_id') and p.ds_status == deepspeed.runtime.zero. | |
partition_parameters.ZeroParamStatus.NOT_AVAILABLE | |
] | |
# convert the LoRA layer to linear layer | |
def convert_lora_to_linear_layer(model): | |
repalce_name = [] | |
for name, module in model.named_modules(): | |
if isinstance(module, LinearLayer_LoRA): | |
repalce_name.append(name) | |
for name in repalce_name: | |
module = recursive_getattr(model, name) | |
zero_stage_3 = hasattr(module.weight, 'ds_id') | |
with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([ | |
module.weight, module.bias, module.lora_left_weight, | |
module.lora_right_weight | |
]), | |
modifier_rank=0, | |
enabled=zero_stage_3): | |
module.fuse_lora_weight() | |
return model | |
def only_optimize_lora_parameters(model): | |
# turn off the gradient of all the parameters except the LoRA parameters | |
for name, param in model.named_parameters(): | |
if "lora_right_weight" in name or "lora_left_weight" in name: | |
param.requires_grad = True | |
else: | |
param.requires_grad = False | |
return model | |