BigSalmon commited on
Commit
a8a914a
·
1 Parent(s): 39e9427

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -1
app.py CHANGED
@@ -166,12 +166,58 @@ def LogProbs(prompt):
166
  print(df)
167
  st.write(df)
168
  return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  with st.form(key='my_form'):
171
  prompt = st.text_area(label='Enter sentence', value=g)
172
  submit_button = st.form_submit_button(label='Submit')
173
  submit_button2 = st.form_submit_button(label='Fast Forward')
174
  submit_button3 = st.form_submit_button(label='Fast Forward 2.0')
 
175
 
176
  if submit_button:
177
  with torch.no_grad():
@@ -198,4 +244,6 @@ with st.form(key='my_form'):
198
  if submit_button3:
199
  print("----")
200
  st.write("___")
201
- st.write(BestProbs)
 
 
 
166
  print(df)
167
  st.write(df)
168
  return df
169
+
170
+ def BestProbs5(prompt):
171
+ prompt = prompt.strip()
172
+ text = tokenizer.encode(prompt)
173
+ myinput, past_key_values = torch.tensor([text]), None
174
+ myinput = myinput
175
+ logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
176
+ logits = logits[0,-1]
177
+ probabilities = torch.nn.functional.softmax(logits)
178
+ best_logits, best_indices = logits.topk(5)
179
+ best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
180
+ for i in best_words[0:5]:
181
+ #print(i)
182
+ print("\n")
183
+ g = (prompt + i)
184
+ st.write(g)
185
+ l = run_generate(g, "hey")
186
+ st.write(l)
187
+
188
+ def run_generate(text, bad_words):
189
+ yo = []
190
+ input_ids = tokenizer.encode(text, return_tensors='pt')
191
+ res = len(tokenizer.encode(text))
192
+ bad_words = bad_words.split()
193
+ bad_word_ids = [[7829], [40940]]
194
+ for bad_word in bad_words:
195
+ bad_word = " " + bad_word
196
+ ids = tokenizer(bad_word).input_ids
197
+ bad_word_ids.append(ids)
198
+ sample_outputs = model.generate(
199
+ input_ids,
200
+ do_sample=True,
201
+ max_length= res + 5,
202
+ min_length = res + 5,
203
+ top_k=50,
204
+ temperature=1.0,
205
+ num_return_sequences=3,
206
+ bad_words_ids=bad_word_ids
207
+ )
208
+ for i in range(3):
209
+ e = tokenizer.decode(sample_outputs[i])
210
+ e = e.replace(text, "")
211
+ yo.append(e)
212
+ print(yo)
213
+ return yo
214
 
215
  with st.form(key='my_form'):
216
  prompt = st.text_area(label='Enter sentence', value=g)
217
  submit_button = st.form_submit_button(label='Submit')
218
  submit_button2 = st.form_submit_button(label='Fast Forward')
219
  submit_button3 = st.form_submit_button(label='Fast Forward 2.0')
220
+ submit_button4 = st.form_submit_button(label='Get Top')
221
 
222
  if submit_button:
223
  with torch.no_grad():
 
244
  if submit_button3:
245
  print("----")
246
  st.write("___")
247
+ st.write(BestProbs)
248
+ if submit_button4:
249
+ BestProbs5(prompt)