openfree commited on
Commit
b785f7c
ยท
verified ยท
1 Parent(s): 5c85beb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +288 -6
app.py CHANGED
@@ -60,6 +60,20 @@ CUSTOM_CSS = """
60
  transform: translateY(-1px);
61
  }
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  .output-box {
64
  background: #ffffff;
65
  border-radius: 8px;
@@ -87,6 +101,14 @@ CUSTOM_CSS = """
87
  background: #f8f9fa;
88
  border-radius: 6px;
89
  }
 
 
 
 
 
 
 
 
90
  """
91
 
92
  DESCRIPTION = """
@@ -101,7 +123,148 @@ DESCRIPTION = """
101
  if not torch.cuda.is_available():
102
  DESCRIPTION += "\n<p style='color: #dc3545;'>Running on CPU ๐Ÿฅถ This demo requires GPU to function properly.</p>"
103
 
104
- # ๋ชจ๋ธ ์„ค์ • ๋ถ€๋ถ„์€ ๋™์ผํ•˜๊ฒŒ ์œ ์ง€...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  def create_interface():
107
  with gr.Blocks(css=CUSTOM_CSS) as demo:
@@ -132,6 +295,8 @@ def create_interface():
132
  chatbot = gr.Chatbot(
133
  elem_classes="chatbot-message"
134
  )
 
 
135
  vqa_input = gr.Textbox(
136
  placeholder="Ask me anything about the image...",
137
  elem_classes="input-box"
@@ -148,7 +313,6 @@ def create_interface():
148
  )
149
 
150
  with gr.Accordion("๐Ÿ› ๏ธ Advanced Settings", open=False, elem_classes="advanced-settings"):
151
- # ๊ณ ๊ธ‰ ์„ค์ • ์ปจํŠธ๋กค๋“ค...
152
  with gr.Row():
153
  with gr.Column():
154
  text_decoding_method = gr.Radio(
@@ -161,28 +325,146 @@ def create_interface():
161
  maximum=1.0,
162
  value=1.0,
163
  label="Temperature",
 
164
  elem_classes="slider-container"
165
  )
166
- with gr.Column():
167
  length_penalty = gr.Slider(
168
  minimum=-1.0,
169
  maximum=2.0,
170
  value=1.0,
171
  label="Length Penalty",
 
172
  elem_classes="slider-container"
173
  )
 
174
  repetition_penalty = gr.Slider(
175
  minimum=1.0,
176
  maximum=5.0,
177
  value=1.5,
178
  label="Repetition Penalty",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  elem_classes="slider-container"
180
  )
181
 
182
- # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ ์—ฐ๊ฒฐ...
183
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  return demo
185
 
186
  if __name__ == "__main__":
187
  demo = create_interface()
188
- demo.queue(max_size=10).launch()
 
 
 
 
 
 
 
60
  transform: translateY(-1px);
61
  }
62
 
63
+ .button-secondary {
64
+ background: #f8f9fa;
65
+ color: #1a73e8;
66
+ border: 1px solid #1a73e8;
67
+ padding: 0.75rem 1.5rem;
68
+ border-radius: 8px;
69
+ cursor: pointer;
70
+ transition: all 0.3s ease;
71
+ }
72
+
73
+ .button-secondary:hover {
74
+ background: #e8f0fe;
75
+ }
76
+
77
  .output-box {
78
  background: #ffffff;
79
  border-radius: 8px;
 
101
  background: #f8f9fa;
102
  border-radius: 6px;
103
  }
104
+
105
+ .examples-container {
106
+ margin-top: 2rem;
107
+ padding: 1rem;
108
+ background: #ffffff;
109
+ border-radius: 8px;
110
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
111
+ }
112
  """
113
 
114
  DESCRIPTION = """
 
123
  if not torch.cuda.is_available():
124
  DESCRIPTION += "\n<p style='color: #dc3545;'>Running on CPU ๐Ÿฅถ This demo requires GPU to function properly.</p>"
125
 
126
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
127
+
128
+ MODEL_ID_OPT_2_7B = "Salesforce/blip2-opt-2.7b"
129
+ MODEL_ID_OPT_6_7B = "Salesforce/blip2-opt-6.7b"
130
+ MODEL_ID_FLAN_T5_XL = "Salesforce/blip2-flan-t5-xl"
131
+ MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl"
132
+ MODEL_ID = os.getenv("MODEL_ID", MODEL_ID_FLAN_T5_XXL)
133
+
134
+ if MODEL_ID not in [MODEL_ID_OPT_2_7B, MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XL, MODEL_ID_FLAN_T5_XXL]:
135
+ error_message = f"Invalid MODEL_ID: {MODEL_ID}"
136
+ raise ValueError(error_message)
137
+
138
+ if torch.cuda.is_available():
139
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
140
+ model = Blip2ForConditionalGeneration.from_pretrained(
141
+ MODEL_ID,
142
+ device_map="auto",
143
+ quantization_config=BitsAndBytesConfig(load_in_8bit=True)
144
+ )
145
+
146
+ @spaces.GPU
147
+ def generate_caption(
148
+ image: PIL.Image.Image,
149
+ decoding_method: str = "Nucleus sampling",
150
+ temperature: float = 1.0,
151
+ length_penalty: float = 1.0,
152
+ repetition_penalty: float = 1.5,
153
+ max_length: int = 50,
154
+ min_length: int = 1,
155
+ num_beams: int = 5,
156
+ top_p: float = 0.9,
157
+ ) -> str:
158
+ inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
159
+ generated_ids = model.generate(
160
+ pixel_values=inputs.pixel_values,
161
+ do_sample=decoding_method == "Nucleus sampling",
162
+ temperature=temperature,
163
+ length_penalty=length_penalty,
164
+ repetition_penalty=repetition_penalty,
165
+ max_length=max_length,
166
+ min_length=min_length,
167
+ num_beams=num_beams,
168
+ top_p=top_p,
169
+ )
170
+ return processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
171
+
172
+ @spaces.GPU
173
+ def answer_question(
174
+ image: PIL.Image.Image,
175
+ prompt: str,
176
+ decoding_method: str = "Nucleus sampling",
177
+ temperature: float = 1.0,
178
+ length_penalty: float = 1.0,
179
+ repetition_penalty: float = 1.5,
180
+ max_length: int = 50,
181
+ min_length: int = 1,
182
+ num_beams: int = 5,
183
+ top_p: float = 0.9,
184
+ ) -> str:
185
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
186
+ generated_ids = model.generate(
187
+ **inputs,
188
+ do_sample=decoding_method == "Nucleus sampling",
189
+ temperature=temperature,
190
+ length_penalty=length_penalty,
191
+ repetition_penalty=repetition_penalty,
192
+ max_length=max_length,
193
+ min_length=min_length,
194
+ num_beams=num_beams,
195
+ top_p=top_p,
196
+ )
197
+ return processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
198
+
199
+ def postprocess_output(output: str) -> str:
200
+ if output and output[-1] not in string.punctuation:
201
+ output += "."
202
+ return output
203
+
204
+ def chat(
205
+ image: PIL.Image.Image,
206
+ text: str,
207
+ decoding_method: str = "Nucleus sampling",
208
+ temperature: float = 1.0,
209
+ length_penalty: float = 1.0,
210
+ repetition_penalty: float = 1.5,
211
+ max_length: int = 50,
212
+ min_length: int = 1,
213
+ num_beams: int = 5,
214
+ top_p: float = 0.9,
215
+ history_orig: list[str] | None = None,
216
+ history_qa: list[str] | None = None,
217
+ ) -> tuple[list[tuple[str, str]], list[str], list[str]]:
218
+ history_orig = history_orig or []
219
+ history_qa = history_qa or []
220
+ history_orig.append(text)
221
+ text_qa = f"Question: {text} Answer:"
222
+ history_qa.append(text_qa)
223
+ prompt = " ".join(history_qa)
224
+
225
+ output = answer_question(
226
+ image=image,
227
+ prompt=prompt,
228
+ decoding_method=decoding_method,
229
+ temperature=temperature,
230
+ length_penalty=length_penalty,
231
+ repetition_penalty=repetition_penalty,
232
+ max_length=max_length,
233
+ min_length=min_length,
234
+ num_beams=num_beams,
235
+ top_p=top_p,
236
+ )
237
+ output = postprocess_output(output)
238
+ history_orig.append(output)
239
+ history_qa.append(output)
240
+
241
+ chat_val = list(zip(history_orig[0::2], history_orig[1::2], strict=False))
242
+ return chat_val, history_orig, history_qa
243
+
244
+ chat.zerogpu = True # type: ignore
245
+
246
+ examples = [
247
+ [
248
+ "images/house.png",
249
+ "How could someone get out of the house?",
250
+ ],
251
+ [
252
+ "images/flower.jpg",
253
+ "What is this flower and where is it's origin?",
254
+ ],
255
+ [
256
+ "images/pizza.jpg",
257
+ "What are steps to cook it?",
258
+ ],
259
+ [
260
+ "images/sunset.jpg",
261
+ "Here is a romantic message going along the photo:",
262
+ ],
263
+ [
264
+ "images/forbidden_city.webp",
265
+ "In what dynasties was this place built?",
266
+ ],
267
+ ]
268
 
269
  def create_interface():
270
  with gr.Blocks(css=CUSTOM_CSS) as demo:
 
295
  chatbot = gr.Chatbot(
296
  elem_classes="chatbot-message"
297
  )
298
+ history_orig = gr.State(value=[])
299
+ history_qa = gr.State(value=[])
300
  vqa_input = gr.Textbox(
301
  placeholder="Ask me anything about the image...",
302
  elem_classes="input-box"
 
313
  )
314
 
315
  with gr.Accordion("๐Ÿ› ๏ธ Advanced Settings", open=False, elem_classes="advanced-settings"):
 
316
  with gr.Row():
317
  with gr.Column():
318
  text_decoding_method = gr.Radio(
 
325
  maximum=1.0,
326
  value=1.0,
327
  label="Temperature",
328
+ info="Used with nucleus sampling",
329
  elem_classes="slider-container"
330
  )
 
331
  length_penalty = gr.Slider(
332
  minimum=-1.0,
333
  maximum=2.0,
334
  value=1.0,
335
  label="Length Penalty",
336
+ info="Set to larger for longer sequence",
337
  elem_classes="slider-container"
338
  )
339
+ with gr.Column():
340
  repetition_penalty = gr.Slider(
341
  minimum=1.0,
342
  maximum=5.0,
343
  value=1.5,
344
  label="Repetition Penalty",
345
+ info="Larger value prevents repetition",
346
+ elem_classes="slider-container"
347
+ )
348
+ max_length = gr.Slider(
349
+ minimum=20,
350
+ maximum=512,
351
+ value=50,
352
+ label="Max Length",
353
+ elem_classes="slider-container"
354
+ )
355
+ min_length = gr.Slider(
356
+ minimum=1,
357
+ maximum=100,
358
+ value=1,
359
+ label="Min Length",
360
+ elem_classes="slider-container"
361
+ )
362
+ num_beams = gr.Slider(
363
+ minimum=1,
364
+ maximum=10,
365
+ value=5,
366
+ label="Number of Beams",
367
+ elem_classes="slider-container"
368
+ )
369
+ top_p = gr.Slider(
370
+ minimum=0.5,
371
+ maximum=1.0,
372
+ value=0.9,
373
+ label="Top P",
374
+ info="Used with nucleus sampling",
375
  elem_classes="slider-container"
376
  )
377
 
378
+ with gr.Group(elem_classes="examples-container"):
379
+ gr.Examples(
380
+ examples=examples,
381
+ inputs=[image, vqa_input],
382
+ label="Try these examples"
383
+ )
384
+
385
+ # Event handlers
386
+ caption_button.click(
387
+ fn=generate_caption,
388
+ inputs=[
389
+ image,
390
+ text_decoding_method,
391
+ temperature,
392
+ length_penalty,
393
+ repetition_penalty,
394
+ max_length,
395
+ min_length,
396
+ num_beams,
397
+ top_p,
398
+ ],
399
+ outputs=caption_output,
400
+ api_name="caption",
401
+ )
402
+
403
+ chat_inputs = [
404
+ image,
405
+ vqa_input,
406
+ text_decoding_method,
407
+ temperature,
408
+ length_penalty,
409
+ repetition_penalty,
410
+ max_length,
411
+ min_length,
412
+ num_beams,
413
+ top_p,
414
+ history_orig,
415
+ history_qa,
416
+ ]
417
+ chat_outputs = [
418
+ chatbot,
419
+ history_orig,
420
+ history_qa,
421
+ ]
422
+
423
+ vqa_input.submit(
424
+ fn=chat,
425
+ inputs=chat_inputs,
426
+ outputs=chat_outputs,
427
+ api_name="chat",
428
+ ).success(
429
+ fn=lambda: "",
430
+ outputs=vqa_input,
431
+ queue=False,
432
+ api_name=False,
433
+ )
434
+
435
+ clear_button.click(
436
+ fn=lambda: ("", [], [], []),
437
+ inputs=None,
438
+ outputs=[
439
+ vqa_input,
440
+ chatbot,
441
+ history_orig,
442
+ history_qa,
443
+ ],
444
+ queue=False,
445
+ api_name="clear",
446
+ )
447
+
448
+ image.change(
449
+ fn=lambda: ("", [], [], []),
450
+ inputs=None,
451
+ outputs=[
452
+ caption_output,
453
+ chatbot,
454
+ history_orig,
455
+ history_qa,
456
+ ],
457
+ queue=False,
458
+ )
459
+
460
  return demo
461
 
462
  if __name__ == "__main__":
463
  demo = create_interface()
464
+ demo.queue(max_size=10).launch(),
465
+ ).success(
466
+ fn=lambda: "",
467
+ outputs=vqa_input,
468
+ queue=False,
469
+ api_name=False,
470
+ )