Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,846 Bytes
be190eb 125014f be190eb 125014f be190eb 125014f be190eb 125014f ce92feb 125014f ce92feb 125014f be190eb 125014f be190eb 125014f be190eb 125014f be190eb 125014f be190eb 125014f be190eb 125014f be190eb 125014f be190eb 125014f be190eb 125014f be190eb 125014f be190eb 125014f be190eb 125014f be190eb 125014f be190eb 125014f b3c8e03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoRAModule(nn.Module):
"""
LoRA module that replaces the forward method of an original Linear or Conv2D module.
"""
def __init__(
self,
lora_name: str,
org_module: nn.Module,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: Optional[float] = None,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
):
"""
Args:
lora_name (str): Name of the LoRA module.
org_module (nn.Module): The original module to wrap.
multiplier (float): Scaling factor for the LoRA output.
lora_dim (int): The rank of the LoRA decomposition.
alpha (float, optional): Scaling factor for LoRA weights. Defaults to lora_dim.
dropout (float, optional): Dropout probability. Defaults to None.
rank_dropout (float, optional): Dropout probability for rank reduction. Defaults to None.
module_dropout (float, optional): Probability of completely dropping the module during training. Defaults to None.
"""
super().__init__()
self.lora_name = lora_name
self.multiplier = multiplier
self.lora_dim = lora_dim
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
# Determine layer type (Linear or Conv2D)
is_conv2d = isinstance(org_module, nn.Conv2d)
in_dim = org_module.in_channels if is_conv2d else org_module.in_features
out_dim = org_module.out_channels if is_conv2d else org_module.out_features
# Define LoRA layers
if is_conv2d:
self.lora_down = nn.Conv2d(in_dim, lora_dim, kernel_size=org_module.kernel_size,
stride=org_module.stride, padding=org_module.padding, bias=False)
self.lora_up = nn.Conv2d(lora_dim, out_dim, kernel_size=1, stride=1, bias=False)
else:
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
# Initialize weights
nn.init.xavier_uniform_(self.lora_down.weight)
nn.init.zeros_(self.lora_up.weight)
# Set alpha scaling factor
self.scale = (alpha if alpha is not None else lora_dim) / lora_dim
self.register_buffer("alpha", torch.tensor(self.scale, dtype=torch.float32))
# Store reference to the original module
self.org_module = org_module
self.org_forward = org_module.forward
def apply_to(self):
"""Replace the forward method of the original module with this module's forward method."""
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
"""
Forward pass for LoRA-enhanced module.
"""
if self.module_dropout and self.training and torch.rand(1).item() < self.module_dropout:
return self.org_forward(x)
# Compute LoRA down projection
lora_output = self.lora_down(x)
# Apply dropout if training
if self.training:
if self.dropout:
lora_output = F.dropout(lora_output, p=self.dropout)
if self.rank_dropout:
dropout_mask = torch.rand_like(lora_output) > self.rank_dropout
lora_output *= dropout_mask
scale_factor = 1.0 / (1.0 - self.rank_dropout)
lora_output *= scale_factor
# Compute LoRA up projection
lora_output = self.lora_up(lora_output)
# Combine with original output
return self.org_forward(x) + lora_output * self.multiplier * self.scale |