tskolm commited on
Commit
febdf4e
·
1 Parent(s): 61e6d0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  import streamlit as st
4
  import sys
5
  import urllib
6
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
7
  import json
8
  import torch
9
 
@@ -11,7 +10,7 @@ def generate(tokenizer, model, text, features):
11
  generated = tokenizer("<|startoftext|> <|titlestart|>{}<|titleend|>".format(text), return_tensors="pt").input_ids
12
  sample_outputs = model.generate(
13
  generated, do_sample=True, top_k=50,
14
- max_length=300, top_p=0.95, temperature=2.1, num_return_sequences=2
15
  )
16
  for i, sample_output in enumerate(sample_outputs):
17
  decoded = tokenizer.decode(sample_output, skip_special_tokens=True).split(text)[1]
@@ -19,8 +18,8 @@ def generate(tokenizer, model, text, features):
19
 
20
 
21
  def load_model():
22
- tokenizer = torch.load('./tokenizer.pt')
23
- model = torch.load('./model.pt', map_location=torch.device('cpu'))
24
  return tokenizer, model
25
 
26
 
@@ -71,4 +70,4 @@ def main():
71
  generate(tokenizer, model, data['title'], features)
72
 
73
  if __name__ == "__main__":
74
- main()
 
3
  import streamlit as st
4
  import sys
5
  import urllib
 
6
  import json
7
  import torch
8
 
 
10
  generated = tokenizer("<|startoftext|> <|titlestart|>{}<|titleend|>".format(text), return_tensors="pt").input_ids
11
  sample_outputs = model.generate(
12
  generated, do_sample=True, top_k=50,
13
+ max_length=300, top_p=features['top_p'], temperature=features['t'] / 100.0, num_return_sequences=features['num'],
14
  )
15
  for i, sample_output in enumerate(sample_outputs):
16
  decoded = tokenizer.decode(sample_output, skip_special_tokens=True).split(text)[1]
 
18
 
19
 
20
  def load_model():
21
+ tokenizer = torch.load('./final_model/tokenizer.pt')
22
+ model = torch.load('./final_model/model.pt', map_location=torch.device('cpu'))
23
  return tokenizer, model
24
 
25
 
 
70
  generate(tokenizer, model, data['title'], features)
71
 
72
  if __name__ == "__main__":
73
+ main()