surnamerator / app.py
jefsnacker's picture
fix imports
78ed328
raw
history blame
2.63 kB
import gradio as gr
import huggingface_hub
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
config_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_mlp",
"torch_mlp_config.yaml")
weights_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_mlp",
"mlp_weights.pt")
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
stoi = config['stoi']
itos = {s:i for i,s in stoi.items()}
class MLP(nn.Module):
def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers):
super(MLP, self).__init__()
self.window = window
self.hidden_nodes = hidden_nodes
self.embeddings = embeddings
self.C = nn.Parameter(torch.randn((num_char, embeddings)) * 0.1, requires_grad=True)
self.first = nn.Linear(embeddings*window, hidden_nodes)
self.layers = nn.Sequential()
for i in range(num_layers):
self.layers = self.layers.extend(nn.Sequential(
nn.Linear(hidden_nodes, hidden_nodes, bias=False),
nn.BatchNorm1d(hidden_nodes),
nn.Tanh()))
self.final = nn.Linear(hidden_nodes, num_char)
def forward(self, x):
x = self.C[x]
x = self.first(x.view(-1, self.window*self.embeddings))
x = self.layers(x)
x = self.final(x)
return x
def sample_char(self, x):
logits = self(x)
probs = F.softmax(logits, dim=1)
return torch.multinomial(probs, num_samples=1).item()
mlp = MLP(config['num_char'],
config['hidden_nodes'],
config['embeddings'],
config['window'],
config['num_layers'])
mlp.load_state_dict(torch.load(weights_path))
mlp.eval()
def generate_names(name_start, number_of_names):
names = ""
for _ in range((int)(number_of_names)):
# Initialize name with user input
name = ""
context = [0] * config['window']
for c in name_start.lower():
name += c
context = context[1:] + [stoi[c]]
# Run inference to finish off the name
while True:
ix = mlp.sample_char(context)
context = context[1:] + [ix]
name += itos[ix]
if ix == 0:
break
names += name + "\n"
return names
app = gr.Interface(
fn=generate_names,
inputs=[
gr.Textbox(placeholder="Start name with..."),
gr.Number(value=1)
],
outputs="text",
)
app.launch()