Prince53 commited on
Commit
5b4b7ad
·
verified ·
1 Parent(s): bdb96de

Delete geochat_demo.py

Browse files
Files changed (1) hide show
  1. geochat_demo.py +0 -706
geochat_demo.py DELETED
@@ -1,706 +0,0 @@
1
- import argparse
2
- import os
3
- import random
4
- from collections import defaultdict
5
-
6
- import cv2
7
- import re
8
- import math
9
- import numpy as np
10
- from PIL import Image
11
- import torch
12
- import html
13
- import gradio as gr
14
-
15
- import torchvision.transforms as T
16
- import torch.backends.cudnn as cudnn
17
-
18
- from geochat.conversation import conv_templates, Chat
19
- from geochat.model.builder import load_pretrained_model
20
- from geochat.mm_utils import get_model_name_from_path
21
-
22
-
23
- def parse_args():
24
- parser = argparse.ArgumentParser(description="Demo")
25
- # parser = argparse.ArgumentParser()
26
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
27
- parser.add_argument("--model-base", type=str, default=None)
28
- parser.add_argument("--gpu-id", type=str,default=0)
29
- parser.add_argument("--device", type=str, default="cuda")
30
- parser.add_argument("--conv-mode", type=str, default=None)
31
- parser.add_argument("--max-new-tokens", type=int, default=300)
32
- parser.add_argument("--load-8bit", action="store_true")
33
- parser.add_argument("--load-4bit", action="store_true")
34
- parser.add_argument("--debug", action="store_true")
35
- parser.add_argument("--image-aspect-ratio", type=str, default='pad')
36
- # args = parser.parse_args()
37
- args = parser.parse_args()
38
- return args
39
-
40
-
41
- random.seed(42)
42
- np.random.seed(42)
43
- torch.manual_seed(42)
44
-
45
- cudnn.benchmark = False
46
- cudnn.deterministic = True
47
-
48
- print('Initializing Chat')
49
- args = parse_args()
50
- # cfg = Config(args)
51
-
52
- model_name = get_model_name_from_path(args.model_path)
53
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
54
-
55
- device = 'cuda:{}'.format(args.gpu_id)
56
-
57
- # model_config = cfg.model_cfg
58
- # model_config.device_8bit = args.gpu_id
59
- # model_cls = registry.get_model_class(model_config.arch)
60
- # model = model_cls.from_config(model_config).to(device)
61
- bounding_box_size = 100
62
-
63
- # vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
64
- # vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
65
-
66
- model = model.eval()
67
-
68
- CONV_VISION = conv_templates['llava_v1'].copy()
69
-
70
- def bbox_and_angle_to_polygon(x1, y1, x2, y2, a):
71
- # Calculate center coordinates
72
- x_ctr = (x1 + x2) / 2
73
- y_ctr = (y1 + y2) / 2
74
-
75
- # Calculate width and height
76
- w = abs(x2 - x1)
77
- h = abs(y2 - y1)
78
-
79
- # Calculate the angle in radians
80
- angle_rad = math.radians(a)
81
-
82
- # Calculate coordinates of the four corners of the rotated bounding box
83
- cos_a = math.cos(angle_rad)
84
- sin_a = math.sin(angle_rad)
85
-
86
- x1_rot = cos_a * (-w / 2) - sin_a * (-h / 2) + x_ctr
87
- y1_rot = sin_a * (-w / 2) + cos_a * (-h / 2) + y_ctr
88
-
89
- x2_rot = cos_a * (w / 2) - sin_a * (-h / 2) + x_ctr
90
- y2_rot = sin_a * (w / 2) + cos_a * (-h / 2) + y_ctr
91
-
92
- x3_rot = cos_a * (w / 2) - sin_a * (h / 2) + x_ctr
93
- y3_rot = sin_a * (w / 2) + cos_a * (h / 2) + y_ctr
94
-
95
- x4_rot = cos_a * (-w / 2) - sin_a * (h / 2) + x_ctr
96
- y4_rot = sin_a * (-w / 2) + cos_a * (h / 2) + y_ctr
97
-
98
- # Return the polygon coordinates
99
- polygon_coords = np.array((x1_rot, y1_rot, x2_rot, y2_rot, x3_rot, y3_rot, x4_rot, y4_rot))
100
-
101
- return polygon_coords
102
-
103
- def rotate_bbox(top_right, bottom_left, angle_degrees):
104
- # Convert angle to radians
105
- angle_radians = np.radians(angle_degrees)
106
-
107
- # Calculate the center of the rectangle
108
- center = ((top_right[0] + bottom_left[0]) / 2, (top_right[1] + bottom_left[1]) / 2)
109
-
110
- # Calculate the width and height of the rectangle
111
- width = top_right[0] - bottom_left[0]
112
- height = top_right[1] - bottom_left[1]
113
-
114
- # Create a rotation matrix
115
- rotation_matrix = cv2.getRotationMatrix2D(center, angle_degrees, 1)
116
-
117
- # Create an array of the rectangle corners
118
- rectangle_points = np.array([[bottom_left[0], bottom_left[1]],
119
- [top_right[0], bottom_left[1]],
120
- [top_right[0], top_right[1]],
121
- [bottom_left[0], top_right[1]]], dtype=np.float32)
122
-
123
- # Rotate the rectangle points
124
- rotated_rectangle = cv2.transform(np.array([rectangle_points]), rotation_matrix)[0]
125
-
126
- return rotated_rectangle
127
- def extract_substrings(string):
128
- # first check if there is no-finished bracket
129
- index = string.rfind('}')
130
- if index != -1:
131
- string = string[:index + 1]
132
-
133
- pattern = r'<p>(.*?)\}(?!<)'
134
- matches = re.findall(pattern, string)
135
- substrings = [match for match in matches]
136
-
137
- return substrings
138
-
139
-
140
- def is_overlapping(rect1, rect2):
141
- x1, y1, x2, y2 = rect1
142
- x3, y3, x4, y4 = rect2
143
- return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
144
-
145
-
146
- def computeIoU(bbox1, bbox2):
147
- x1, y1, x2, y2 = bbox1
148
- x3, y3, x4, y4 = bbox2
149
- intersection_x1 = max(x1, x3)
150
- intersection_y1 = max(y1, y3)
151
- intersection_x2 = min(x2, x4)
152
- intersection_y2 = min(y2, y4)
153
- intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
154
- bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
155
- bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
156
- union_area = bbox1_area + bbox2_area - intersection_area
157
- iou = intersection_area / union_area
158
- return iou
159
-
160
-
161
- def save_tmp_img(visual_img):
162
- file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
163
- file_path = "/tmp/gradio" + file_name
164
- visual_img.save(file_path)
165
- return file_path
166
-
167
-
168
- def mask2bbox(mask):
169
- if mask is None:
170
- return ''
171
- mask = mask.resize([100, 100], resample=Image.NEAREST)
172
- mask = np.array(mask)[:, :, 0]
173
-
174
- rows = np.any(mask, axis=1)
175
- cols = np.any(mask, axis=0)
176
-
177
- if rows.sum():
178
- # Get the top, bottom, left, and right boundaries
179
- rmin, rmax = np.where(rows)[0][[0, -1]]
180
- cmin, cmax = np.where(cols)[0][[0, -1]]
181
- bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
182
- else:
183
- bbox = ''
184
-
185
- return bbox
186
-
187
-
188
- def escape_markdown(text):
189
- # List of Markdown special characters that need to be escaped
190
- md_chars = ['<', '>']
191
-
192
- # Escape each special character
193
- for char in md_chars:
194
- text = text.replace(char, '\\' + char)
195
-
196
- return text
197
-
198
-
199
- def reverse_escape(text):
200
- md_chars = ['\\<', '\\>']
201
-
202
- for char in md_chars:
203
- text = text.replace(char, char[1:])
204
-
205
- return text
206
-
207
-
208
- colors = [
209
- (255, 0, 0),
210
- (0, 255, 0),
211
- (0, 0, 255),
212
- (210, 210, 0),
213
- (255, 0, 255),
214
- (0, 255, 255),
215
- (114, 128, 250),
216
- (0, 165, 255),
217
- (0, 128, 0),
218
- (144, 238, 144),
219
- (238, 238, 175),
220
- (255, 191, 0),
221
- (0, 128, 0),
222
- (226, 43, 138),
223
- (255, 0, 255),
224
- (0, 215, 255),
225
- ]
226
-
227
- color_map = {
228
- f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
229
- color_id, color in enumerate(colors)
230
- }
231
-
232
- used_colors = colors
233
-
234
-
235
- def visualize_all_bbox_together(image, generation):
236
- if image is None:
237
- return None, ''
238
-
239
- generation = html.unescape(generation)
240
-
241
- image_width, image_height = image.size
242
- image = image.resize([500, int(500 / image_width * image_height)])
243
- image_width, image_height = image.size
244
-
245
- string_list = extract_substrings(generation)
246
- if string_list: # it is grounding or detection
247
- mode = 'all'
248
- entities = defaultdict(list)
249
- i = 0
250
- j = 0
251
- for string in string_list:
252
- try:
253
- obj, string = string.split('</p>')
254
- except ValueError:
255
- print('wrong string: ', string)
256
- continue
257
- if "}{" in string:
258
- string=string.replace("}{","}<delim>{")
259
- bbox_list = string.split('<delim>')
260
- flag = False
261
- for bbox_string in bbox_list:
262
- integers = re.findall(r'-?\d+', bbox_string)
263
- if len(integers)==4:
264
- angle=0
265
- else:
266
- angle=integers[4]
267
- integers=integers[:-1]
268
-
269
- if len(integers) == 4:
270
- x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
271
- left = x0 / bounding_box_size * image_width
272
- bottom = y0 / bounding_box_size * image_height
273
- right = x1 / bounding_box_size * image_width
274
- top = y1 / bounding_box_size * image_height
275
-
276
- entities[obj].append([left, bottom, right, top,angle])
277
-
278
- j += 1
279
- flag = True
280
- if flag:
281
- i += 1
282
- else:
283
- integers = re.findall(r'-?\d+', generation)
284
- # if len(integers)==4:
285
- angle=0
286
- # else:
287
- # angle=integers[4]
288
- integers=integers[:-1]
289
- if len(integers) == 4: # it is refer
290
- mode = 'single'
291
-
292
- entities = list()
293
- x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
294
- left = x0 / bounding_box_size * image_width
295
- bottom = y0 / bounding_box_size * image_height
296
- right = x1 / bounding_box_size * image_width
297
- top = y1 / bounding_box_size * image_height
298
- entities.append([left, bottom, right, top,angle])
299
- else:
300
- # don't detect any valid bbox to visualize
301
- return None, ''
302
-
303
- if len(entities) == 0:
304
- return None, ''
305
-
306
- if isinstance(image, Image.Image):
307
- image_h = image.height
308
- image_w = image.width
309
- image = np.array(image)
310
-
311
- elif isinstance(image, str):
312
- if os.path.exists(image):
313
- pil_img = Image.open(image).convert("RGB")
314
- image = np.array(pil_img)[:, :, [2, 1, 0]]
315
- image_h = pil_img.height
316
- image_w = pil_img.width
317
- else:
318
- raise ValueError(f"invaild image path, {image}")
319
- elif isinstance(image, torch.Tensor):
320
-
321
- image_tensor = image.cpu()
322
- reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
323
- reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
324
- image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
325
- pil_img = T.ToPILImage()(image_tensor)
326
- image_h = pil_img.height
327
- image_w = pil_img.width
328
- image = np.array(pil_img)[:, :, [2, 1, 0]]
329
- else:
330
- raise ValueError(f"invalid image format, {type(image)} for {image}")
331
-
332
- indices = list(range(len(entities)))
333
-
334
- new_image = image.copy()
335
-
336
- previous_bboxes = []
337
- # size of text
338
- text_size = 0.4
339
- # thickness of text
340
- text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
341
- box_line = 2
342
- (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
343
- base_height = int(text_height * 0.675)
344
- text_offset_original = text_height - base_height
345
- text_spaces = 2
346
-
347
- # num_bboxes = sum(len(x[-1]) for x in entities)
348
- used_colors = colors # random.sample(colors, k=num_bboxes)
349
-
350
- color_id = -1
351
- for entity_idx, entity_name in enumerate(entities):
352
- if mode == 'single' or mode == 'identify':
353
- bboxes = entity_name
354
- bboxes = [bboxes]
355
- else:
356
- bboxes = entities[entity_name]
357
- color_id += 1
358
- for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm,angle) in enumerate(bboxes):
359
- skip_flag = False
360
- orig_x1, orig_y1, orig_x2, orig_y2,angle = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm), int(angle)
361
-
362
- color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
363
- top_right=(orig_x1,orig_y1)
364
- bottom_left=(orig_x2,orig_y2)
365
- angle=angle
366
- rotated_bbox = rotate_bbox(top_right, bottom_left, angle)
367
- new_image=cv2.polylines(new_image, [rotated_bbox.astype(np.int32)], isClosed=True,thickness=2, color=color)
368
-
369
- # new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
370
-
371
- if mode == 'all':
372
- l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
373
-
374
- x1 = orig_x1 - l_o
375
- y1 = orig_y1 - l_o
376
-
377
- if y1 < text_height + text_offset_original + 2 * text_spaces:
378
- y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
379
- x1 = orig_x1 + r_o
380
-
381
- # add text background
382
- (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
383
- text_line)
384
- text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
385
- text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
386
-
387
- for prev_bbox in previous_bboxes:
388
- if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
389
- prev_bbox['phrase'] == entity_name:
390
- skip_flag = True
391
- break
392
- while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
393
- text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
394
- text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
395
- y1 += (text_height + text_offset_original + 2 * text_spaces)
396
-
397
- if text_bg_y2 >= image_h:
398
- text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
399
- text_bg_y2 = image_h
400
- y1 = image_h
401
- break
402
- if not skip_flag:
403
- alpha = 0.5
404
- for i in range(text_bg_y1, text_bg_y2):
405
- for j in range(text_bg_x1, text_bg_x2):
406
- if i < image_h and j < image_w:
407
- if j < text_bg_x1 + 1.35 * c_width:
408
- # original color
409
- bg_color = color
410
- else:
411
- # white
412
- bg_color = [255, 255, 255]
413
- new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
414
- np.uint8)
415
-
416
- cv2.putText(
417
- new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
418
- cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
419
- )
420
-
421
- previous_bboxes.append(
422
- {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
423
-
424
- if mode == 'all':
425
- def color_iterator(colors):
426
- while True:
427
- for color in colors:
428
- yield color
429
-
430
- color_gen = color_iterator(colors)
431
-
432
- # Add colors to phrases and remove <p></p>
433
- def colored_phrases(match):
434
- phrase = match.group(1)
435
- color = next(color_gen)
436
- return f'<span style="color:rgb{color}">{phrase}</span>'
437
-
438
- generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
439
- generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
440
- else:
441
- generation_colored = ''
442
-
443
- pil_image = Image.fromarray(new_image)
444
- return pil_image, generation_colored
445
-
446
-
447
- def gradio_reset(chat_state, img_list):
448
- if chat_state is not None:
449
- chat_state.messages = []
450
- if img_list is not None:
451
- img_list = []
452
- return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
453
- interactive=True), chat_state, img_list
454
-
455
-
456
- def image_upload_trigger(upload_flag, replace_flag, img_list):
457
- # set the upload flag to true when receive a new image.
458
- # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
459
- upload_flag = 1
460
- if img_list:
461
- replace_flag = 1
462
- return upload_flag, replace_flag
463
-
464
-
465
- def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
466
- # set the upload flag to true when receive a new image.
467
- # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
468
- upload_flag = 1
469
- if img_list or replace_flag == 1:
470
- replace_flag = 1
471
-
472
- return upload_flag, replace_flag
473
-
474
-
475
- def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
476
- if len(user_message) == 0:
477
- text_box_show = 'Input should not be empty!'
478
- else:
479
- text_box_show = ''
480
-
481
- if isinstance(gr_img, dict):
482
- gr_img, mask = gr_img['image'], gr_img['mask']
483
- else:
484
- mask = None
485
-
486
- if '[identify]' in user_message:
487
- # check if user provide bbox in the text input
488
- integers = re.findall(r'-?\d+', user_message)
489
- if len(integers) != 4: # no bbox in text
490
- bbox = mask2bbox(mask)
491
- user_message = user_message + bbox
492
-
493
- if chat_state is None:
494
- chat_state = CONV_VISION.copy()
495
-
496
- if upload_flag:
497
- if replace_flag:
498
- chat_state = CONV_VISION.copy() # new image, reset everything
499
- replace_flag = 0
500
- chatbot = []
501
- img_list = []
502
- llm_message = chat.upload_img(gr_img, chat_state, img_list)
503
- upload_flag = 0
504
-
505
- chat.ask(user_message, chat_state)
506
-
507
- chatbot = chatbot + [[user_message, None]]
508
-
509
- if '[identify]' in user_message:
510
- visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
511
- if visual_img is not None:
512
- file_path = save_tmp_img(visual_img)
513
- chatbot = chatbot + [[(file_path,), None]]
514
-
515
- return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
516
-
517
-
518
- # def gradio_answer(chatbot, chat_state, img_list, temperature):
519
- # llm_message = chat.answer(conv=chat_state,
520
- # img_list=img_list,
521
- # temperature=temperature,
522
- # max_new_tokens=500,
523
- # max_length=2000)[0]
524
- # chatbot[-1][1] = llm_message
525
- # return chatbot, chat_state
526
-
527
-
528
- def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
529
- if len(img_list) > 0:
530
- if not isinstance(img_list[0], torch.Tensor):
531
- chat.encode_img(img_list)
532
- streamer = chat.stream_answer(conv=chat_state,
533
- img_list=img_list,
534
- temperature=temperature,
535
- max_new_tokens=500,
536
- max_length=2000)
537
- # chatbot[-1][1] = output
538
- # chat_state.messages[-1][1] = '</s>'
539
-
540
- output = ''
541
- for new_output in streamer:
542
- # print(new_output)
543
- output=output+new_output
544
- print(output)
545
- # if "{" in output:
546
- # chatbot[-1][1]="Grounding and referring expression is still under work."
547
- # else:
548
- output = escape_markdown(output)
549
- # output += escapped
550
- chatbot[-1][1] = output
551
- yield chatbot, chat_state
552
- chat_state.messages[-1][1] = '</s>'
553
- return chatbot, chat_state
554
-
555
-
556
- def gradio_visualize(chatbot, gr_img):
557
- if isinstance(gr_img, dict):
558
- gr_img, mask = gr_img['image'], gr_img['mask']
559
-
560
- unescaped = reverse_escape(chatbot[-1][1])
561
- visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
562
- if visual_img is not None:
563
- if len(generation_color):
564
- chatbot[-1][1] = generation_color
565
- file_path = save_tmp_img(visual_img)
566
- chatbot = chatbot + [[None, (file_path,)]]
567
-
568
- return chatbot
569
-
570
-
571
- def gradio_taskselect(idx):
572
- prompt_list = [
573
- '',
574
- 'Classify the image in the following classes: ',
575
- '[identify] what is this ',
576
- ]
577
- instruct_list = [
578
- '**Hint:** Type in whatever you want',
579
- '**Hint:** Type in the classes you want the model to classify in',
580
- '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
581
- ]
582
- return prompt_list[idx], instruct_list[idx]
583
-
584
-
585
-
586
-
587
- chat = Chat(model, image_processor,tokenizer, device=device)
588
-
589
-
590
- title = """<h1 align="center">GeoChat Demo</h1>"""
591
- description = 'Welcome to Our GeoChat Chatbot Demo!'
592
- article = """<div style="display: flex;"><p style="display: inline-block;"><a href='https://mbzuai-oryx.github.io/GeoChat'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p style="display: inline-block;"><a href='https://arxiv.org/abs/2311.15826'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p><p style="display: inline-block;"><a href='https://github.com/mbzuai-oryx/GeoChat/tree/main'><img src='https://img.shields.io/badge/GitHub-Repo-blue'></a></p><p style="display: inline-block;"><a href='https://youtu.be/KOKtkkKpNDk?feature=shared'><img src='https://img.shields.io/badge/YouTube-Video-red'></a></p></div>"""
593
- # article = """<p><a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>"""
594
-
595
- introduction = '''
596
- 1. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
597
- 2. No Tag: Input whatever you want and CLICK **Send** without any tagging
598
-
599
- You can also simply chat in free form!
600
- '''
601
-
602
-
603
- text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
604
- scale=12)
605
- with gr.Blocks() as demo:
606
- gr.Markdown(title)
607
- # gr.Markdown(description)
608
- gr.Markdown(article)
609
-
610
- with gr.Row():
611
- with gr.Column(scale=0.5):
612
- image = gr.Image(type="pil", tool='sketch', brush_radius=20)
613
-
614
- temperature = gr.Slider(
615
- minimum=0.1,
616
- maximum=1.5,
617
- value=0.6,
618
- step=0.1,
619
- interactive=True,
620
- label="Temperature",
621
- )
622
-
623
- clear = gr.Button("Restart")
624
-
625
- gr.Markdown(introduction)
626
-
627
- with gr.Column():
628
- chat_state = gr.State(value=None)
629
- img_list = gr.State(value=[])
630
- chatbot = gr.Chatbot(label='GeoChat')
631
-
632
- dataset = gr.Dataset(
633
- components=[gr.Textbox(visible=False)],
634
- samples=[['No Tag'], ['Scene Classification'],['Identify']],
635
- type="index",
636
- label='Task Shortcuts',
637
- )
638
- task_inst = gr.Markdown('**Hint:** Upload your image and chat')
639
- with gr.Row():
640
- text_input.render()
641
- send = gr.Button("Send", variant='primary', size='sm', scale=1)
642
-
643
- upload_flag = gr.State(value=0)
644
- replace_flag = gr.State(value=0)
645
- image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
646
-
647
- with gr.Row():
648
- with gr.Column():
649
- gr.Examples(examples=[
650
- ["demo_images/train_2956_0001.png", "Where are the airplanes located and what is their type?", upload_flag, replace_flag,
651
- img_list],
652
- ["demo_images/7292.JPG", "How many buildings are flooded?", upload_flag,
653
- replace_flag, img_list],
654
- ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
655
- outputs=[upload_flag, replace_flag])
656
- with gr.Column():
657
- gr.Examples(examples=[
658
- ["demo_images/church_183.png", "Classify the image in the following classes: Church, Beach, Dense Residential, Storage Tanks.",
659
- upload_flag, replace_flag, img_list],
660
- ["demo_images/04444.png", "[identify] what is this {<8><26><22><37>}", upload_flag,
661
- replace_flag, img_list],
662
- ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
663
- outputs=[upload_flag, replace_flag])
664
-
665
- dataset.click(
666
- gradio_taskselect,
667
- inputs=[dataset],
668
- outputs=[text_input, task_inst],
669
- show_progress="hidden",
670
- postprocess=False,
671
- queue=False,
672
- )
673
-
674
- text_input.submit(
675
- gradio_ask,
676
- [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
677
- [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
678
- ).success(
679
- gradio_stream_answer,
680
- [chatbot, chat_state, img_list, temperature],
681
- [chatbot, chat_state]
682
- ).success(
683
- gradio_visualize,
684
- [chatbot, image],
685
- [chatbot],
686
- queue=False,
687
- )
688
-
689
- send.click(
690
- gradio_ask,
691
- [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
692
- [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
693
- ).success(
694
- gradio_stream_answer,
695
- [chatbot, chat_state, img_list, temperature],
696
- [chatbot, chat_state]
697
- ).success(
698
- gradio_visualize,
699
- [chatbot, image],
700
- [chatbot],
701
- queue=False,
702
- )
703
-
704
- clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
705
-
706
- demo.launch(share=True, enable_queue=True,server_name='0.0.0.0')