Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class LoRAModule(nn.Module): | |
""" | |
LoRA module that replaces the forward method of an original Linear or Conv2D module. | |
""" | |
def __init__( | |
self, | |
lora_name: str, | |
org_module: nn.Module, | |
multiplier: float = 1.0, | |
lora_dim: int = 4, | |
alpha: Optional[float] = None, | |
dropout: Optional[float] = None, | |
rank_dropout: Optional[float] = None, | |
module_dropout: Optional[float] = None, | |
): | |
""" | |
Args: | |
lora_name (str): Name of the LoRA module. | |
org_module (nn.Module): The original module to wrap. | |
multiplier (float): Scaling factor for the LoRA output. | |
lora_dim (int): The rank of the LoRA decomposition. | |
alpha (float, optional): Scaling factor for LoRA weights. Defaults to lora_dim. | |
dropout (float, optional): Dropout probability. Defaults to None. | |
rank_dropout (float, optional): Dropout probability for rank reduction. Defaults to None. | |
module_dropout (float, optional): Probability of completely dropping the module during training. Defaults to None. | |
""" | |
super().__init__() | |
self.lora_name = lora_name | |
self.multiplier = multiplier | |
self.lora_dim = lora_dim | |
self.dropout = dropout | |
self.rank_dropout = rank_dropout | |
self.module_dropout = module_dropout | |
# Determine layer type (Linear or Conv2D) | |
is_conv2d = isinstance(org_module, nn.Conv2d) | |
in_dim = org_module.in_channels if is_conv2d else org_module.in_features | |
out_dim = org_module.out_channels if is_conv2d else org_module.out_features | |
# Define LoRA layers | |
if is_conv2d: | |
self.lora_down = nn.Conv2d(in_dim, lora_dim, kernel_size=org_module.kernel_size, | |
stride=org_module.stride, padding=org_module.padding, bias=False) | |
self.lora_up = nn.Conv2d(lora_dim, out_dim, kernel_size=1, stride=1, bias=False) | |
else: | |
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) | |
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) | |
# Initialize weights | |
nn.init.xavier_uniform_(self.lora_down.weight) | |
nn.init.zeros_(self.lora_up.weight) | |
# Set alpha scaling factor | |
self.scale = (alpha if alpha is not None else lora_dim) / lora_dim | |
self.register_buffer("alpha", torch.tensor(self.scale, dtype=torch.float32)) | |
# Store reference to the original module | |
self.org_module = org_module | |
self.org_forward = org_module.forward | |
def apply_to(self): | |
"""Replace the forward method of the original module with this module's forward method.""" | |
self.org_module.forward = self.forward | |
del self.org_module | |
def forward(self, x): | |
""" | |
Forward pass for LoRA-enhanced module. | |
""" | |
if self.module_dropout and self.training and torch.rand(1).item() < self.module_dropout: | |
return self.org_forward(x) | |
# Compute LoRA down projection | |
lora_output = self.lora_down(x) | |
# Apply dropout if training | |
if self.training: | |
if self.dropout: | |
lora_output = F.dropout(lora_output, p=self.dropout) | |
if self.rank_dropout: | |
dropout_mask = torch.rand_like(lora_output) > self.rank_dropout | |
lora_output *= dropout_mask | |
scale_factor = 1.0 / (1.0 - self.rank_dropout) | |
lora_output *= scale_factor | |
# Compute LoRA up projection | |
lora_output = self.lora_up(lora_output) | |
# Combine with original output | |
return self.org_forward(x) + lora_output * self.multiplier * self.scale |