Spaces:
Runtime error
Runtime error
import gradio as gr | |
import huggingface_hub | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import yaml | |
mlp_config_path = huggingface_hub.hf_hub_download( | |
"jefsnacker/surname_generator", | |
"torch_mlp_config.yaml") | |
mlp_weights_path = huggingface_hub.hf_hub_download( | |
"jefsnacker/surname_generator", | |
"mlp_weights.pt") | |
wavenet_config_path = huggingface_hub.hf_hub_download( | |
"jefsnacker/surname_generator", | |
"wavenet_config.yaml") | |
wavenet_weights_path = huggingface_hub.hf_hub_download( | |
"jefsnacker/surname_generator", | |
"wavenet_weights.pt") | |
with open(mlp_config_path, 'r') as file: | |
mlp_config = yaml.safe_load(file) | |
with open(wavenet_config_path, 'r') as file: | |
wavenet_config = yaml.safe_load(file) | |
class MLP(nn.Module): | |
def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers): | |
super(MLP, self).__init__() | |
self.window = window | |
self.hidden_nodes = hidden_nodes | |
self.embeddings = embeddings | |
self.C = nn.Parameter(torch.randn((num_char, embeddings)) * 0.1, requires_grad=True) | |
self.first = nn.Linear(embeddings*window, hidden_nodes) | |
self.layers = nn.Sequential() | |
for i in range(num_layers): | |
self.layers = self.layers.extend(nn.Sequential( | |
nn.Linear(hidden_nodes, hidden_nodes, bias=False), | |
nn.BatchNorm1d(hidden_nodes), | |
nn.Tanh())) | |
self.final = nn.Linear(hidden_nodes, num_char) | |
def forward(self, x): | |
x = self.C[x] | |
x = self.first(x.view(-1, self.window*self.embeddings)) | |
x = self.layers(x) | |
x = self.final(x) | |
return x | |
def sample_char(self, x): | |
logits = self(x) | |
probs = F.softmax(logits, dim=1) | |
return torch.multinomial(probs, num_samples=1).item() | |
mlp = MLP(config['num_char'], | |
config['hidden_nodes'], | |
config['embeddings'], | |
config['window'], | |
config['num_layers']) | |
mlp.load_state_dict(torch.load(weights_path)) | |
mlp.eval() | |
class WaveNet(nn.Module): | |
def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers): | |
super(WaveNet, self).__init__() | |
self.window = window | |
self.hidden_nodes = hidden_nodes | |
self.embeddings = embeddings | |
self.layers = nn.Sequential( | |
nn.Embedding(num_char, embeddings) | |
) | |
for i in range(num_layers): | |
if i == 0: | |
nodes = window | |
else: | |
nodes = hidden_nodes | |
self.layers = self.layers.extend(nn.Sequential( | |
nn.Conv1d(nodes, hidden_nodes, kernel_size=2, stride=1, bias=False), | |
nn.BatchNorm1d(hidden_nodes), | |
nn.Tanh())) | |
self.layers = self.layers.extend(nn.Sequential( | |
nn.Flatten(), | |
nn.Linear(hidden_nodes*(embeddings-num_layers), num_char) | |
)) | |
def forward(self, x): | |
return self.layers(x) | |
def sample_char(self, x): | |
logits = self(x) | |
probs = F.softmax(logits, dim=1) | |
return torch.multinomial(probs, num_samples=1).item() | |
wavenet = WaveNet(wavenet_config['num_char'], | |
wavenet_config['hidden_nodes'], | |
wavenet_config['embeddings'], | |
wavenet_config['window'], | |
wavenet_config['num_layers']) | |
wavenet.load_state_dict(torch.load(wavenet_weights_path)) | |
wavenet.eval() | |
def generate_names(name_start, number_of_names, model): | |
if model == "MLP": | |
stoi = mlp_config['stoi'] | |
window = mlp_config['window'] | |
elif model == "WaveNet": | |
stoi = wavenet_config['stoi'] | |
window = wavenet_config['window'] | |
else: | |
raise Exception("Model not selected") | |
itos = {s:i for i,s in stoi.items()} | |
names = "" | |
for _ in range((int)(number_of_names)): | |
# Initialize name with user input | |
name = "" | |
context = [0] * window | |
for c in name_start.lower(): | |
name += c | |
context = context[1:] + [stoi[c]] | |
# Run inference to finish off the name | |
while True: | |
x = torch.tensor(context).view(1, -1) | |
if model == "MLP": | |
ix = mlp.sample_char(x) | |
elif model == "WaveNet": | |
ix = wavenet.sample_char(x) | |
else: | |
raise Exception("Model not selected") | |
context = context[1:] + [ix] | |
name += itos[ix] | |
if ix == 0: | |
break | |
names += name + "\n" | |
return names | |
demo = gr.Interface( | |
fn=generate_names, | |
inputs=[ | |
gr.Textbox(placeholder="Start name with..."), | |
gr.Number(value=5), | |
gr.Dropdown(["MLP", "WaveNet"], value="WaveNet"), | |
], | |
outputs="text", | |
) | |
demo.launch() |