import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoConfig class _MLPVectorProjector(nn.Module): def __init__( self, input_hidden_size: int = 512, lm_hidden_size: int = 2560, num_layers: int = 1, width: int = 4 ): super(_MLPVectorProjector, self).__init__() self.mlps = nn.ModuleList() for _ in range(width): mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)] for _ in range(1, num_layers): mlp.append(nn.GELU()) mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False)) self.mlps.append(nn.Sequential(*mlp)) def forward(self, x): return torch.cat([mlp(x) for mlp in self.mlps], dim=-2) def build_mlp_vector_projector( input_hidden_size: int = 512, lm_hidden_size: int = 2560, num_layers: int = 1, num_tokens: int = 4 ): return _MLPVectorProjector( input_hidden_size, lm_hidden_size, num_layers, num_tokens )