Spaces:
Runtime error
Runtime error
File size: 2,629 Bytes
243da15 78ed328 243da15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import gradio as gr
import huggingface_hub
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
config_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_mlp",
"torch_mlp_config.yaml")
weights_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_mlp",
"mlp_weights.pt")
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
stoi = config['stoi']
itos = {s:i for i,s in stoi.items()}
class MLP(nn.Module):
def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers):
super(MLP, self).__init__()
self.window = window
self.hidden_nodes = hidden_nodes
self.embeddings = embeddings
self.C = nn.Parameter(torch.randn((num_char, embeddings)) * 0.1, requires_grad=True)
self.first = nn.Linear(embeddings*window, hidden_nodes)
self.layers = nn.Sequential()
for i in range(num_layers):
self.layers = self.layers.extend(nn.Sequential(
nn.Linear(hidden_nodes, hidden_nodes, bias=False),
nn.BatchNorm1d(hidden_nodes),
nn.Tanh()))
self.final = nn.Linear(hidden_nodes, num_char)
def forward(self, x):
x = self.C[x]
x = self.first(x.view(-1, self.window*self.embeddings))
x = self.layers(x)
x = self.final(x)
return x
def sample_char(self, x):
logits = self(x)
probs = F.softmax(logits, dim=1)
return torch.multinomial(probs, num_samples=1).item()
mlp = MLP(config['num_char'],
config['hidden_nodes'],
config['embeddings'],
config['window'],
config['num_layers'])
mlp.load_state_dict(torch.load(weights_path))
mlp.eval()
def generate_names(name_start, number_of_names):
names = ""
for _ in range((int)(number_of_names)):
# Initialize name with user input
name = ""
context = [0] * config['window']
for c in name_start.lower():
name += c
context = context[1:] + [stoi[c]]
# Run inference to finish off the name
while True:
ix = mlp.sample_char(context)
context = context[1:] + [ix]
name += itos[ix]
if ix == 0:
break
names += name + "\n"
return names
app = gr.Interface(
fn=generate_names,
inputs=[
gr.Textbox(placeholder="Start name with..."),
gr.Number(value=1)
],
outputs="text",
)
app.launch() |