1inkusFace commited on
Commit
221091b
·
verified ·
1 Parent(s): 6886b34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -11
app.py CHANGED
@@ -260,6 +260,7 @@ def captioning(img):
260
  "Describe this image with a caption to be used for image generation."
261
  )
262
  inputsa = processor5(images=img, text=cap_prompt, return_tensors="pt").to('cuda')
 
263
  generated_ids = model5.generate(
264
  **inputsa,
265
  do_sample=False,
@@ -271,6 +272,19 @@ def captioning(img):
271
  length_penalty=1.0,
272
  temperature=1,
273
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  generated_text = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
276
  generated_text = generated_text.replace(cap_prompt, "").strip() #Or could try .split(prompt, 1)[-1].strip()
@@ -282,14 +296,14 @@ def captioning(img):
282
  #with torch.no_grad():
283
  generated_ids = model5.generate(
284
  **inputs,
285
- do_sample=False,
286
- num_beams=5,
287
  max_length=64,
288
  #min_length=16,
289
- top_p=0.9,
290
- repetition_penalty=1.5,
291
- length_penalty=1.0,
292
- temperature=1,
293
  )
294
  generated_text = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
295
  response_text = generated_text.replace(prompt, "").strip() #Or could try .split(prompt, 1)[-1].strip()
@@ -302,14 +316,14 @@ def captioning(img):
302
  ).to('cuda')
303
  generated_ids = model5.generate(
304
  **inputf,
305
- do_sample=False,
306
- num_beams=5,
307
  max_length=96,
308
  min_length=64,
309
- top_p=0.9,
310
- repetition_penalty=1.5,
311
  length_penalty=1.0,
312
- temperature=1,
313
  )
314
  generated_texta = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
315
  response_text = generated_texta.replace(generated_text, "").strip()
 
260
  "Describe this image with a caption to be used for image generation."
261
  )
262
  inputsa = processor5(images=img, text=cap_prompt, return_tensors="pt").to('cuda')
263
+ '''
264
  generated_ids = model5.generate(
265
  **inputsa,
266
  do_sample=False,
 
272
  length_penalty=1.0,
273
  temperature=1,
274
  )
275
+ '''
276
+ generated_ids = model5.generate(
277
+ **inputsa,
278
+ do_sample=text_decoding_method == "Nucleus sampling",
279
+ num_beams=1,
280
+ max_length=128,
281
+ min_length=64,
282
+ top_p=0.9,
283
+ repetition_penalty=1.0,
284
+ length_penalty=2.0,
285
+ temperature=0.5,
286
+ )
287
+
288
 
289
  generated_text = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
290
  generated_text = generated_text.replace(cap_prompt, "").strip() #Or could try .split(prompt, 1)[-1].strip()
 
296
  #with torch.no_grad():
297
  generated_ids = model5.generate(
298
  **inputs,
299
+ do_sample=text_decoding_method == "Nucleus sampling",
300
+ num_beams=1,
301
  max_length=64,
302
  #min_length=16,
303
+ top_p=0.1,
304
+ repetition_penalty=1.0,
305
+ length_penalty=2.0,
306
+ temperature=0.5,
307
  )
308
  generated_text = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
309
  response_text = generated_text.replace(prompt, "").strip() #Or could try .split(prompt, 1)[-1].strip()
 
316
  ).to('cuda')
317
  generated_ids = model5.generate(
318
  **inputf,
319
+ do_sample=text_decoding_method == "Nucleus sampling",
320
+ num_beams=1,
321
  max_length=96,
322
  min_length=64,
323
+ top_p=0.1,
324
+ repetition_penalty=1.0,
325
  length_penalty=1.0,
326
+ temperature=0.5,
327
  )
328
  generated_texta = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
329
  response_text = generated_texta.replace(generated_text, "").strip()