File size: 3,334 Bytes
85e172b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

import torch
import numpy as np 

from . import utils

from torch import nn


class CustomModule(nn.Module):
    """A simple two layer type I MLP structure.
    """
    def __init__(self, w1_weight=None, w2_bias=None, w2_weight=None, act='gelu'):
        super().__init__()

        self.linear1 = nn.Linear(w1_weight.shape[1], w1_weight.shape[0])
        self.linear2 = nn.Linear(w1_weight.shape[0], w1_weight.shape[1])
        self.act = utils.load_activation(act)

        self.linear1.weight = nn.Parameter(w1_weight.float())
        self.linear1.bias = nn.Parameter(w2_bias.float())
        self.linear2.weight = nn.Parameter(w2_weight.T.float())
        self.linear2.bias = nn.Parameter(torch.zeros_like(self.linear2.bias))

    def forward(self, x):
        return self.linear2(self.act(self.linear1(x)))


class CustomNormModule(nn.Module):
    """A simple two layer type I MLP structure.
    """
    def __init__(self, 
            w1_weight=None, 
            w1_bias = None,
            w2_weight=None, 
            centroid=None, 
            norm_weight=None, 
            norm_bias=None,
            add_norm = True, 
            return_w1 = False,
            act='relu'
        ):
        super().__init__()

        self.linear1 = nn.Linear(w1_weight.shape[1], w1_weight.shape[0])
        self.linear2 = nn.Linear(w1_weight.shape[0], w1_weight.shape[1])
        self.act = utils.load_activation(act)

        self.centroid = centroid
        self.norm_weight = norm_weight
        self.norm_bias = norm_bias
        if self.norm_bias is None: self.norm_bias = 0
        self.add_norm = add_norm

        self.return_w1 = return_w1

        self.linear1.weight = nn.Parameter(w1_weight)
        if w1_bias is not None: self.linear1.bias = nn.Parameter(w1_bias)
        self.linear2.weight = nn.Parameter(w2_weight.T)
        self.linear2.bias = nn.Parameter(torch.zeros_like(self.linear2.bias).to(w1_weight.dtype).cuda())

    def forward(self, x):

        # normalisation (part I)
        x = (x - self.norm_bias) / self.norm_weight / np.sqrt(self.centroid.shape[0])

        x = x - self.centroid

        if self.add_norm:
            x = x / torch.norm(x, dim=-1)[:,:,None]

        w1_output = self.act(self.linear1(x))

        if self.return_w1:
            return w1_output

        w2_output = self.linear2(w1_output)
        return w2_output


class ModifiedMLP(nn.Module):
    """Modifed MLP structure
    """
    def __init__(self, original_mlp, custom_module):
        super(ModifiedMLP, self).__init__()
        self.original_mlp = original_mlp
        self.custom_module = custom_module
    
    def forward(self, x):
        # Get the output from the original MLP
        o = self.original_mlp(x)
        # Pass the output through the CustomModule
        return o + self.custom_module(x)


class ModifieMambadMLP(nn.Module):
    """Modifed MLP structure
    """
    def __init__(self, original_mlp, custom_module):
        super(ModifieMambadMLP, self).__init__()
        self.original_mlp = original_mlp
        self.custom_module = custom_module
    
    def forward(self, x, cache_params=None):
        # Get the output from the original MLP
        o = self.original_mlp(x, cache_params=cache_params)
        # Pass the output through the CustomModule
        return o + self.custom_module(x)