Spaces:
Running
Running
File size: 1,958 Bytes
72f684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch.nn as nn
import torch.nn.init as init
import torch
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
class Adapter(nn.Module):
def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1):
super().__init__()
self.query_length = query_length
self.dropout_prob = dropout_prob
self.adapter_norm = adapter_norm
self.dropout = nn.Dropout(p=self.dropout_prob)
self.c_fc = nn.Linear(input_size, input_size*2)
self.act = Swish()
self.c_proj = nn.Linear(input_size*2, output_size)
if adapter_norm == "layer_norm":
self.norm = nn.LayerNorm([self.query_length, output_size])
elif adapter_norm == "batch_norm":
self.norm = nn.BatchNorm1d(self.query_length)
self.init_type = init_type.lower()
self._initialize_weights()
def forward(self, hidden_states):
hidden_states = self.dropout(hidden_states)
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.norm(hidden_states)
return hidden_states
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
if self.init_type == "glorot":
init.xavier_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif self.init_type == "normal":
init.normal_(m.weight, mean=0, std=0.01)
if m.bias is not None:
init.constant_(m.bias, 0)
else:
raise ValueError("Invalid initialization type specified.")
|