Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -260,6 +260,7 @@ def captioning(img):
|
|
260 |
"Describe this image with a caption to be used for image generation."
|
261 |
)
|
262 |
inputsa = processor5(images=img, text=cap_prompt, return_tensors="pt").to('cuda')
|
|
|
263 |
generated_ids = model5.generate(
|
264 |
**inputsa,
|
265 |
do_sample=False,
|
@@ -271,6 +272,19 @@ def captioning(img):
|
|
271 |
length_penalty=1.0,
|
272 |
temperature=1,
|
273 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
generated_text = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
276 |
generated_text = generated_text.replace(cap_prompt, "").strip() #Or could try .split(prompt, 1)[-1].strip()
|
@@ -282,14 +296,14 @@ def captioning(img):
|
|
282 |
#with torch.no_grad():
|
283 |
generated_ids = model5.generate(
|
284 |
**inputs,
|
285 |
-
do_sample=
|
286 |
-
num_beams=
|
287 |
max_length=64,
|
288 |
#min_length=16,
|
289 |
-
top_p=0.
|
290 |
-
repetition_penalty=1.
|
291 |
-
length_penalty=
|
292 |
-
temperature=
|
293 |
)
|
294 |
generated_text = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
295 |
response_text = generated_text.replace(prompt, "").strip() #Or could try .split(prompt, 1)[-1].strip()
|
@@ -302,14 +316,14 @@ def captioning(img):
|
|
302 |
).to('cuda')
|
303 |
generated_ids = model5.generate(
|
304 |
**inputf,
|
305 |
-
do_sample=
|
306 |
-
num_beams=
|
307 |
max_length=96,
|
308 |
min_length=64,
|
309 |
-
top_p=0.
|
310 |
-
repetition_penalty=1.
|
311 |
length_penalty=1.0,
|
312 |
-
temperature=
|
313 |
)
|
314 |
generated_texta = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
315 |
response_text = generated_texta.replace(generated_text, "").strip()
|
|
|
260 |
"Describe this image with a caption to be used for image generation."
|
261 |
)
|
262 |
inputsa = processor5(images=img, text=cap_prompt, return_tensors="pt").to('cuda')
|
263 |
+
'''
|
264 |
generated_ids = model5.generate(
|
265 |
**inputsa,
|
266 |
do_sample=False,
|
|
|
272 |
length_penalty=1.0,
|
273 |
temperature=1,
|
274 |
)
|
275 |
+
'''
|
276 |
+
generated_ids = model5.generate(
|
277 |
+
**inputsa,
|
278 |
+
do_sample=text_decoding_method == "Nucleus sampling",
|
279 |
+
num_beams=1,
|
280 |
+
max_length=128,
|
281 |
+
min_length=64,
|
282 |
+
top_p=0.9,
|
283 |
+
repetition_penalty=1.0,
|
284 |
+
length_penalty=2.0,
|
285 |
+
temperature=0.5,
|
286 |
+
)
|
287 |
+
|
288 |
|
289 |
generated_text = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
290 |
generated_text = generated_text.replace(cap_prompt, "").strip() #Or could try .split(prompt, 1)[-1].strip()
|
|
|
296 |
#with torch.no_grad():
|
297 |
generated_ids = model5.generate(
|
298 |
**inputs,
|
299 |
+
do_sample=text_decoding_method == "Nucleus sampling",
|
300 |
+
num_beams=1,
|
301 |
max_length=64,
|
302 |
#min_length=16,
|
303 |
+
top_p=0.1,
|
304 |
+
repetition_penalty=1.0,
|
305 |
+
length_penalty=2.0,
|
306 |
+
temperature=0.5,
|
307 |
)
|
308 |
generated_text = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
309 |
response_text = generated_text.replace(prompt, "").strip() #Or could try .split(prompt, 1)[-1].strip()
|
|
|
316 |
).to('cuda')
|
317 |
generated_ids = model5.generate(
|
318 |
**inputf,
|
319 |
+
do_sample=text_decoding_method == "Nucleus sampling",
|
320 |
+
num_beams=1,
|
321 |
max_length=96,
|
322 |
min_length=64,
|
323 |
+
top_p=0.1,
|
324 |
+
repetition_penalty=1.0,
|
325 |
length_penalty=1.0,
|
326 |
+
temperature=0.5,
|
327 |
)
|
328 |
generated_texta = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
329 |
response_text = generated_texta.replace(generated_text, "").strip()
|