Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoConfig | |
class _MLPVectorProjector(nn.Module): | |
def __init__( | |
self, | |
input_hidden_size: int = 512, | |
lm_hidden_size: int = 2560, | |
num_layers: int = 1, | |
width: int = 4 | |
): | |
super(_MLPVectorProjector, self).__init__() | |
self.mlps = nn.ModuleList() | |
for _ in range(width): | |
mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)] | |
for _ in range(1, num_layers): | |
mlp.append(nn.GELU()) | |
mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False)) | |
self.mlps.append(nn.Sequential(*mlp)) | |
def forward(self, x): | |
return torch.cat([mlp(x) for mlp in self.mlps], dim=-2) | |
def build_mlp_vector_projector( | |
input_hidden_size: int = 512, | |
lm_hidden_size: int = 2560, | |
num_layers: int = 1, | |
num_tokens: int = 4 | |
): | |
return _MLPVectorProjector( | |
input_hidden_size, lm_hidden_size, num_layers, num_tokens | |
) | |