jefsnacker commited on
Commit
3f362c0
·
1 Parent(s): 3027a6c

add gpt nano model

Browse files
Files changed (1) hide show
  1. app.py +207 -3
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
 
3
  import huggingface_hub
@@ -6,8 +9,6 @@ import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
 
9
- import yaml
10
-
11
 
12
  mlp_config_path = huggingface_hub.hf_hub_download(
13
  "jefsnacker/surname_generator",
@@ -25,12 +26,27 @@ wavenet_weights_path = huggingface_hub.hf_hub_download(
25
  "jefsnacker/surname_generator",
26
  "wavenet_weights.pt")
27
 
 
 
 
 
 
 
 
 
28
  with open(mlp_config_path, 'r') as file:
29
  mlp_config = yaml.safe_load(file)
30
 
31
  with open(wavenet_config_path, 'r') as file:
32
  wavenet_config = yaml.safe_load(file)
 
 
 
33
 
 
 
 
 
34
  class MLP(nn.Module):
35
  def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers):
36
  super(MLP, self).__init__()
@@ -75,6 +91,10 @@ mlp = MLP(mlp_config['num_char'],
75
  mlp.load_state_dict(torch.load(mlp_weights_path))
76
  mlp.eval()
77
 
 
 
 
 
78
  class WaveNet(nn.Module):
79
  def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers):
80
  super(WaveNet, self).__init__()
@@ -119,6 +139,185 @@ wavenet = WaveNet(wavenet_config['num_char'],
119
  wavenet.load_state_dict(torch.load(wavenet_weights_path))
120
  wavenet.eval()
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def generate_names(name_start, number_of_names, model):
123
  if model == "MLP":
124
  stoi = mlp_config['stoi']
@@ -126,6 +325,9 @@ def generate_names(name_start, number_of_names, model):
126
  elif model == "WaveNet":
127
  stoi = wavenet_config['stoi']
128
  window = wavenet_config['window']
 
 
 
129
  else:
130
  raise Exception("Model not selected")
131
 
@@ -148,6 +350,8 @@ def generate_names(name_start, number_of_names, model):
148
  ix = mlp.sample_char(x)
149
  elif model == "WaveNet":
150
  ix = wavenet.sample_char(x)
 
 
151
  else:
152
  raise Exception("Model not selected")
153
 
@@ -166,7 +370,7 @@ demo = gr.Interface(
166
  inputs=[
167
  gr.Textbox(placeholder="Start name with..."),
168
  gr.Number(value=5),
169
- gr.Dropdown(["MLP", "WaveNet"], value="WaveNet"),
170
  ],
171
  outputs="text",
172
  )
 
1
+ import math
2
+ import yaml
3
+
4
  import gradio as gr
5
 
6
  import huggingface_hub
 
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
 
 
 
12
 
13
  mlp_config_path = huggingface_hub.hf_hub_download(
14
  "jefsnacker/surname_generator",
 
26
  "jefsnacker/surname_generator",
27
  "wavenet_weights.pt")
28
 
29
+ gpt_nano_config_path = huggingface_hub.hf_hub_download(
30
+ "jefsnacker/surname_generator",
31
+ "gpt_config.yaml")
32
+
33
+ gpt_nano_weights_path = huggingface_hub.hf_hub_download(
34
+ "jefsnacker/surname_generator",
35
+ "gpt_weights.pt")
36
+
37
  with open(mlp_config_path, 'r') as file:
38
  mlp_config = yaml.safe_load(file)
39
 
40
  with open(wavenet_config_path, 'r') as file:
41
  wavenet_config = yaml.safe_load(file)
42
+
43
+ with open(gpt_nano_config_path, 'r') as file:
44
+ gpt_nano_config = yaml.safe_load(file)
45
 
46
+ ##################################################################################
47
+ ## MLP
48
+ ##################################################################################
49
+
50
  class MLP(nn.Module):
51
  def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers):
52
  super(MLP, self).__init__()
 
91
  mlp.load_state_dict(torch.load(mlp_weights_path))
92
  mlp.eval()
93
 
94
+ ##################################################################################
95
+ ## WaveNet
96
+ ##################################################################################
97
+
98
  class WaveNet(nn.Module):
99
  def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers):
100
  super(WaveNet, self).__init__()
 
139
  wavenet.load_state_dict(torch.load(wavenet_weights_path))
140
  wavenet.eval()
141
 
142
+ ##################################################################################
143
+ ## Transformer
144
+ ##################################################################################
145
+
146
+ class NewGELU(nn.Module):
147
+ """
148
+ Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
149
+ """
150
+ def forward(self, x):
151
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
152
+
153
+ class GptAttention(nn.Module):
154
+ """
155
+ For this attention module k = v = q are all the same.
156
+ It's for encoder/decoder only transfomers.
157
+ """
158
+ def __init__(self, config):
159
+ super(GptAttention, self).__init__()
160
+ self.config = config
161
+
162
+ assert self.config["d_model"] % self.config["heads"] == 0
163
+ self.heads = self.config["heads"]
164
+
165
+ self.w_attn = nn.Linear(self.config["d_model"], 3*self.config["d_model"])
166
+ self.head = nn.Linear(self.config["d_model"], self.config["d_model"])
167
+
168
+ self.attn_dropout = nn.Dropout(config["attn_pdrop"])
169
+ self.resid_dropout = nn.Dropout(config["resid_pdrop"])
170
+
171
+ # causal mask to ensure that attention is only applied to the left in the input sequence
172
+ self.register_buffer(
173
+ "bias",
174
+ torch.tril(
175
+ torch.ones(
176
+ self.config["window"],
177
+ self.config["window"])
178
+ ).view(1, 1, self.config["window"], self.config["window"])
179
+ )
180
+
181
+ def forward(self, x):
182
+ B, window, embs = x.shape
183
+
184
+ q, v, k = self.w_attn(x).split(self.config["d_model"], dim=2)
185
+
186
+ # (B, heads, window, embs)
187
+ q = q.view(
188
+ B,
189
+ window,
190
+ self.config["heads"],
191
+ embs // self.config["heads"]
192
+ ).transpose(1, 2)
193
+ k = k.view(
194
+ B,
195
+ window,
196
+ self.config["heads"],
197
+ embs // self.config["heads"]
198
+ ).transpose(1, 2)
199
+ v = v.view(
200
+ B,
201
+ window,
202
+ self.config["heads"],
203
+ embs // self.config["heads"]
204
+ ).transpose(1, 2)
205
+
206
+ # Self-attend: (B, heads, window, embs) x (B, heads, embs, window) -> (B, heads, window, window)
207
+ scores = q @ k.transpose(-2, -1) / math.sqrt(k.size(-1))
208
+ mask = scores.masked_fill(self.bias[:,:,:window,:window] == 0, float('-inf'))
209
+ probs = F.softmax(mask, dim=-1)
210
+ attn = self.attn_dropout(probs)
211
+ attn = probs @ v
212
+ attn = attn.transpose(1, 2).contiguous().view(B, window, embs)
213
+
214
+ return self.resid_dropout(self.head(attn))
215
+
216
+ class FeedForward(nn.Module):
217
+ def __init__(self, config):
218
+ super(FeedForward, self).__init__()
219
+ self.l1 = nn.Linear(config["d_model"], 4*config["d_model"])
220
+ self.l2 = nn.Linear(4*config["d_model"], config["d_model"])
221
+ self.dropout = nn.Dropout(config["resid_pdrop"])
222
+
223
+ def forward(self, x):
224
+ x = NewGELU()(self.l1(x))
225
+ return self.dropout(self.l2(x))
226
+
227
+ class Block(nn.Module):
228
+ def __init__(self, config):
229
+ super(Block, self).__init__()
230
+ self.attn = GptAttention(config)
231
+ self.norm1 = nn.LayerNorm(config["d_model"])
232
+ self.ff = FeedForward(config)
233
+ self.norm2 = nn.LayerNorm(config["d_model"])
234
+
235
+ def forward(self, x):
236
+ x = self.norm1(x + self.attn(x))
237
+ x = self.norm2(x + self.ff(x))
238
+ return x
239
+
240
+ class GPT(nn.Module):
241
+ def __init__(self, config):
242
+ super(GPT, self).__init__()
243
+ self.config = config
244
+
245
+ self.vocab_emb = nn.Embedding(self.config["vocab"], self.config["d_model"])
246
+ self.pos_emb = nn.Embedding(self.config["window"], self.config["d_model"])
247
+ self.emb_dropout = nn.Dropout(config["embd_pdrop"])
248
+
249
+ self.blocks = nn.ModuleList([Block(self.config) for _ in range(self.config["blocks"])])
250
+ self.head_layer_norm = nn.LayerNorm(config["d_model"])
251
+ self.head = nn.Linear(self.config["d_model"], self.config["vocab"])
252
+
253
+ def forward(self, x):
254
+ vocab_emb = self.vocab_emb(x)
255
+ pos_emb = self.pos_emb(torch.arange(0, x.shape[1], dtype=torch.long, device=x.device))
256
+
257
+ x = self.emb_dropout(vocab_emb + pos_emb)
258
+
259
+ for b in self.blocks:
260
+ x = b(x)
261
+
262
+ x = self.head_layer_norm(x)
263
+ x = self.head(x)
264
+
265
+ return x
266
+
267
+ def configure_opt(self):
268
+ p_decay = set()
269
+ p_no_decay = set()
270
+ whitelist_weight_modules = (torch.nn.Linear, )
271
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
272
+ for mn, m in self.named_modules():
273
+ for pn, p in m.named_parameters():
274
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
275
+ # random note: because named_modules and named_parameters are recursive
276
+ # we will see the same tensors p many many times. but doing it this way
277
+ # allows us to know which parent module any tensor p belongs to...
278
+ if pn.endswith('bias'):
279
+ # all biases will not be decayed
280
+ p_no_decay.add(fpn)
281
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
282
+ # weights of whitelist modules will be weight decayed
283
+ p_decay.add(fpn)
284
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
285
+ # weights of blacklist modules will NOT be weight decayed
286
+ p_no_decay.add(fpn)
287
+
288
+ # validate that we considered every parameter
289
+ param_dict = {pn: p for pn, p in self.named_parameters()}
290
+ inter_params = p_decay & p_no_decay
291
+ union_params = p_decay | p_no_decay
292
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
293
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
294
+ % (str(param_dict.keys() - union_params), )
295
+
296
+ # create the pytorch optimizer object
297
+ optim_groups = [
298
+ {"params": [param_dict[pn] for pn in sorted(list(p_decay))], "weight_decay": self.config["weight_decay"]},
299
+ {"params": [param_dict[pn] for pn in sorted(list(p_no_decay))], "weight_decay": 0.0},
300
+ ]
301
+ optimizer = torch.optim.AdamW(
302
+ optim_groups,
303
+ lr=self.config["lr"],
304
+ betas=(self.config["b1"], self.config["b2"])
305
+ )
306
+ return optimizer
307
+
308
+ def sample_char(self, x):
309
+ logits = self(x)
310
+ probs = F.softmax(logits[:,-1,:], dim=1)
311
+ return torch.multinomial(probs, num_samples=1).item()
312
+
313
+ gpt_nano = GPT(gpt_nano_config)
314
+ gpt_nano.load_state_dict(torch.load(gpt_nano_weights_path))
315
+ gpt_nano.eval()
316
+
317
+ ##################################################################################
318
+ ## Gradio App
319
+ ##################################################################################
320
+
321
  def generate_names(name_start, number_of_names, model):
322
  if model == "MLP":
323
  stoi = mlp_config['stoi']
 
325
  elif model == "WaveNet":
326
  stoi = wavenet_config['stoi']
327
  window = wavenet_config['window']
328
+ elif model == "GPT Nano":
329
+ stoi = gpt_nano_config['stoi']
330
+ window = gpt_nano_config['window']
331
  else:
332
  raise Exception("Model not selected")
333
 
 
350
  ix = mlp.sample_char(x)
351
  elif model == "WaveNet":
352
  ix = wavenet.sample_char(x)
353
+ elif model == "GPT Nano":
354
+ ix = gpt_nano.sample_char(x)
355
  else:
356
  raise Exception("Model not selected")
357
 
 
370
  inputs=[
371
  gr.Textbox(placeholder="Start name with..."),
372
  gr.Number(value=5),
373
+ gr.Dropdown(["MLP", "WaveNet", "GPT Nano"], value="GPT Nano"),
374
  ],
375
  outputs="text",
376
  )