openfree commited on
Commit
30a0d3e
·
verified ·
1 Parent(s): 206e2f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -191
app.py CHANGED
@@ -266,143 +266,126 @@ examples = [
266
  ],
267
  ]
268
 
269
- def create_interface():
270
- with gr.Blocks(css=CUSTOM_CSS) as demo:
271
- gr.Markdown(DESCRIPTION)
272
-
273
- with gr.Group(elem_classes="container"):
274
- with gr.Row():
275
- with gr.Column(scale=1):
276
- image = gr.Image(
277
- type="pil",
278
- label="Upload Image",
279
- elem_classes="input-box"
280
- )
281
-
282
- with gr.Column(scale=2):
283
- with gr.Tabs(elem_classes="tab-nav"):
284
- with gr.Tab(label="✨ Image Captioning"):
285
- caption_button = gr.Button(
286
- "Generate Caption",
287
- elem_classes="button-primary"
288
- )
289
- caption_output = gr.Textbox(
290
- label="Generated Caption",
291
- elem_classes="output-box"
292
- )
293
-
294
- with gr.Tab(label="💭 Visual Q&A"):
295
- chatbot = gr.Chatbot(
296
- elem_classes="chatbot-message"
297
- )
298
- history_orig = gr.State(value=[])
299
- history_qa = gr.State(value=[])
300
- vqa_input = gr.Textbox(
301
- placeholder="Ask me anything about the image...",
302
- elem_classes="input-box"
303
- )
304
-
305
- with gr.Row():
306
- clear_button = gr.Button(
307
- "Clear Chat",
308
- elem_classes="button-secondary"
309
- )
310
- submit_button = gr.Button(
311
- "Send Message",
312
- elem_classes="button-primary"
313
- )
314
-
315
- with gr.Accordion("🛠️ Advanced Settings", open=False, elem_classes="advanced-settings"):
316
- with gr.Row():
317
- with gr.Column():
318
- text_decoding_method = gr.Radio(
319
- choices=["Beam search", "Nucleus sampling"],
320
- value="Nucleus sampling",
321
- label="Decoding Method"
322
- )
323
- temperature = gr.Slider(
324
- minimum=0.5,
325
- maximum=1.0,
326
- value=1.0,
327
- label="Temperature",
328
- info="Used with nucleus sampling",
329
- elem_classes="slider-container"
330
- )
331
- length_penalty = gr.Slider(
332
- minimum=-1.0,
333
- maximum=2.0,
334
- value=1.0,
335
- label="Length Penalty",
336
- info="Set to larger for longer sequence",
337
- elem_classes="slider-container"
338
- )
339
- with gr.Column():
340
- repetition_penalty = gr.Slider(
341
- minimum=1.0,
342
- maximum=5.0,
343
- value=1.5,
344
- label="Repetition Penalty",
345
- info="Larger value prevents repetition",
346
- elem_classes="slider-container"
347
- )
348
- max_length = gr.Slider(
349
- minimum=20,
350
- maximum=512,
351
- value=50,
352
- label="Max Length",
353
- elem_classes="slider-container"
354
  )
355
- min_length = gr.Slider(
356
- minimum=1,
357
- maximum=100,
358
- value=1,
359
- label="Min Length",
360
- elem_classes="slider-container"
361
  )
362
- num_beams = gr.Slider(
363
- minimum=1,
364
- maximum=10,
365
- value=5,
366
- label="Number of Beams",
367
- elem_classes="slider-container"
368
  )
369
- top_p = gr.Slider(
370
- minimum=0.5,
371
- maximum=1.0,
372
- value=0.9,
373
- label="Top P",
374
- info="Used with nucleus sampling",
375
- elem_classes="slider-container"
376
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
- with gr.Group(elem_classes="examples-container"):
379
- gr.Examples(
380
- examples=examples,
381
- inputs=[image, vqa_input],
382
- label="Try these examples"
383
- )
384
-
385
- # Event handlers
386
- caption_button.click(
387
- fn=generate_caption,
388
- inputs=[
389
- image,
390
- text_decoding_method,
391
- temperature,
392
- length_penalty,
393
- repetition_penalty,
394
- max_length,
395
- min_length,
396
- num_beams,
397
- top_p,
398
- ],
399
- outputs=caption_output,
400
- api_name="caption",
401
  )
402
 
403
- chat_inputs = [
 
 
 
404
  image,
405
- vqa_input,
406
  text_decoding_method,
407
  temperature,
408
  length_penalty,
@@ -411,72 +394,70 @@ def create_interface():
411
  min_length,
412
  num_beams,
413
  top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  history_orig,
415
  history_qa,
416
- ]
417
- chat_outputs = [
 
 
 
 
 
 
 
 
418
  chatbot,
419
  history_orig,
420
  history_qa,
421
- ]
422
-
423
- vqa_input.submit(
424
- fn=chat,
425
- inputs=chat_inputs,
426
- outputs=chat_outputs,
427
- api_name="chat",
428
- ).success(
429
- fn=lambda: "",
430
- outputs=vqa_input,
431
- queue=False,
432
- api_name=False,
433
- )
434
-
435
- clear_button.click(
436
- fn=lambda: ("", [], [], []),
437
- inputs=None,
438
- outputs=[
439
- vqa_input,
440
- chatbot,
441
- history_orig,
442
- history_qa,
443
- ],
444
- queue=False,
445
- api_name="clear",
446
- )
447
-
448
- image.change(
449
- fn=lambda: ("", [], [], []),
450
- inputs=None,
451
- outputs=[
452
- caption_output,
453
- chatbot,
454
- history_orig,
455
- history_qa,
456
- ],
457
- queue=False,
458
- )
459
-
460
- return demo
461
 
462
  if __name__ == "__main__":
463
- demo = create_interface()
464
- demo.queue(max_size=10).launch(),
465
- ).success(
466
- fn=lambda: "",
467
- outputs=vqa_input,
468
- queue=False,
469
- api_name=False,
470
- )
471
-
472
- submit_button.click(
473
- fn=chat,
474
- inputs=chat_inputs,
475
- outputs=chat_outputs,
476
- api_name="chat"
477
- ).success(
478
- fn=lambda: "",
479
- outputs=vqa_input,
480
- queue=False,
481
- api_name=False
482
- )
 
266
  ],
267
  ]
268
 
269
+ with gr.Blocks(css=CUSTOM_CSS) as demo:
270
+ gr.Markdown(DESCRIPTION)
271
+
272
+ with gr.Group(elem_classes="container"):
273
+ with gr.Row():
274
+ with gr.Column(scale=1):
275
+ image = gr.Image(
276
+ type="pil",
277
+ label="Upload Image",
278
+ elem_classes="input-box"
279
+ )
280
+
281
+ with gr.Column(scale=2):
282
+ with gr.Tabs(elem_classes="tab-nav"):
283
+ with gr.Tab(label="✨ Image Captioning"):
284
+ caption_button = gr.Button(
285
+ "Generate Caption",
286
+ elem_classes="button-primary"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  )
288
+ caption_output = gr.Textbox(
289
+ label="Generated Caption",
290
+ elem_classes="output-box"
 
 
 
291
  )
292
+
293
+ with gr.Tab(label="💭 Visual Q&A"):
294
+ chatbot = gr.Chatbot(
295
+ elem_classes="chatbot-message"
 
 
296
  )
297
+ history_orig = gr.State(value=[])
298
+ history_qa = gr.State(value=[])
299
+ vqa_input = gr.Textbox(
300
+ placeholder="Ask me anything about the image...",
301
+ elem_classes="input-box"
 
 
302
  )
303
+
304
+ with gr.Row():
305
+ clear_button = gr.Button(
306
+ "Clear Chat",
307
+ elem_classes="button-secondary"
308
+ )
309
+ submit_button = gr.Button(
310
+ "Send Message",
311
+ elem_classes="button-primary"
312
+ )
313
+
314
+ with gr.Accordion("🛠️ Advanced Settings", open=False, elem_classes="advanced-settings"):
315
+ with gr.Row():
316
+ with gr.Column():
317
+ text_decoding_method = gr.Radio(
318
+ choices=["Beam search", "Nucleus sampling"],
319
+ value="Nucleus sampling",
320
+ label="Decoding Method"
321
+ )
322
+ temperature = gr.Slider(
323
+ minimum=0.5,
324
+ maximum=1.0,
325
+ value=1.0,
326
+ label="Temperature",
327
+ info="Used with nucleus sampling",
328
+ elem_classes="slider-container"
329
+ )
330
+ length_penalty = gr.Slider(
331
+ minimum=-1.0,
332
+ maximum=2.0,
333
+ value=1.0,
334
+ label="Length Penalty",
335
+ info="Set to larger for longer sequence",
336
+ elem_classes="slider-container"
337
+ )
338
+ with gr.Column():
339
+ repetition_penalty = gr.Slider(
340
+ minimum=1.0,
341
+ maximum=5.0,
342
+ value=1.5,
343
+ label="Repetition Penalty",
344
+ info="Larger value prevents repetition",
345
+ elem_classes="slider-container"
346
+ )
347
+ max_length = gr.Slider(
348
+ minimum=20,
349
+ maximum=512,
350
+ value=50,
351
+ label="Max Length",
352
+ elem_classes="slider-container"
353
+ )
354
+ min_length = gr.Slider(
355
+ minimum=1,
356
+ maximum=100,
357
+ value=1,
358
+ label="Min Length",
359
+ elem_classes="slider-container"
360
+ )
361
+ num_beams = gr.Slider(
362
+ minimum=1,
363
+ maximum=10,
364
+ value=5,
365
+ label="Number of Beams",
366
+ elem_classes="slider-container"
367
+ )
368
+ top_p = gr.Slider(
369
+ minimum=0.5,
370
+ maximum=1.0,
371
+ value=0.9,
372
+ label="Top P",
373
+ info="Used with nucleus sampling",
374
+ elem_classes="slider-container"
375
+ )
376
 
377
+ with gr.Group(elem_classes="examples-container"):
378
+ gr.Examples(
379
+ examples=examples,
380
+ inputs=[image, vqa_input],
381
+ label="Try these examples"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  )
383
 
384
+ # Event handlers
385
+ caption_button.click(
386
+ fn=generate_caption,
387
+ inputs=[
388
  image,
 
389
  text_decoding_method,
390
  temperature,
391
  length_penalty,
 
394
  min_length,
395
  num_beams,
396
  top_p,
397
+ ],
398
+ outputs=caption_output,
399
+ api_name="caption",
400
+ )
401
+
402
+ chat_inputs = [
403
+ image,
404
+ vqa_input,
405
+ text_decoding_method,
406
+ temperature,
407
+ length_penalty,
408
+ repetition_penalty,
409
+ max_length,
410
+ min_length,
411
+ num_beams,
412
+ top_p,
413
+ history_orig,
414
+ history_qa,
415
+ ]
416
+ chat_outputs = [
417
+ chatbot,
418
+ history_orig,
419
+ history_qa,
420
+ ]
421
+
422
+ vqa_input.submit(
423
+ fn=chat,
424
+ inputs=chat_inputs,
425
+ outputs=chat_outputs
426
+ ).success(
427
+ fn=lambda: "",
428
+ outputs=vqa_input,
429
+ queue=False,
430
+ api_name=False
431
+ )
432
+
433
+ clear_button.click(
434
+ fn=lambda: ("", [], [], []),
435
+ inputs=None,
436
+ outputs=[
437
+ vqa_input,
438
+ chatbot,
439
  history_orig,
440
  history_qa,
441
+ ],
442
+ queue=False,
443
+ api_name="clear"
444
+ )
445
+
446
+ image.change(
447
+ fn=lambda: ("", [], [], []),
448
+ inputs=None,
449
+ outputs=[
450
+ caption_output,
451
  chatbot,
452
  history_orig,
453
  history_qa,
454
+ ],
455
+ queue=False
456
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
  if __name__ == "__main__":
459
+ demo.queue(max_size=10).launch()
460
+ outputs=vqa_input,
461
+ queue=False,
462
+ api_name=False
463
+ )