Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,334 Bytes
85e172b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
import numpy as np
from . import utils
from torch import nn
class CustomModule(nn.Module):
"""A simple two layer type I MLP structure.
"""
def __init__(self, w1_weight=None, w2_bias=None, w2_weight=None, act='gelu'):
super().__init__()
self.linear1 = nn.Linear(w1_weight.shape[1], w1_weight.shape[0])
self.linear2 = nn.Linear(w1_weight.shape[0], w1_weight.shape[1])
self.act = utils.load_activation(act)
self.linear1.weight = nn.Parameter(w1_weight.float())
self.linear1.bias = nn.Parameter(w2_bias.float())
self.linear2.weight = nn.Parameter(w2_weight.T.float())
self.linear2.bias = nn.Parameter(torch.zeros_like(self.linear2.bias))
def forward(self, x):
return self.linear2(self.act(self.linear1(x)))
class CustomNormModule(nn.Module):
"""A simple two layer type I MLP structure.
"""
def __init__(self,
w1_weight=None,
w1_bias = None,
w2_weight=None,
centroid=None,
norm_weight=None,
norm_bias=None,
add_norm = True,
return_w1 = False,
act='relu'
):
super().__init__()
self.linear1 = nn.Linear(w1_weight.shape[1], w1_weight.shape[0])
self.linear2 = nn.Linear(w1_weight.shape[0], w1_weight.shape[1])
self.act = utils.load_activation(act)
self.centroid = centroid
self.norm_weight = norm_weight
self.norm_bias = norm_bias
if self.norm_bias is None: self.norm_bias = 0
self.add_norm = add_norm
self.return_w1 = return_w1
self.linear1.weight = nn.Parameter(w1_weight)
if w1_bias is not None: self.linear1.bias = nn.Parameter(w1_bias)
self.linear2.weight = nn.Parameter(w2_weight.T)
self.linear2.bias = nn.Parameter(torch.zeros_like(self.linear2.bias).to(w1_weight.dtype).cuda())
def forward(self, x):
# normalisation (part I)
x = (x - self.norm_bias) / self.norm_weight / np.sqrt(self.centroid.shape[0])
x = x - self.centroid
if self.add_norm:
x = x / torch.norm(x, dim=-1)[:,:,None]
w1_output = self.act(self.linear1(x))
if self.return_w1:
return w1_output
w2_output = self.linear2(w1_output)
return w2_output
class ModifiedMLP(nn.Module):
"""Modifed MLP structure
"""
def __init__(self, original_mlp, custom_module):
super(ModifiedMLP, self).__init__()
self.original_mlp = original_mlp
self.custom_module = custom_module
def forward(self, x):
# Get the output from the original MLP
o = self.original_mlp(x)
# Pass the output through the CustomModule
return o + self.custom_module(x)
class ModifieMambadMLP(nn.Module):
"""Modifed MLP structure
"""
def __init__(self, original_mlp, custom_module):
super(ModifieMambadMLP, self).__init__()
self.original_mlp = original_mlp
self.custom_module = custom_module
def forward(self, x, cache_params=None):
# Get the output from the original MLP
o = self.original_mlp(x, cache_params=cache_params)
# Pass the output through the CustomModule
return o + self.custom_module(x)
|