import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoConfig

class _MLPVectorProjector(nn.Module):
    def __init__(
        self,
        input_hidden_size: int = 512,
        lm_hidden_size: int = 2560,
        num_layers: int = 1,
        width: int = 4
    ):
        super(_MLPVectorProjector, self).__init__()
        self.mlps = nn.ModuleList()
        for _ in range(width):
            mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
            for _ in range(1, num_layers):
                mlp.append(nn.GELU())
                mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False))
            self.mlps.append(nn.Sequential(*mlp))

    def forward(self, x):
        return torch.cat([mlp(x) for mlp in self.mlps], dim=-2)


def build_mlp_vector_projector(
    input_hidden_size: int = 512, 
    lm_hidden_size: int = 2560,
    num_layers: int = 1,
    num_tokens: int = 4
):
    return _MLPVectorProjector(
        input_hidden_size, lm_hidden_size, num_layers, num_tokens
    )