import math import torch import torch.nn as nn import torch.nn.functional as F class LoRAModule(nn.Module): """ LoRA module that replaces the forward method of an original Linear or Conv2D module. """ def __init__( self, lora_name: str, org_module: nn.Module, multiplier: float = 1.0, lora_dim: int = 4, alpha: Optional[float] = None, dropout: Optional[float] = None, rank_dropout: Optional[float] = None, module_dropout: Optional[float] = None, ): """ Args: lora_name (str): Name of the LoRA module. org_module (nn.Module): The original module to wrap. multiplier (float): Scaling factor for the LoRA output. lora_dim (int): The rank of the LoRA decomposition. alpha (float, optional): Scaling factor for LoRA weights. Defaults to lora_dim. dropout (float, optional): Dropout probability. Defaults to None. rank_dropout (float, optional): Dropout probability for rank reduction. Defaults to None. module_dropout (float, optional): Probability of completely dropping the module during training. Defaults to None. """ super().__init__() self.lora_name = lora_name self.multiplier = multiplier self.lora_dim = lora_dim self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout # Determine layer type (Linear or Conv2D) is_conv2d = isinstance(org_module, nn.Conv2d) in_dim = org_module.in_channels if is_conv2d else org_module.in_features out_dim = org_module.out_channels if is_conv2d else org_module.out_features # Define LoRA layers if is_conv2d: self.lora_down = nn.Conv2d(in_dim, lora_dim, kernel_size=org_module.kernel_size, stride=org_module.stride, padding=org_module.padding, bias=False) self.lora_up = nn.Conv2d(lora_dim, out_dim, kernel_size=1, stride=1, bias=False) else: self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) # Initialize weights nn.init.xavier_uniform_(self.lora_down.weight) nn.init.zeros_(self.lora_up.weight) # Set alpha scaling factor self.scale = (alpha if alpha is not None else lora_dim) / lora_dim self.register_buffer("alpha", torch.tensor(self.scale, dtype=torch.float32)) # Store reference to the original module self.org_module = org_module self.org_forward = org_module.forward def apply_to(self): """Replace the forward method of the original module with this module's forward method.""" self.org_module.forward = self.forward del self.org_module def forward(self, x): """ Forward pass for LoRA-enhanced module. """ if self.module_dropout and self.training and torch.rand(1).item() < self.module_dropout: return self.org_forward(x) # Compute LoRA down projection lora_output = self.lora_down(x) # Apply dropout if training if self.training: if self.dropout: lora_output = F.dropout(lora_output, p=self.dropout) if self.rank_dropout: dropout_mask = torch.rand_like(lora_output) > self.rank_dropout lora_output *= dropout_mask scale_factor = 1.0 / (1.0 - self.rank_dropout) lora_output *= scale_factor # Compute LoRA up projection lora_output = self.lora_up(lora_output) # Combine with original output return self.org_forward(x) + lora_output * self.multiplier * self.scale