AnyModeAssistant / model.py
venkat-natchi's picture
Upload 8 files
f315cdb verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoConfig
class _MLPVectorProjector(nn.Module):
def __init__(
self,
input_hidden_size: int = 512,
lm_hidden_size: int = 2560,
num_layers: int = 1,
width: int = 4
):
super(_MLPVectorProjector, self).__init__()
self.mlps = nn.ModuleList()
for _ in range(width):
mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
for _ in range(1, num_layers):
mlp.append(nn.GELU())
mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False))
self.mlps.append(nn.Sequential(*mlp))
def forward(self, x):
return torch.cat([mlp(x) for mlp in self.mlps], dim=-2)
def build_mlp_vector_projector(
input_hidden_size: int = 512,
lm_hidden_size: int = 2560,
num_layers: int = 1,
num_tokens: int = 4
):
return _MLPVectorProjector(
input_hidden_size, lm_hidden_size, num_layers, num_tokens
)