File size: 3,846 Bytes
be190eb
 
125014f
 
be190eb
125014f
be190eb
125014f
be190eb
 
 
 
125014f
 
ce92feb
 
125014f
ce92feb
 
 
125014f
be190eb
125014f
 
 
 
 
 
 
 
 
be190eb
 
125014f
be190eb
 
 
 
 
 
125014f
 
 
 
be190eb
125014f
 
 
 
 
be190eb
125014f
 
be190eb
125014f
 
 
be190eb
125014f
 
 
be190eb
125014f
 
 
be190eb
125014f
 
 
 
be190eb
125014f
 
 
 
 
 
be190eb
125014f
 
be190eb
125014f
 
 
 
 
 
 
 
 
be190eb
125014f
 
be190eb
125014f
b3c8e03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class LoRAModule(nn.Module):
    """
    LoRA module that replaces the forward method of an original Linear or Conv2D module.
    """

    def __init__(
        self,
        lora_name: str,
        org_module: nn.Module,
        multiplier: float = 1.0,
        lora_dim: int = 4,
        alpha: Optional[float] = None,
        dropout: Optional[float] = None,
        rank_dropout: Optional[float] = None,
        module_dropout: Optional[float] = None,
    ):
        """
        Args:
            lora_name (str): Name of the LoRA module.
            org_module (nn.Module): The original module to wrap.
            multiplier (float): Scaling factor for the LoRA output.
            lora_dim (int): The rank of the LoRA decomposition.
            alpha (float, optional): Scaling factor for LoRA weights. Defaults to lora_dim.
            dropout (float, optional): Dropout probability. Defaults to None.
            rank_dropout (float, optional): Dropout probability for rank reduction. Defaults to None.
            module_dropout (float, optional): Probability of completely dropping the module during training. Defaults to None.
        """
        super().__init__()
        self.lora_name = lora_name
        self.multiplier = multiplier
        self.lora_dim = lora_dim
        self.dropout = dropout
        self.rank_dropout = rank_dropout
        self.module_dropout = module_dropout

        # Determine layer type (Linear or Conv2D)
        is_conv2d = isinstance(org_module, nn.Conv2d)
        in_dim = org_module.in_channels if is_conv2d else org_module.in_features
        out_dim = org_module.out_channels if is_conv2d else org_module.out_features

        # Define LoRA layers
        if is_conv2d:
            self.lora_down = nn.Conv2d(in_dim, lora_dim, kernel_size=org_module.kernel_size,
                                       stride=org_module.stride, padding=org_module.padding, bias=False)
            self.lora_up = nn.Conv2d(lora_dim, out_dim, kernel_size=1, stride=1, bias=False)
        else:
            self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
            self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)

        # Initialize weights
        nn.init.xavier_uniform_(self.lora_down.weight)
        nn.init.zeros_(self.lora_up.weight)

        # Set alpha scaling factor
        self.scale = (alpha if alpha is not None else lora_dim) / lora_dim
        self.register_buffer("alpha", torch.tensor(self.scale, dtype=torch.float32))

        # Store reference to the original module
        self.org_module = org_module
        self.org_forward = org_module.forward

    def apply_to(self):
        """Replace the forward method of the original module with this module's forward method."""
        self.org_module.forward = self.forward
        del self.org_module

    def forward(self, x):
        """
        Forward pass for LoRA-enhanced module.
        """
        if self.module_dropout and self.training and torch.rand(1).item() < self.module_dropout:
            return self.org_forward(x)

        # Compute LoRA down projection
        lora_output = self.lora_down(x)

        # Apply dropout if training
        if self.training:
            if self.dropout:
                lora_output = F.dropout(lora_output, p=self.dropout)
            if self.rank_dropout:
                dropout_mask = torch.rand_like(lora_output) > self.rank_dropout
                lora_output *= dropout_mask
                scale_factor = 1.0 / (1.0 - self.rank_dropout)
                lora_output *= scale_factor

        # Compute LoRA up projection
        lora_output = self.lora_up(lora_output)

        # Combine with original output
        return self.org_forward(x) + lora_output * self.multiplier * self.scale