Spaces:
Runtime error
Runtime error
import gradio as gr | |
import huggingface_hub | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import yaml | |
config_path = huggingface_hub.hf_hub_download( | |
"jefsnacker/surname_mlp", | |
"torch_mlp_config.yaml") | |
weights_path = huggingface_hub.hf_hub_download( | |
"jefsnacker/surname_mlp", | |
"mlp_weights.pt") | |
with open(config_path, 'r') as file: | |
config = yaml.safe_load(file) | |
stoi = config['stoi'] | |
itos = {s:i for i,s in stoi.items()} | |
class MLP(nn.Module): | |
def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers): | |
super(MLP, self).__init__() | |
self.window = window | |
self.hidden_nodes = hidden_nodes | |
self.embeddings = embeddings | |
self.C = nn.Parameter(torch.randn((num_char, embeddings)) * 0.1, requires_grad=True) | |
self.first = nn.Linear(embeddings*window, hidden_nodes) | |
self.layers = nn.Sequential() | |
for i in range(num_layers): | |
self.layers = self.layers.extend(nn.Sequential( | |
nn.Linear(hidden_nodes, hidden_nodes, bias=False), | |
nn.BatchNorm1d(hidden_nodes), | |
nn.Tanh())) | |
self.final = nn.Linear(hidden_nodes, num_char) | |
def forward(self, x): | |
x = self.C[x] | |
x = self.first(x.view(-1, self.window*self.embeddings)) | |
x = self.layers(x) | |
x = self.final(x) | |
return x | |
def sample_char(self, x): | |
logits = self(x) | |
probs = F.softmax(logits, dim=1) | |
return torch.multinomial(probs, num_samples=1).item() | |
mlp = MLP(config['num_char'], | |
config['hidden_nodes'], | |
config['embeddings'], | |
config['window'], | |
config['num_layers']) | |
mlp.load_state_dict(torch.load(weights_path)) | |
mlp.eval() | |
def generate_names(name_start, number_of_names): | |
names = "" | |
for _ in range((int)(number_of_names)): | |
# Initialize name with user input | |
name = "" | |
context = [0] * config['window'] | |
for c in name_start.lower(): | |
name += c | |
context = context[1:] + [stoi[c]] | |
# Run inference to finish off the name | |
while True: | |
ix = mlp.sample_char(context) | |
context = context[1:] + [ix] | |
name += itos[ix] | |
if ix == 0: | |
break | |
names += name + "\n" | |
return names | |
app = gr.Interface( | |
fn=generate_names, | |
inputs=[ | |
gr.Textbox(placeholder="Start name with..."), | |
gr.Number(value=1) | |
], | |
outputs="text", | |
) | |
app.launch() |