import torch.nn as nn import torch.nn.init as init import torch class Swish(nn.Module): def __init__(self): super(Swish, self).__init__() def forward(self, x): return x * torch.sigmoid(x) class Adapter(nn.Module): def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1): super().__init__() self.query_length = query_length self.dropout_prob = dropout_prob self.adapter_norm = adapter_norm self.dropout = nn.Dropout(p=self.dropout_prob) self.c_fc = nn.Linear(input_size, input_size*2) self.act = Swish() self.c_proj = nn.Linear(input_size*2, output_size) if adapter_norm == "layer_norm": self.norm = nn.LayerNorm([self.query_length, output_size]) elif adapter_norm == "batch_norm": self.norm = nn.BatchNorm1d(self.query_length) self.init_type = init_type.lower() self._initialize_weights() def forward(self, hidden_states): hidden_states = self.dropout(hidden_states) hidden_states = self.c_fc(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.c_proj(hidden_states) hidden_states = self.norm(hidden_states) return hidden_states def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): if self.init_type == "glorot": init.xavier_uniform_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif self.init_type == "normal": init.normal_(m.weight, mean=0, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) else: raise ValueError("Invalid initialization type specified.")