Prince53 commited on
Commit
e8b3653
·
verified ·
1 Parent(s): d2d61b9

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +35 -0
  2. geochat_demo.py +706 -0
  3. pyproject.toml +39 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+
6
+ # Load model
7
+ class MyModel(torch.nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ # Define layers here
11
+
12
+ def forward(self, x):
13
+ # Forward pass
14
+ return x
15
+
16
+ model = MyModel()
17
+ model.load_state_dict(torch.load("model.pth"))
18
+ model.eval()
19
+
20
+ # Define image preprocessing
21
+ transform = transforms.Compose([
22
+ transforms.Resize((224, 224)),
23
+ transforms.ToTensor(),
24
+ ])
25
+
26
+ # Define prediction function
27
+ def predict(image):
28
+ image = transform(image).unsqueeze(0) # Add batch dimension
29
+ with torch.no_grad():
30
+ output = model(image)
31
+ return output.numpy().tolist()
32
+
33
+ # Create Gradio interface
34
+ iface = gr.Interface(fn=predict, inputs=gr.Image(), outputs="json")
35
+ iface.launch()
geochat_demo.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ from collections import defaultdict
5
+
6
+ import cv2
7
+ import re
8
+ import math
9
+ import numpy as np
10
+ from PIL import Image
11
+ import torch
12
+ import html
13
+ import gradio as gr
14
+
15
+ import torchvision.transforms as T
16
+ import torch.backends.cudnn as cudnn
17
+
18
+ from geochat.conversation import conv_templates, Chat
19
+ from geochat.model.builder import load_pretrained_model
20
+ from geochat.mm_utils import get_model_name_from_path
21
+
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(description="Demo")
25
+ # parser = argparse.ArgumentParser()
26
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
27
+ parser.add_argument("--model-base", type=str, default=None)
28
+ parser.add_argument("--gpu-id", type=str,default=0)
29
+ parser.add_argument("--device", type=str, default="cuda")
30
+ parser.add_argument("--conv-mode", type=str, default=None)
31
+ parser.add_argument("--max-new-tokens", type=int, default=300)
32
+ parser.add_argument("--load-8bit", action="store_true")
33
+ parser.add_argument("--load-4bit", action="store_true")
34
+ parser.add_argument("--debug", action="store_true")
35
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
36
+ # args = parser.parse_args()
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ random.seed(42)
42
+ np.random.seed(42)
43
+ torch.manual_seed(42)
44
+
45
+ cudnn.benchmark = False
46
+ cudnn.deterministic = True
47
+
48
+ print('Initializing Chat')
49
+ args = parse_args()
50
+ # cfg = Config(args)
51
+
52
+ model_name = get_model_name_from_path(args.model_path)
53
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
54
+
55
+ device = 'cuda:{}'.format(args.gpu_id)
56
+
57
+ # model_config = cfg.model_cfg
58
+ # model_config.device_8bit = args.gpu_id
59
+ # model_cls = registry.get_model_class(model_config.arch)
60
+ # model = model_cls.from_config(model_config).to(device)
61
+ bounding_box_size = 100
62
+
63
+ # vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
64
+ # vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
65
+
66
+ model = model.eval()
67
+
68
+ CONV_VISION = conv_templates['llava_v1'].copy()
69
+
70
+ def bbox_and_angle_to_polygon(x1, y1, x2, y2, a):
71
+ # Calculate center coordinates
72
+ x_ctr = (x1 + x2) / 2
73
+ y_ctr = (y1 + y2) / 2
74
+
75
+ # Calculate width and height
76
+ w = abs(x2 - x1)
77
+ h = abs(y2 - y1)
78
+
79
+ # Calculate the angle in radians
80
+ angle_rad = math.radians(a)
81
+
82
+ # Calculate coordinates of the four corners of the rotated bounding box
83
+ cos_a = math.cos(angle_rad)
84
+ sin_a = math.sin(angle_rad)
85
+
86
+ x1_rot = cos_a * (-w / 2) - sin_a * (-h / 2) + x_ctr
87
+ y1_rot = sin_a * (-w / 2) + cos_a * (-h / 2) + y_ctr
88
+
89
+ x2_rot = cos_a * (w / 2) - sin_a * (-h / 2) + x_ctr
90
+ y2_rot = sin_a * (w / 2) + cos_a * (-h / 2) + y_ctr
91
+
92
+ x3_rot = cos_a * (w / 2) - sin_a * (h / 2) + x_ctr
93
+ y3_rot = sin_a * (w / 2) + cos_a * (h / 2) + y_ctr
94
+
95
+ x4_rot = cos_a * (-w / 2) - sin_a * (h / 2) + x_ctr
96
+ y4_rot = sin_a * (-w / 2) + cos_a * (h / 2) + y_ctr
97
+
98
+ # Return the polygon coordinates
99
+ polygon_coords = np.array((x1_rot, y1_rot, x2_rot, y2_rot, x3_rot, y3_rot, x4_rot, y4_rot))
100
+
101
+ return polygon_coords
102
+
103
+ def rotate_bbox(top_right, bottom_left, angle_degrees):
104
+ # Convert angle to radians
105
+ angle_radians = np.radians(angle_degrees)
106
+
107
+ # Calculate the center of the rectangle
108
+ center = ((top_right[0] + bottom_left[0]) / 2, (top_right[1] + bottom_left[1]) / 2)
109
+
110
+ # Calculate the width and height of the rectangle
111
+ width = top_right[0] - bottom_left[0]
112
+ height = top_right[1] - bottom_left[1]
113
+
114
+ # Create a rotation matrix
115
+ rotation_matrix = cv2.getRotationMatrix2D(center, angle_degrees, 1)
116
+
117
+ # Create an array of the rectangle corners
118
+ rectangle_points = np.array([[bottom_left[0], bottom_left[1]],
119
+ [top_right[0], bottom_left[1]],
120
+ [top_right[0], top_right[1]],
121
+ [bottom_left[0], top_right[1]]], dtype=np.float32)
122
+
123
+ # Rotate the rectangle points
124
+ rotated_rectangle = cv2.transform(np.array([rectangle_points]), rotation_matrix)[0]
125
+
126
+ return rotated_rectangle
127
+ def extract_substrings(string):
128
+ # first check if there is no-finished bracket
129
+ index = string.rfind('}')
130
+ if index != -1:
131
+ string = string[:index + 1]
132
+
133
+ pattern = r'<p>(.*?)\}(?!<)'
134
+ matches = re.findall(pattern, string)
135
+ substrings = [match for match in matches]
136
+
137
+ return substrings
138
+
139
+
140
+ def is_overlapping(rect1, rect2):
141
+ x1, y1, x2, y2 = rect1
142
+ x3, y3, x4, y4 = rect2
143
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
144
+
145
+
146
+ def computeIoU(bbox1, bbox2):
147
+ x1, y1, x2, y2 = bbox1
148
+ x3, y3, x4, y4 = bbox2
149
+ intersection_x1 = max(x1, x3)
150
+ intersection_y1 = max(y1, y3)
151
+ intersection_x2 = min(x2, x4)
152
+ intersection_y2 = min(y2, y4)
153
+ intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
154
+ bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
155
+ bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
156
+ union_area = bbox1_area + bbox2_area - intersection_area
157
+ iou = intersection_area / union_area
158
+ return iou
159
+
160
+
161
+ def save_tmp_img(visual_img):
162
+ file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
163
+ file_path = "/tmp/gradio" + file_name
164
+ visual_img.save(file_path)
165
+ return file_path
166
+
167
+
168
+ def mask2bbox(mask):
169
+ if mask is None:
170
+ return ''
171
+ mask = mask.resize([100, 100], resample=Image.NEAREST)
172
+ mask = np.array(mask)[:, :, 0]
173
+
174
+ rows = np.any(mask, axis=1)
175
+ cols = np.any(mask, axis=0)
176
+
177
+ if rows.sum():
178
+ # Get the top, bottom, left, and right boundaries
179
+ rmin, rmax = np.where(rows)[0][[0, -1]]
180
+ cmin, cmax = np.where(cols)[0][[0, -1]]
181
+ bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
182
+ else:
183
+ bbox = ''
184
+
185
+ return bbox
186
+
187
+
188
+ def escape_markdown(text):
189
+ # List of Markdown special characters that need to be escaped
190
+ md_chars = ['<', '>']
191
+
192
+ # Escape each special character
193
+ for char in md_chars:
194
+ text = text.replace(char, '\\' + char)
195
+
196
+ return text
197
+
198
+
199
+ def reverse_escape(text):
200
+ md_chars = ['\\<', '\\>']
201
+
202
+ for char in md_chars:
203
+ text = text.replace(char, char[1:])
204
+
205
+ return text
206
+
207
+
208
+ colors = [
209
+ (255, 0, 0),
210
+ (0, 255, 0),
211
+ (0, 0, 255),
212
+ (210, 210, 0),
213
+ (255, 0, 255),
214
+ (0, 255, 255),
215
+ (114, 128, 250),
216
+ (0, 165, 255),
217
+ (0, 128, 0),
218
+ (144, 238, 144),
219
+ (238, 238, 175),
220
+ (255, 191, 0),
221
+ (0, 128, 0),
222
+ (226, 43, 138),
223
+ (255, 0, 255),
224
+ (0, 215, 255),
225
+ ]
226
+
227
+ color_map = {
228
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
229
+ color_id, color in enumerate(colors)
230
+ }
231
+
232
+ used_colors = colors
233
+
234
+
235
+ def visualize_all_bbox_together(image, generation):
236
+ if image is None:
237
+ return None, ''
238
+
239
+ generation = html.unescape(generation)
240
+
241
+ image_width, image_height = image.size
242
+ image = image.resize([500, int(500 / image_width * image_height)])
243
+ image_width, image_height = image.size
244
+
245
+ string_list = extract_substrings(generation)
246
+ if string_list: # it is grounding or detection
247
+ mode = 'all'
248
+ entities = defaultdict(list)
249
+ i = 0
250
+ j = 0
251
+ for string in string_list:
252
+ try:
253
+ obj, string = string.split('</p>')
254
+ except ValueError:
255
+ print('wrong string: ', string)
256
+ continue
257
+ if "}{" in string:
258
+ string=string.replace("}{","}<delim>{")
259
+ bbox_list = string.split('<delim>')
260
+ flag = False
261
+ for bbox_string in bbox_list:
262
+ integers = re.findall(r'-?\d+', bbox_string)
263
+ if len(integers)==4:
264
+ angle=0
265
+ else:
266
+ angle=integers[4]
267
+ integers=integers[:-1]
268
+
269
+ if len(integers) == 4:
270
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
271
+ left = x0 / bounding_box_size * image_width
272
+ bottom = y0 / bounding_box_size * image_height
273
+ right = x1 / bounding_box_size * image_width
274
+ top = y1 / bounding_box_size * image_height
275
+
276
+ entities[obj].append([left, bottom, right, top,angle])
277
+
278
+ j += 1
279
+ flag = True
280
+ if flag:
281
+ i += 1
282
+ else:
283
+ integers = re.findall(r'-?\d+', generation)
284
+ # if len(integers)==4:
285
+ angle=0
286
+ # else:
287
+ # angle=integers[4]
288
+ integers=integers[:-1]
289
+ if len(integers) == 4: # it is refer
290
+ mode = 'single'
291
+
292
+ entities = list()
293
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
294
+ left = x0 / bounding_box_size * image_width
295
+ bottom = y0 / bounding_box_size * image_height
296
+ right = x1 / bounding_box_size * image_width
297
+ top = y1 / bounding_box_size * image_height
298
+ entities.append([left, bottom, right, top,angle])
299
+ else:
300
+ # don't detect any valid bbox to visualize
301
+ return None, ''
302
+
303
+ if len(entities) == 0:
304
+ return None, ''
305
+
306
+ if isinstance(image, Image.Image):
307
+ image_h = image.height
308
+ image_w = image.width
309
+ image = np.array(image)
310
+
311
+ elif isinstance(image, str):
312
+ if os.path.exists(image):
313
+ pil_img = Image.open(image).convert("RGB")
314
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
315
+ image_h = pil_img.height
316
+ image_w = pil_img.width
317
+ else:
318
+ raise ValueError(f"invaild image path, {image}")
319
+ elif isinstance(image, torch.Tensor):
320
+
321
+ image_tensor = image.cpu()
322
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
323
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
324
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
325
+ pil_img = T.ToPILImage()(image_tensor)
326
+ image_h = pil_img.height
327
+ image_w = pil_img.width
328
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
329
+ else:
330
+ raise ValueError(f"invalid image format, {type(image)} for {image}")
331
+
332
+ indices = list(range(len(entities)))
333
+
334
+ new_image = image.copy()
335
+
336
+ previous_bboxes = []
337
+ # size of text
338
+ text_size = 0.4
339
+ # thickness of text
340
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
341
+ box_line = 2
342
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
343
+ base_height = int(text_height * 0.675)
344
+ text_offset_original = text_height - base_height
345
+ text_spaces = 2
346
+
347
+ # num_bboxes = sum(len(x[-1]) for x in entities)
348
+ used_colors = colors # random.sample(colors, k=num_bboxes)
349
+
350
+ color_id = -1
351
+ for entity_idx, entity_name in enumerate(entities):
352
+ if mode == 'single' or mode == 'identify':
353
+ bboxes = entity_name
354
+ bboxes = [bboxes]
355
+ else:
356
+ bboxes = entities[entity_name]
357
+ color_id += 1
358
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm,angle) in enumerate(bboxes):
359
+ skip_flag = False
360
+ orig_x1, orig_y1, orig_x2, orig_y2,angle = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm), int(angle)
361
+
362
+ color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
363
+ top_right=(orig_x1,orig_y1)
364
+ bottom_left=(orig_x2,orig_y2)
365
+ angle=angle
366
+ rotated_bbox = rotate_bbox(top_right, bottom_left, angle)
367
+ new_image=cv2.polylines(new_image, [rotated_bbox.astype(np.int32)], isClosed=True,thickness=2, color=color)
368
+
369
+ # new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
370
+
371
+ if mode == 'all':
372
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
373
+
374
+ x1 = orig_x1 - l_o
375
+ y1 = orig_y1 - l_o
376
+
377
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
378
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
379
+ x1 = orig_x1 + r_o
380
+
381
+ # add text background
382
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
383
+ text_line)
384
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
385
+ text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
386
+
387
+ for prev_bbox in previous_bboxes:
388
+ if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
389
+ prev_bbox['phrase'] == entity_name:
390
+ skip_flag = True
391
+ break
392
+ while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
393
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
394
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
395
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
396
+
397
+ if text_bg_y2 >= image_h:
398
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
399
+ text_bg_y2 = image_h
400
+ y1 = image_h
401
+ break
402
+ if not skip_flag:
403
+ alpha = 0.5
404
+ for i in range(text_bg_y1, text_bg_y2):
405
+ for j in range(text_bg_x1, text_bg_x2):
406
+ if i < image_h and j < image_w:
407
+ if j < text_bg_x1 + 1.35 * c_width:
408
+ # original color
409
+ bg_color = color
410
+ else:
411
+ # white
412
+ bg_color = [255, 255, 255]
413
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
414
+ np.uint8)
415
+
416
+ cv2.putText(
417
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
418
+ cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
419
+ )
420
+
421
+ previous_bboxes.append(
422
+ {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
423
+
424
+ if mode == 'all':
425
+ def color_iterator(colors):
426
+ while True:
427
+ for color in colors:
428
+ yield color
429
+
430
+ color_gen = color_iterator(colors)
431
+
432
+ # Add colors to phrases and remove <p></p>
433
+ def colored_phrases(match):
434
+ phrase = match.group(1)
435
+ color = next(color_gen)
436
+ return f'<span style="color:rgb{color}">{phrase}</span>'
437
+
438
+ generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
439
+ generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
440
+ else:
441
+ generation_colored = ''
442
+
443
+ pil_image = Image.fromarray(new_image)
444
+ return pil_image, generation_colored
445
+
446
+
447
+ def gradio_reset(chat_state, img_list):
448
+ if chat_state is not None:
449
+ chat_state.messages = []
450
+ if img_list is not None:
451
+ img_list = []
452
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
453
+ interactive=True), chat_state, img_list
454
+
455
+
456
+ def image_upload_trigger(upload_flag, replace_flag, img_list):
457
+ # set the upload flag to true when receive a new image.
458
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
459
+ upload_flag = 1
460
+ if img_list:
461
+ replace_flag = 1
462
+ return upload_flag, replace_flag
463
+
464
+
465
+ def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
466
+ # set the upload flag to true when receive a new image.
467
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
468
+ upload_flag = 1
469
+ if img_list or replace_flag == 1:
470
+ replace_flag = 1
471
+
472
+ return upload_flag, replace_flag
473
+
474
+
475
+ def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
476
+ if len(user_message) == 0:
477
+ text_box_show = 'Input should not be empty!'
478
+ else:
479
+ text_box_show = ''
480
+
481
+ if isinstance(gr_img, dict):
482
+ gr_img, mask = gr_img['image'], gr_img['mask']
483
+ else:
484
+ mask = None
485
+
486
+ if '[identify]' in user_message:
487
+ # check if user provide bbox in the text input
488
+ integers = re.findall(r'-?\d+', user_message)
489
+ if len(integers) != 4: # no bbox in text
490
+ bbox = mask2bbox(mask)
491
+ user_message = user_message + bbox
492
+
493
+ if chat_state is None:
494
+ chat_state = CONV_VISION.copy()
495
+
496
+ if upload_flag:
497
+ if replace_flag:
498
+ chat_state = CONV_VISION.copy() # new image, reset everything
499
+ replace_flag = 0
500
+ chatbot = []
501
+ img_list = []
502
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
503
+ upload_flag = 0
504
+
505
+ chat.ask(user_message, chat_state)
506
+
507
+ chatbot = chatbot + [[user_message, None]]
508
+
509
+ if '[identify]' in user_message:
510
+ visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
511
+ if visual_img is not None:
512
+ file_path = save_tmp_img(visual_img)
513
+ chatbot = chatbot + [[(file_path,), None]]
514
+
515
+ return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
516
+
517
+
518
+ # def gradio_answer(chatbot, chat_state, img_list, temperature):
519
+ # llm_message = chat.answer(conv=chat_state,
520
+ # img_list=img_list,
521
+ # temperature=temperature,
522
+ # max_new_tokens=500,
523
+ # max_length=2000)[0]
524
+ # chatbot[-1][1] = llm_message
525
+ # return chatbot, chat_state
526
+
527
+
528
+ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
529
+ if len(img_list) > 0:
530
+ if not isinstance(img_list[0], torch.Tensor):
531
+ chat.encode_img(img_list)
532
+ streamer = chat.stream_answer(conv=chat_state,
533
+ img_list=img_list,
534
+ temperature=temperature,
535
+ max_new_tokens=500,
536
+ max_length=2000)
537
+ # chatbot[-1][1] = output
538
+ # chat_state.messages[-1][1] = '</s>'
539
+
540
+ output = ''
541
+ for new_output in streamer:
542
+ # print(new_output)
543
+ output=output+new_output
544
+ print(output)
545
+ # if "{" in output:
546
+ # chatbot[-1][1]="Grounding and referring expression is still under work."
547
+ # else:
548
+ output = escape_markdown(output)
549
+ # output += escapped
550
+ chatbot[-1][1] = output
551
+ yield chatbot, chat_state
552
+ chat_state.messages[-1][1] = '</s>'
553
+ return chatbot, chat_state
554
+
555
+
556
+ def gradio_visualize(chatbot, gr_img):
557
+ if isinstance(gr_img, dict):
558
+ gr_img, mask = gr_img['image'], gr_img['mask']
559
+
560
+ unescaped = reverse_escape(chatbot[-1][1])
561
+ visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
562
+ if visual_img is not None:
563
+ if len(generation_color):
564
+ chatbot[-1][1] = generation_color
565
+ file_path = save_tmp_img(visual_img)
566
+ chatbot = chatbot + [[None, (file_path,)]]
567
+
568
+ return chatbot
569
+
570
+
571
+ def gradio_taskselect(idx):
572
+ prompt_list = [
573
+ '',
574
+ 'Classify the image in the following classes: ',
575
+ '[identify] what is this ',
576
+ ]
577
+ instruct_list = [
578
+ '**Hint:** Type in whatever you want',
579
+ '**Hint:** Type in the classes you want the model to classify in',
580
+ '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
581
+ ]
582
+ return prompt_list[idx], instruct_list[idx]
583
+
584
+
585
+
586
+
587
+ chat = Chat(model, image_processor,tokenizer, device=device)
588
+
589
+
590
+ title = """<h1 align="center">GeoChat Demo</h1>"""
591
+ description = 'Welcome to Our GeoChat Chatbot Demo!'
592
+ article = """<div style="display: flex;"><p style="display: inline-block;"><a href='https://mbzuai-oryx.github.io/GeoChat'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p style="display: inline-block;"><a href='https://arxiv.org/abs/2311.15826'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p><p style="display: inline-block;"><a href='https://github.com/mbzuai-oryx/GeoChat/tree/main'><img src='https://img.shields.io/badge/GitHub-Repo-blue'></a></p><p style="display: inline-block;"><a href='https://youtu.be/KOKtkkKpNDk?feature=shared'><img src='https://img.shields.io/badge/YouTube-Video-red'></a></p></div>"""
593
+ # article = """<p><a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>"""
594
+
595
+ introduction = '''
596
+ 1. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
597
+ 2. No Tag: Input whatever you want and CLICK **Send** without any tagging
598
+
599
+ You can also simply chat in free form!
600
+ '''
601
+
602
+
603
+ text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
604
+ scale=12)
605
+ with gr.Blocks() as demo:
606
+ gr.Markdown(title)
607
+ # gr.Markdown(description)
608
+ gr.Markdown(article)
609
+
610
+ with gr.Row():
611
+ with gr.Column(scale=0.5):
612
+ image = gr.Image(type="pil", tool='sketch', brush_radius=20)
613
+
614
+ temperature = gr.Slider(
615
+ minimum=0.1,
616
+ maximum=1.5,
617
+ value=0.6,
618
+ step=0.1,
619
+ interactive=True,
620
+ label="Temperature",
621
+ )
622
+
623
+ clear = gr.Button("Restart")
624
+
625
+ gr.Markdown(introduction)
626
+
627
+ with gr.Column():
628
+ chat_state = gr.State(value=None)
629
+ img_list = gr.State(value=[])
630
+ chatbot = gr.Chatbot(label='GeoChat')
631
+
632
+ dataset = gr.Dataset(
633
+ components=[gr.Textbox(visible=False)],
634
+ samples=[['No Tag'], ['Scene Classification'],['Identify']],
635
+ type="index",
636
+ label='Task Shortcuts',
637
+ )
638
+ task_inst = gr.Markdown('**Hint:** Upload your image and chat')
639
+ with gr.Row():
640
+ text_input.render()
641
+ send = gr.Button("Send", variant='primary', size='sm', scale=1)
642
+
643
+ upload_flag = gr.State(value=0)
644
+ replace_flag = gr.State(value=0)
645
+ image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
646
+
647
+ with gr.Row():
648
+ with gr.Column():
649
+ gr.Examples(examples=[
650
+ ["demo_images/train_2956_0001.png", "Where are the airplanes located and what is their type?", upload_flag, replace_flag,
651
+ img_list],
652
+ ["demo_images/7292.JPG", "How many buildings are flooded?", upload_flag,
653
+ replace_flag, img_list],
654
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
655
+ outputs=[upload_flag, replace_flag])
656
+ with gr.Column():
657
+ gr.Examples(examples=[
658
+ ["demo_images/church_183.png", "Classify the image in the following classes: Church, Beach, Dense Residential, Storage Tanks.",
659
+ upload_flag, replace_flag, img_list],
660
+ ["demo_images/04444.png", "[identify] what is this {<8><26><22><37>}", upload_flag,
661
+ replace_flag, img_list],
662
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
663
+ outputs=[upload_flag, replace_flag])
664
+
665
+ dataset.click(
666
+ gradio_taskselect,
667
+ inputs=[dataset],
668
+ outputs=[text_input, task_inst],
669
+ show_progress="hidden",
670
+ postprocess=False,
671
+ queue=False,
672
+ )
673
+
674
+ text_input.submit(
675
+ gradio_ask,
676
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
677
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
678
+ ).success(
679
+ gradio_stream_answer,
680
+ [chatbot, chat_state, img_list, temperature],
681
+ [chatbot, chat_state]
682
+ ).success(
683
+ gradio_visualize,
684
+ [chatbot, image],
685
+ [chatbot],
686
+ queue=False,
687
+ )
688
+
689
+ send.click(
690
+ gradio_ask,
691
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
692
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
693
+ ).success(
694
+ gradio_stream_answer,
695
+ [chatbot, chat_state, img_list, temperature],
696
+ [chatbot, chat_state]
697
+ ).success(
698
+ gradio_visualize,
699
+ [chatbot, image],
700
+ [chatbot],
701
+ queue=False,
702
+ )
703
+
704
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
705
+
706
+ demo.launch(share=True, enable_queue=True,server_name='0.0.0.0')
pyproject.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "geochat"
7
+ version = "1.1.1"
8
+ description = "Grounded VLM for Remote Sensing"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ ]
15
+ dependencies = [
16
+ "einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy",
17
+ "requests", "sentencepiece", "tokenizers>=0.12.1",
18
+ "torch==2.0.1", "torchvision==0.15.2", "uvicorn", "wandb",
19
+ "shortuuid", "httpx==0.24.0",
20
+ "deepspeed==0.9.5",
21
+ "peft==0.4.0",
22
+ "transformers==4.31.0",
23
+ "accelerate==0.21.0",
24
+ "bitsandbytes==0.41.0",
25
+ "scikit-learn==1.2.2",
26
+ "sentencepiece==0.1.99",
27
+ "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
28
+ "gradio_client==0.2.9"
29
+ ]
30
+
31
+ [project.urls]
32
+ "Homepage" = "https://github.com/mbzuai-oryx/GeoChat"
33
+ "Bug Tracker" = "https://github.com/mbzuai-oryx/GeoChat/issues"
34
+
35
+ [tool.setuptools.packages.find]
36
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
37
+
38
+ [tool.wheel]
39
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]