Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -166,12 +166,58 @@ def LogProbs(prompt):
|
|
166 |
print(df)
|
167 |
st.write(df)
|
168 |
return df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
with st.form(key='my_form'):
|
171 |
prompt = st.text_area(label='Enter sentence', value=g)
|
172 |
submit_button = st.form_submit_button(label='Submit')
|
173 |
submit_button2 = st.form_submit_button(label='Fast Forward')
|
174 |
submit_button3 = st.form_submit_button(label='Fast Forward 2.0')
|
|
|
175 |
|
176 |
if submit_button:
|
177 |
with torch.no_grad():
|
@@ -198,4 +244,6 @@ with st.form(key='my_form'):
|
|
198 |
if submit_button3:
|
199 |
print("----")
|
200 |
st.write("___")
|
201 |
-
st.write(BestProbs)
|
|
|
|
|
|
166 |
print(df)
|
167 |
st.write(df)
|
168 |
return df
|
169 |
+
|
170 |
+
def BestProbs5(prompt):
|
171 |
+
prompt = prompt.strip()
|
172 |
+
text = tokenizer.encode(prompt)
|
173 |
+
myinput, past_key_values = torch.tensor([text]), None
|
174 |
+
myinput = myinput
|
175 |
+
logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
|
176 |
+
logits = logits[0,-1]
|
177 |
+
probabilities = torch.nn.functional.softmax(logits)
|
178 |
+
best_logits, best_indices = logits.topk(5)
|
179 |
+
best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
|
180 |
+
for i in best_words[0:5]:
|
181 |
+
#print(i)
|
182 |
+
print("\n")
|
183 |
+
g = (prompt + i)
|
184 |
+
st.write(g)
|
185 |
+
l = run_generate(g, "hey")
|
186 |
+
st.write(l)
|
187 |
+
|
188 |
+
def run_generate(text, bad_words):
|
189 |
+
yo = []
|
190 |
+
input_ids = tokenizer.encode(text, return_tensors='pt')
|
191 |
+
res = len(tokenizer.encode(text))
|
192 |
+
bad_words = bad_words.split()
|
193 |
+
bad_word_ids = [[7829], [40940]]
|
194 |
+
for bad_word in bad_words:
|
195 |
+
bad_word = " " + bad_word
|
196 |
+
ids = tokenizer(bad_word).input_ids
|
197 |
+
bad_word_ids.append(ids)
|
198 |
+
sample_outputs = model.generate(
|
199 |
+
input_ids,
|
200 |
+
do_sample=True,
|
201 |
+
max_length= res + 5,
|
202 |
+
min_length = res + 5,
|
203 |
+
top_k=50,
|
204 |
+
temperature=1.0,
|
205 |
+
num_return_sequences=3,
|
206 |
+
bad_words_ids=bad_word_ids
|
207 |
+
)
|
208 |
+
for i in range(3):
|
209 |
+
e = tokenizer.decode(sample_outputs[i])
|
210 |
+
e = e.replace(text, "")
|
211 |
+
yo.append(e)
|
212 |
+
print(yo)
|
213 |
+
return yo
|
214 |
|
215 |
with st.form(key='my_form'):
|
216 |
prompt = st.text_area(label='Enter sentence', value=g)
|
217 |
submit_button = st.form_submit_button(label='Submit')
|
218 |
submit_button2 = st.form_submit_button(label='Fast Forward')
|
219 |
submit_button3 = st.form_submit_button(label='Fast Forward 2.0')
|
220 |
+
submit_button4 = st.form_submit_button(label='Get Top')
|
221 |
|
222 |
if submit_button:
|
223 |
with torch.no_grad():
|
|
|
244 |
if submit_button3:
|
245 |
print("----")
|
246 |
st.write("___")
|
247 |
+
st.write(BestProbs)
|
248 |
+
if submit_button4:
|
249 |
+
BestProbs5(prompt)
|