Simon Salmon commited on
Commit
e0e7abc
·
1 Parent(s): 210b713

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -13,23 +13,26 @@ st.title('KoGPT2 Demo')
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
- tokenizer = AutoTokenizer.from_pretrained("skt/kogpt2-base-v2")
17
- model = GPT2LMHeadModel.from_pretrained('skt/kogpt2-base-v2')
 
18
 
19
  with st.form(key='my_form'):
20
- text_input = st.text_input(label='Enter sentence')
21
  submit_button = st.form_submit_button(label='Submit')
22
 
23
  if submit_button:
24
  with torch.no_grad():
25
- inputs = tokenizer.encode(text_input)
26
- gen_ids = model.generate(torch.tensor([inputs]),
27
- max_length=128,
28
- repetition_penalty=2.0,
29
- pad_token_id=tokenizer.pad_token_id,
30
- eos_token_id=tokenizer.eos_token_id,
31
- bos_token_id=tokenizer.bos_token_id,
32
- use_cache=True)
33
- generated = tokenizer.decode(gen_ids[0,:].tolist())
34
-
35
- st.write(generated)
 
 
 
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
+ from transformers import AutoTokenizer, AutoModelWithLMHead
17
+ tokenizer = AutoTokenizer.from_pretrained("BigSalmon/SimplifyText")
18
+ model = AutoModelWithLMHead.from_pretrained("BigSalmon/SimplifyText")
19
 
20
  with st.form(key='my_form'):
21
+ prompt = st.text_input(label='Enter sentence')
22
  submit_button = st.form_submit_button(label='Submit')
23
 
24
  if submit_button:
25
  with torch.no_grad():
26
+ text = tokenizer.encode(prompt)
27
+ myinput, past_key_values = torch.tensor([text]), None
28
+ myinput = myinput
29
+ myinput= myinput.to(device)
30
+ logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
31
+ logits = logits[0,-1]
32
+ probabilities = torch.nn.functional.softmax(logits)
33
+ best_logits, best_indices = logits.topk(4)
34
+ best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
35
+ text.append(best_indices[0].item())
36
+ best_probabilities = probabilities[best_indices].tolist()
37
+ words = []
38
+ st.write(best_words)