import gradio as gr import huggingface_hub import torch import torch.nn as nn import torch.nn.functional as F import yaml config_path = huggingface_hub.hf_hub_download( "jefsnacker/surname_mlp", "torch_mlp_config.yaml") weights_path = huggingface_hub.hf_hub_download( "jefsnacker/surname_mlp", "mlp_weights.pt") with open(config_path, 'r') as file: config = yaml.safe_load(file) stoi = config['stoi'] itos = {s:i for i,s in stoi.items()} class MLP(nn.Module): def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers): super(MLP, self).__init__() self.window = window self.hidden_nodes = hidden_nodes self.embeddings = embeddings self.C = nn.Parameter(torch.randn((num_char, embeddings)) * 0.1, requires_grad=True) self.first = nn.Linear(embeddings*window, hidden_nodes) self.layers = nn.Sequential() for i in range(num_layers): self.layers = self.layers.extend(nn.Sequential( nn.Linear(hidden_nodes, hidden_nodes, bias=False), nn.BatchNorm1d(hidden_nodes), nn.Tanh())) self.final = nn.Linear(hidden_nodes, num_char) def forward(self, x): x = self.C[x] x = self.first(x.view(-1, self.window*self.embeddings)) x = self.layers(x) x = self.final(x) return x def sample_char(self, x): logits = self(x) probs = F.softmax(logits, dim=1) return torch.multinomial(probs, num_samples=1).item() mlp = MLP(config['num_char'], config['hidden_nodes'], config['embeddings'], config['window'], config['num_layers']) mlp.load_state_dict(torch.load(weights_path)) mlp.eval() def generate_names(name_start, number_of_names): names = "" for _ in range((int)(number_of_names)): # Initialize name with user input name = "" context = [0] * config['window'] for c in name_start.lower(): name += c context = context[1:] + [stoi[c]] # Run inference to finish off the name while True: ix = mlp.sample_char(context) context = context[1:] + [ix] name += itos[ix] if ix == 0: break names += name + "\n" return names app = gr.Interface( fn=generate_names, inputs=[ gr.Textbox(placeholder="Start name with..."), gr.Number(value=1) ], outputs="text", ) app.launch()