ZebangCheng commited on
Commit
ea07ffb
·
1 Parent(s): b46ef65
Files changed (1) hide show
  1. app.py +693 -4
app.py CHANGED
@@ -1,7 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "! I'm gradio, nice to meet you ! Have a good day !"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ from collections import defaultdict
5
+
6
+ import cv2
7
+ import re
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ import torch
12
+ import html
13
  import gradio as gr
14
 
15
+ import torchvision.transforms as T
16
+ import torch.backends.cudnn as cudnn
17
+
18
+ from minigpt4.common.config import Config
19
+
20
+ from minigpt4.common.registry import registry
21
+ from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
22
+
23
+ # imports modules for registration
24
+ from minigpt4.datasets.builders import *
25
+ from minigpt4.models import *
26
+ from minigpt4.processors import *
27
+ from minigpt4.runners import *
28
+ from minigpt4.tasks import *
29
+
30
+
31
+ def parse_args():
32
+ parser = argparse.ArgumentParser(description="Demo")
33
+ parser.add_argument("--cfg-path", default='eval_configs/demo.yaml',
34
+ help="path to configuration file.")
35
+ parser.add_argument(
36
+ "--options",
37
+ nargs="+",
38
+ help="override some settings in the used config, the key-value pair "
39
+ "in xxx=yyy format will be merged into config file (deprecate), "
40
+ "change to --cfg-options instead.",
41
+ )
42
+ args = parser.parse_args()
43
+ return args
44
+
45
+
46
+ random.seed(42)
47
+ np.random.seed(42)
48
+ torch.manual_seed(42)
49
+
50
+ cudnn.benchmark = False
51
+ cudnn.deterministic = True
52
+
53
+ print('Initializing Chat')
54
+ args = parse_args()
55
+ cfg = Config(args)
56
+
57
+ device = 'cuda'
58
+
59
+ model_config = cfg.model_cfg
60
+
61
+ print("model_config:", model_config)
62
+ model_cls = registry.get_model_class(model_config.arch)
63
+ model = model_cls.from_config(model_config).to(device)
64
+ bounding_box_size = 100
65
+
66
+ vis_processor_cfg = cfg.datasets_cfg.feature_face_caption.vis_processor.train
67
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
68
+
69
+ model = model.eval()
70
+
71
+ CONV_VISION = Conversation(
72
+ system="",
73
+ roles=(r"<s>[INST] ", r" [/INST]"),
74
+ messages=[],
75
+ offset=2,
76
+ sep_style=SeparatorStyle.SINGLE,
77
+ sep="",
78
+ )
79
+
80
+
81
+ def extract_substrings(string):
82
+ # first check if there is no-finished bracket
83
+ index = string.rfind('}')
84
+ if index != -1:
85
+ string = string[:index + 1]
86
+
87
+ pattern = r'<p>(.*?)\}(?!<)'
88
+ matches = re.findall(pattern, string)
89
+ substrings = [match for match in matches]
90
+
91
+ return substrings
92
+
93
+
94
+ def is_overlapping(rect1, rect2):
95
+ x1, y1, x2, y2 = rect1
96
+ x3, y3, x4, y4 = rect2
97
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
98
+
99
+
100
+ def computeIoU(bbox1, bbox2):
101
+ x1, y1, x2, y2 = bbox1
102
+ x3, y3, x4, y4 = bbox2
103
+ intersection_x1 = max(x1, x3)
104
+ intersection_y1 = max(y1, y3)
105
+ intersection_x2 = min(x2, x4)
106
+ intersection_y2 = min(y2, y4)
107
+ intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
108
+ bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
109
+ bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
110
+ union_area = bbox1_area + bbox2_area - intersection_area
111
+ iou = intersection_area / union_area
112
+ return iou
113
+
114
+
115
+ def save_tmp_img(visual_img):
116
+ file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
117
+ file_path = "/tmp/gradio" + file_name
118
+ visual_img.save(file_path)
119
+ return file_path
120
+
121
+
122
+ def mask2bbox(mask):
123
+ if mask is None:
124
+ return ''
125
+ mask = mask.resize([100, 100], resample=Image.NEAREST)
126
+ mask = np.array(mask)[:, :, 0]
127
+
128
+ rows = np.any(mask, axis=1)
129
+ cols = np.any(mask, axis=0)
130
+
131
+ if rows.sum():
132
+ # Get the top, bottom, left, and right boundaries
133
+ rmin, rmax = np.where(rows)[0][[0, -1]]
134
+ cmin, cmax = np.where(cols)[0][[0, -1]]
135
+ bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
136
+ else:
137
+ bbox = ''
138
+
139
+ return bbox
140
+
141
+
142
+ def escape_markdown(text):
143
+ # List of Markdown special characters that need to be escaped
144
+ md_chars = ['<', '>']
145
+
146
+ # Escape each special character
147
+ for char in md_chars:
148
+ text = text.replace(char, '\\' + char)
149
+
150
+ return text
151
+
152
+
153
+ def reverse_escape(text):
154
+ md_chars = ['\\<', '\\>']
155
+
156
+ for char in md_chars:
157
+ text = text.replace(char, char[1:])
158
+
159
+ return text
160
+
161
+
162
+ colors = [
163
+ (255, 0, 0),
164
+ (0, 255, 0),
165
+ (0, 0, 255),
166
+ (210, 210, 0),
167
+ (255, 0, 255),
168
+ (0, 255, 255),
169
+ (114, 128, 250),
170
+ (0, 165, 255),
171
+ (0, 128, 0),
172
+ (144, 238, 144),
173
+ (238, 238, 175),
174
+ (255, 191, 0),
175
+ (0, 128, 0),
176
+ (226, 43, 138),
177
+ (255, 0, 255),
178
+ (0, 215, 255),
179
+ ]
180
+
181
+ color_map = {
182
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
183
+ color_id, color in enumerate(colors)
184
+ }
185
+
186
+ used_colors = colors
187
+
188
+ def get_first_frame(video_path):
189
+ cap = cv2.VideoCapture(video_path)
190
+
191
+ if not cap.isOpened():
192
+ print("Error: Cannot open video.")
193
+ return None
194
+
195
+ ret, frame = cap.read()
196
+ cap.release()
197
+
198
+ if ret:
199
+ return frame
200
+ else:
201
+ print("Error: Cannot read frame from video.")
202
+ return None
203
+
204
+ def visualize_all_bbox_together(image, generation):
205
+ if image is None:
206
+ return None, ''
207
+
208
+ if isinstance(image, str): # is a image path
209
+ raw_image = get_first_frame(image)
210
+ frame_rgb = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
211
+ image = Image.fromarray(frame_rgb)
212
+
213
+ generation = html.unescape(generation)
214
+
215
+ image_width, image_height = image.size
216
+ image = image.resize([500, int(500 / image_width * image_height)])
217
+ image_width, image_height = image.size
218
+
219
+ string_list = extract_substrings(generation)
220
+ if string_list: # it is grounding or detection
221
+ mode = 'all'
222
+ entities = defaultdict(list)
223
+ i = 0
224
+ j = 0
225
+ for string in string_list:
226
+ try:
227
+ obj, string = string.split('</p>')
228
+ except ValueError:
229
+ print('wrong string: ', string)
230
+ continue
231
+ bbox_list = string.split('<delim>')
232
+ flag = False
233
+ for bbox_string in bbox_list:
234
+ integers = re.findall(r'-?\d+', bbox_string)
235
+ if len(integers) == 4:
236
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
237
+ left = x0 / bounding_box_size * image_width
238
+ bottom = y0 / bounding_box_size * image_height
239
+ right = x1 / bounding_box_size * image_width
240
+ top = y1 / bounding_box_size * image_height
241
+
242
+ entities[obj].append([left, bottom, right, top])
243
+
244
+ j += 1
245
+ flag = True
246
+ if flag:
247
+ i += 1
248
+ else:
249
+ integers = re.findall(r'-?\d+', generation)
250
+
251
+ if len(integers) == 4: # it is refer
252
+ mode = 'single'
253
+
254
+ entities = list()
255
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
256
+ left = x0 / bounding_box_size * image_width
257
+ bottom = y0 / bounding_box_size * image_height
258
+ right = x1 / bounding_box_size * image_width
259
+ top = y1 / bounding_box_size * image_height
260
+ entities.append([left, bottom, right, top])
261
+ else:
262
+ # don't detect any valid bbox to visualize
263
+ return None, ''
264
+
265
+ if len(entities) == 0:
266
+ return None, ''
267
+
268
+ if isinstance(image, Image.Image):
269
+ image_h = image.height
270
+ image_w = image.width
271
+ image = np.array(image)
272
+
273
+ elif isinstance(image, str):
274
+ if os.path.exists(image):
275
+ pil_img = Image.open(image).convert("RGB")
276
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
277
+ image_h = pil_img.height
278
+ image_w = pil_img.width
279
+ else:
280
+ raise ValueError(f"invaild image path, {image}")
281
+ elif isinstance(image, torch.Tensor):
282
+
283
+ image_tensor = image.cpu()
284
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
285
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
286
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
287
+ pil_img = T.ToPILImage()(image_tensor)
288
+ image_h = pil_img.height
289
+ image_w = pil_img.width
290
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
291
+ else:
292
+ raise ValueError(f"invaild image format, {type(image)} for {image}")
293
+
294
+ indices = list(range(len(entities)))
295
+
296
+ new_image = image.copy()
297
+
298
+ previous_bboxes = []
299
+ # size of text
300
+ text_size = 0.5
301
+ # thickness of text
302
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
303
+ box_line = 2
304
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
305
+ base_height = int(text_height * 0.675)
306
+ text_offset_original = text_height - base_height
307
+ text_spaces = 2
308
+
309
+ # num_bboxes = sum(len(x[-1]) for x in entities)
310
+ used_colors = colors # random.sample(colors, k=num_bboxes)
311
+
312
+ color_id = -1
313
+ for entity_idx, entity_name in enumerate(entities):
314
+ if mode == 'single' or mode == 'identify':
315
+ bboxes = entity_name
316
+ bboxes = [bboxes]
317
+ else:
318
+ bboxes = entities[entity_name]
319
+ color_id += 1
320
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
321
+ skip_flag = False
322
+ orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm)
323
+
324
+ color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
325
+ new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
326
+
327
+ if mode == 'all':
328
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
329
+
330
+ x1 = orig_x1 - l_o
331
+ y1 = orig_y1 - l_o
332
+
333
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
334
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
335
+ x1 = orig_x1 + r_o
336
+
337
+ # add text background
338
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
339
+ text_line)
340
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
341
+ text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
342
+
343
+ for prev_bbox in previous_bboxes:
344
+ if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
345
+ prev_bbox['phrase'] == entity_name:
346
+ skip_flag = True
347
+ break
348
+ while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
349
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
350
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
351
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
352
+
353
+ if text_bg_y2 >= image_h:
354
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
355
+ text_bg_y2 = image_h
356
+ y1 = image_h
357
+ break
358
+ if not skip_flag:
359
+ alpha = 0.5
360
+ for i in range(text_bg_y1, text_bg_y2):
361
+ for j in range(text_bg_x1, text_bg_x2):
362
+ if i < image_h and j < image_w:
363
+ if j < text_bg_x1 + 1.35 * c_width:
364
+ # original color
365
+ bg_color = color
366
+ else:
367
+ # white
368
+ bg_color = [255, 255, 255]
369
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
370
+ np.uint8)
371
+
372
+ cv2.putText(
373
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
374
+ cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
375
+ )
376
+
377
+ previous_bboxes.append(
378
+ {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
379
+
380
+ if mode == 'all':
381
+ def color_iterator(colors):
382
+ while True:
383
+ for color in colors:
384
+ yield color
385
+
386
+ color_gen = color_iterator(colors)
387
+
388
+ # Add colors to phrases and remove <p></p>
389
+ def colored_phrases(match):
390
+ phrase = match.group(1)
391
+ color = next(color_gen)
392
+ return f'<span style="color:rgb{color}">{phrase}</span>'
393
+
394
+ generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
395
+ generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
396
+ else:
397
+ generation_colored = ''
398
+
399
+ pil_image = Image.fromarray(new_image)
400
+ return pil_image, generation_colored
401
+
402
+
403
+ def gradio_reset(chat_state, img_list):
404
+ if chat_state is not None:
405
+ chat_state.messages = []
406
+ if img_list is not None:
407
+ img_list = []
408
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
409
+ interactive=True), chat_state, img_list
410
+
411
+
412
+ def image_upload_trigger(upload_flag, replace_flag, img_list):
413
+ # set the upload flag to true when receive a new image.
414
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
415
+ upload_flag = 1
416
+ if img_list:
417
+ replace_flag = 1
418
+ return upload_flag, replace_flag
419
+
420
+
421
+ def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
422
+ # set the upload flag to true when receive a new image.
423
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
424
+ upload_flag = 1
425
+ if img_list or replace_flag == 1:
426
+ replace_flag = 1
427
+
428
+ return upload_flag, replace_flag
429
+
430
+
431
+ def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
432
+ print("+++gradio_ask+++")
433
+
434
+ if len(user_message) == 0:
435
+ text_box_show = 'Input should not be empty!'
436
+ else:
437
+ text_box_show = ''
438
+
439
+ print('user_message:', user_message)
440
+ print('chatbot:', chatbot)
441
+ print('chat_state:', chat_state)
442
+
443
+
444
+ if isinstance(gr_img, dict):
445
+ gr_img, mask = gr_img['image'], gr_img['mask']
446
+ else:
447
+ mask = None
448
+
449
+ if '[identify]' in user_message:
450
+ # check if user provide bbox in the text input
451
+ integers = re.findall(r'-?\d+', user_message)
452
+ if len(integers) != 4: # no bbox in text
453
+ bbox = mask2bbox(mask)
454
+ user_message = user_message + bbox
455
+
456
+ if chat_state is None:
457
+ chat_state = CONV_VISION.copy()
458
+
459
+ if upload_flag:
460
+ if replace_flag:
461
+ chat_state = CONV_VISION.copy() # new image, reset everything
462
+ replace_flag = 0
463
+ chatbot = []
464
+ img_list = []
465
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
466
+ upload_flag = 0
467
+
468
+ chat.ask(user_message, chat_state)
469
+ print('user_message: ', user_message)
470
+ print('chat_state: ', chat_state)
471
+
472
+ chatbot = chatbot + [[user_message, None]]
473
+
474
+ if '[identify]' in user_message:
475
+ visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
476
+ if visual_img is not None:
477
+ file_path = save_tmp_img(visual_img)
478
+ chatbot = chatbot + [[(file_path,), None]]
479
+
480
+ return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
481
+
482
+
483
+ def gradio_answer(chatbot, chat_state, img_list, temperature):
484
+ print("--gradio_answer--")
485
+ # print('img_list: ', img_list)
486
+ llm_message = chat.answer(conv=chat_state,
487
+ img_list=img_list,
488
+ temperature=temperature,
489
+ max_new_tokens=500,
490
+ max_length=2000)[0]
491
+ chatbot[-1][1] = llm_message
492
+ print('gradio_answer: ', llm_message)
493
+
494
+ return chatbot, chat_state
495
+
496
+ def process_english_text(text):
497
+ if len(text) < 2:
498
+ return text
499
+ text = text[0].upper() + text[1:]
500
+
501
+ sentences = text.split('. ')
502
+ corrected_sentences = [s.capitalize() for s in sentences]
503
+ text = '. '.join(corrected_sentences)
504
+
505
+ if text.endswith(','):
506
+ text = text[:-1]
507
+ if not text.endswith('.'):
508
+ text += '.'
509
+
510
+ return text
511
+
512
+
513
+ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
514
+ print('---gradio_stream_answer---')
515
+ if len(img_list) > 0:
516
+ if not isinstance(img_list[0], torch.Tensor):
517
+ chat.encode_img(img_list)
518
+ print(chat)
519
+ streamer = chat.stream_answer(conv=chat_state,
520
+ img_list=img_list,
521
+ temperature=temperature,
522
+ max_new_tokens=500,
523
+ max_length=2000)
524
+ output = ''
525
+ print('streamer:', streamer)
526
+ for new_output in streamer:
527
+ escapped = escape_markdown(new_output)
528
+ output += escapped
529
+ chatbot[-1][1] = output
530
+ chatbot[-1][1] = process_english_text(chatbot[-1][1])
531
+ yield chatbot, chat_state
532
+ chat_state.messages[-1][1] = '</s>'
533
+ print('output:', output)
534
+ return chatbot, chat_state
535
+
536
+
537
+ def gradio_visualize(chatbot, gr_img):
538
+ if isinstance(gr_img, dict):
539
+ gr_img, mask = gr_img['image'], gr_img['mask']
540
+
541
+ unescaped = reverse_escape(chatbot[-1][1])
542
+ visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
543
+ if visual_img is not None:
544
+ if len(generation_color):
545
+ chatbot[-1][1] = generation_color
546
+ file_path = save_tmp_img(visual_img)
547
+ chatbot = chatbot + [[None, (file_path,)]]
548
+
549
+ return chatbot
550
+
551
+
552
+ def gradio_taskselect(idx):
553
+ prompt_list = [
554
+ '',
555
+ '[reason] ',
556
+ '[emotion] ',
557
+ '[visual] ',
558
+ '[audio] '
559
+ ]
560
+ instruct_list = [
561
+ '**Hint:** Type in whatever you want',
562
+ '**Hint:** Send the command to multimodal emotion reasoning',
563
+ '**Hint:** Send the command to multimodal emotion recognition',
564
+ '**Hint:** Send the command to generate visual description',
565
+ '**Hint:** Send the command to generate audio description'
566
+ ]
567
+ return prompt_list[idx], instruct_list[idx]
568
+
569
+
570
+
571
+
572
+ chat = Chat(model, vis_processor, device=device)
573
+
574
+ title = """<h1 align="center">Emotion-LLaMA Demo</h1>"""
575
+ description = 'Welcome to Our Emotion-LLaMA Chatbot Demo!'
576
+ article = """<p><a href='https://anonymous.4open.science/r/Emotion-LLaMA'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>"""
577
+
578
+ introduction = '''
579
+ For Abilities Involging Multimodal Emotion Understanding:
580
+ 1. Reason: Click **Send** to generate a multimodal emotion description.
581
+ 2. Emotion: Click **Send** to generate an emotion label.
582
+ 3. Visual: Click **Send** to generate a visual description.
583
+ 4. Audio: Click **Send** to generate an audio description.
584
+ 5. No Tag: Input whatever you want and click **Send** without any tagging.
585
+
586
+ You can also simply chat in free form!
587
+ '''
588
+
589
+ text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False, scale=8)
590
+ with gr.Blocks() as demo:
591
+ gr.Markdown(title)
592
+ # gr.Markdown(description)
593
+ gr.Markdown(article)
594
+
595
+ with gr.Row():
596
+ with gr.Column(scale=0.5):
597
+ # image = gr.Image(type="pil", tool='sketch', brush_radius=20)
598
+ image = gr.Video(sources=["upload", "webcam"])
599
+
600
+ temperature = gr.Slider(
601
+ minimum=0.1,
602
+ maximum=1.5,
603
+ value=0.2,
604
+ step=0.1,
605
+ interactive=True,
606
+ label="Temperature",
607
+ )
608
+
609
+ clear = gr.Button("Restart")
610
+
611
+ gr.Markdown(introduction)
612
+
613
+ with gr.Column():
614
+ chat_state = gr.State(value=None)
615
+ img_list = gr.State(value=[])
616
+ chatbot = gr.Chatbot(label='Emotion-LLaMA')
617
+
618
+ dataset = gr.Dataset(
619
+ components=[gr.Textbox(visible=False)],
620
+ samples=[['No Tag'], ['reason'], ['emotion'], ['visual'], ['audio']],
621
+ type="index",
622
+ label='Task Shortcuts',
623
+ )
624
+ task_inst = gr.Markdown('**Hint:** Upload your video and chat')
625
+ with gr.Row():
626
+ text_input.render()
627
+ send = gr.Button("Send", variant='primary', size='sm', scale=1)
628
+
629
+ upload_flag = gr.State(value=0)
630
+ replace_flag = gr.State(value=0)
631
+ image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
632
+
633
+ with gr.Row():
634
+ with gr.Column():
635
+ gr.Examples(examples=[
636
+ ["examples/samplenew_00004251.mp4", "[detection] face", upload_flag, replace_flag, img_list],
637
+ ["examples/sample_00000338.mp4", "The person in video says: Oh no, my phone and wallet are all in my bag. [emotion] Please determine which emotion label in the video represents: happy, sad, neutral, angry, worried, surprise.", upload_flag, replace_flag, img_list],
638
+ ["examples/sample_00000669.mp4", "The person in video says: Why are you looking at me like this? It's just a woman, so you have to have something to do with me. [emotion] Determine the emotional state shown in the video, choosing from happy, sad, neutral, angry, worried, or surprise.", upload_flag, replace_flag, img_list],
639
+ ["examples/sample_00003462.mp4", "The person in video says: Do you believe that you push me around? [emotion] Assess and label the emotion evident in the video: could it be happy, sad, neutral, angry, worried, surprise?", upload_flag, replace_flag, img_list],
640
+ ["examples/sample_00000727.mp4", "The person in video says: No, this, I have to get up! You, I'm sorry, everyone. I'm sorry, it's from the German side. [emotion] Identify the displayed emotion in the video: is it happy, sad, neutral, angry, worried, or surprise?", upload_flag, replace_flag, img_list],
641
+ ["examples/samplenew_00061200.mp4", "The person in video says: Me: I'm not going in anymore, scared. [emotion] Identify the displayed emotion in the video: is it happy, sad, neutral, angry, fear, contempt, doubt, worried, or surprise?", upload_flag, replace_flag, img_list],
642
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
643
+ outputs=[upload_flag, replace_flag])
644
+ with gr.Column():
645
+ gr.Examples(examples=[
646
+ ["examples/samplenew_00051251.mp4", "In what state is the person in the video, say the following: \"Do you really think so?\"", upload_flag, replace_flag, img_list],
647
+ ["examples/sample_00004735.mp4", "[visual] What are the emotions of the woman in the video?", upload_flag, replace_flag, img_list],
648
+ ["examples/sample_00002422.mp4", "[audio] Analyze the speaker's voice in the video.", upload_flag, replace_flag, img_list],
649
+ ["examples/sample_00001073.mp4", "The person in video says: Make him different from before. I like the way you are now. [reason] Please analyze all the clues in the video and reason out the emotional label of the person in the video.", upload_flag, replace_flag, img_list],
650
+ ["examples/sample_00004671.mp4", "The person in video says: Won't you? Impossible! Fan Xiaomei is not such a person. [reason] What are the facial expressions and vocal tone used in the video? What is the intended meaning behind his words? Which emotion does this reflect?", upload_flag, replace_flag, img_list],
651
+ ["examples/sample_00005854.mp4", "The person in video says: Bastard! Boss, you don't choose, you prefer. [reason] Please integrate information from various modalities to infer the emotional category of the person in the video.", upload_flag, replace_flag, img_list],
652
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
653
+ outputs=[upload_flag, replace_flag])
654
+
655
+ dataset.click(
656
+ gradio_taskselect,
657
+ inputs=[dataset],
658
+ outputs=[text_input, task_inst],
659
+ show_progress="hidden",
660
+ postprocess=False,
661
+ queue=False,
662
+ )
663
+
664
+ text_input.submit(
665
+ gradio_ask,
666
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
667
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
668
+ ).success(
669
+ gradio_stream_answer,
670
+ [chatbot, chat_state, img_list, temperature],
671
+ [chatbot, chat_state]
672
+ ).success(
673
+ gradio_visualize,
674
+ [chatbot, image],
675
+ [chatbot],
676
+ queue=False,
677
+ )
678
+
679
+ send.click(
680
+ gradio_ask,
681
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
682
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
683
+ ).success(
684
+ gradio_stream_answer,
685
+ [chatbot, chat_state, img_list, temperature],
686
+ [chatbot, chat_state]
687
+ ).success(
688
+ gradio_visualize,
689
+ [chatbot, image],
690
+ [chatbot],
691
+ queue=False,
692
+ )
693
+
694
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
695
 
696
+ demo.launch(share=True, enable_queue=True)