face-to-all-666 / lora.py
primerz's picture
Update lora.py
b3c8e03 verified
raw
history blame
3.85 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoRAModule(nn.Module):
"""
LoRA module that replaces the forward method of an original Linear or Conv2D module.
"""
def __init__(
self,
lora_name: str,
org_module: nn.Module,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: Optional[float] = None,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
):
"""
Args:
lora_name (str): Name of the LoRA module.
org_module (nn.Module): The original module to wrap.
multiplier (float): Scaling factor for the LoRA output.
lora_dim (int): The rank of the LoRA decomposition.
alpha (float, optional): Scaling factor for LoRA weights. Defaults to lora_dim.
dropout (float, optional): Dropout probability. Defaults to None.
rank_dropout (float, optional): Dropout probability for rank reduction. Defaults to None.
module_dropout (float, optional): Probability of completely dropping the module during training. Defaults to None.
"""
super().__init__()
self.lora_name = lora_name
self.multiplier = multiplier
self.lora_dim = lora_dim
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
# Determine layer type (Linear or Conv2D)
is_conv2d = isinstance(org_module, nn.Conv2d)
in_dim = org_module.in_channels if is_conv2d else org_module.in_features
out_dim = org_module.out_channels if is_conv2d else org_module.out_features
# Define LoRA layers
if is_conv2d:
self.lora_down = nn.Conv2d(in_dim, lora_dim, kernel_size=org_module.kernel_size,
stride=org_module.stride, padding=org_module.padding, bias=False)
self.lora_up = nn.Conv2d(lora_dim, out_dim, kernel_size=1, stride=1, bias=False)
else:
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
# Initialize weights
nn.init.xavier_uniform_(self.lora_down.weight)
nn.init.zeros_(self.lora_up.weight)
# Set alpha scaling factor
self.scale = (alpha if alpha is not None else lora_dim) / lora_dim
self.register_buffer("alpha", torch.tensor(self.scale, dtype=torch.float32))
# Store reference to the original module
self.org_module = org_module
self.org_forward = org_module.forward
def apply_to(self):
"""Replace the forward method of the original module with this module's forward method."""
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
"""
Forward pass for LoRA-enhanced module.
"""
if self.module_dropout and self.training and torch.rand(1).item() < self.module_dropout:
return self.org_forward(x)
# Compute LoRA down projection
lora_output = self.lora_down(x)
# Apply dropout if training
if self.training:
if self.dropout:
lora_output = F.dropout(lora_output, p=self.dropout)
if self.rank_dropout:
dropout_mask = torch.rand_like(lora_output) > self.rank_dropout
lora_output *= dropout_mask
scale_factor = 1.0 / (1.0 - self.rank_dropout)
lora_output *= scale_factor
# Compute LoRA up projection
lora_output = self.lora_up(lora_output)
# Combine with original output
return self.org_forward(x) + lora_output * self.multiplier * self.scale