wenhu commited on
Commit
d8b1b85
ยท
verified ยท
1 Parent(s): 3e0e0e0

Update app_test.py

Browse files
Files changed (1) hide show
  1. app_test.py +515 -6
app_test.py CHANGED
@@ -63,10 +63,399 @@ external_log_dir = "./logs"
63
  LOGDIR = external_log_dir
64
  VOTEDIR = "./votes"
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  @spaces.GPU
68
- def bot():
69
- print(f"### Chatbot instance ID")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
 
72
  with gr.Blocks(
@@ -112,7 +501,127 @@ with gr.Blocks(
112
  regenerate_btn = gr.Button(value="๐Ÿ”„ Regenerate", interactive=True)
113
  clear_btn = gr.Button(value="๐Ÿ—‘๏ธ Clear history", interactive=True)
114
 
115
- bot()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  demo.queue()
118
 
@@ -136,8 +645,8 @@ if __name__ == "__main__":
136
 
137
  model_path = args.model_path
138
  filt_invalid = "cut"
139
- #model_name = get_model_name_from_path(args.model_path)
140
- #tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
141
- #model=model.to(torch.device('cuda'))
142
  chat_image_num = 0
143
  demo.launch()
 
63
  LOGDIR = external_log_dir
64
  VOTEDIR = "./votes"
65
 
66
+ def get_conv_log_filename():
67
+ t = datetime.datetime.now()
68
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
69
+ return name
70
+
71
+ def get_conv_vote_filename():
72
+ t = datetime.datetime.now()
73
+ name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json")
74
+ if not os.path.isfile(name):
75
+ os.makedirs(os.path.dirname(name), exist_ok=True)
76
+ return name
77
+
78
+ def vote_last_response(state, vote_type, model_selector):
79
+ with open(get_conv_vote_filename(), "a") as fout:
80
+ data = {
81
+ "type": vote_type,
82
+ "model": model_selector,
83
+ "state": state,
84
+ }
85
+ fout.write(json.dumps(data) + "\n")
86
+ api.upload_file(
87
+ path_or_fileobj=get_conv_vote_filename(),
88
+ path_in_repo=get_conv_vote_filename().replace("./votes/", ""),
89
+ repo_id=repo_name,
90
+ repo_type="dataset")
91
+
92
+
93
+ def upvote_last_response(state):
94
+ vote_last_response(state, "upvote", "MAmmoTH-VL2")
95
+ gr.Info("Thank you for your voting!")
96
+ return state
97
+
98
+ def downvote_last_response(state):
99
+ vote_last_response(state, "downvote", "MAmmoTH-VL2")
100
+ gr.Info("Thank you for your voting!")
101
+ return state
102
+
103
+ class InferenceDemo(object):
104
+ def __init__(
105
+ self, args, model_path, tokenizer, model, image_processor, context_len
106
+ ) -> None:
107
+ disable_torch_init()
108
+
109
+ self.tokenizer, self.model, self.image_processor, self.context_len = (
110
+ tokenizer,
111
+ model,
112
+ image_processor,
113
+ context_len,
114
+ )
115
+
116
+ if "llama-2" in model_name.lower():
117
+ conv_mode = "llava_llama_2"
118
+ elif "v1" in model_name.lower():
119
+ conv_mode = "llava_v1"
120
+ elif "mpt" in model_name.lower():
121
+ conv_mode = "mpt"
122
+ elif "qwen" in model_name.lower():
123
+ conv_mode = "qwen_1_5"
124
+ elif "pangea" in model_name.lower():
125
+ conv_mode = "qwen_1_5"
126
+ elif "mammoth-vl" in model_name.lower():
127
+ conv_mode = "qwen_2_5"
128
+ else:
129
+ conv_mode = "llava_v0"
130
+
131
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
132
+ print(
133
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
134
+ conv_mode, args.conv_mode, args.conv_mode
135
+ )
136
+ )
137
+ else:
138
+ args.conv_mode = conv_mode
139
+ self.conv_mode = conv_mode
140
+ self.conversation = conv_templates[args.conv_mode].copy()
141
+ self.num_frames = args.num_frames
142
+
143
+ class ChatSessionManager:
144
+ def __init__(self):
145
+ self.chatbot_instance = None
146
+
147
+ def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
148
+ self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
149
+ print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
150
+
151
+ def reset_chatbot(self):
152
+ self.chatbot_instance = None
153
+
154
+ def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
155
+ if self.chatbot_instance is None:
156
+ self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
157
+ return self.chatbot_instance
158
+
159
+
160
+ def is_valid_video_filename(name):
161
+ video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
162
+
163
+ ext = name.split(".")[-1].lower()
164
+
165
+ if ext in video_extensions:
166
+ return True
167
+ else:
168
+ return False
169
+
170
+ def is_valid_image_filename(name):
171
+ image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
172
+
173
+ ext = name.split(".")[-1].lower()
174
+
175
+ if ext in image_extensions:
176
+ return True
177
+ else:
178
+ return False
179
+
180
+ def sample_frames_v1(video_file, num_frames):
181
+ video = cv2.VideoCapture(video_file)
182
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
183
+ interval = total_frames // num_frames
184
+ frames = []
185
+ for i in range(total_frames):
186
+ ret, frame = video.read()
187
+ pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
188
+ if not ret:
189
+ continue
190
+ if i % interval == 0:
191
+ frames.append(pil_img)
192
+ video.release()
193
+ return frames
194
+
195
+ def sample_frames_v2(video_path, frame_count=32):
196
+ video_frames = []
197
+ vr = VideoReader(video_path, ctx=cpu(0))
198
+ total_frames = len(vr)
199
+ frame_interval = max(total_frames // frame_count, 1)
200
+
201
+ for i in range(0, total_frames, frame_interval):
202
+ frame = vr[i].asnumpy()
203
+ frame_image = Image.fromarray(frame) # Convert to PIL.Image
204
+ video_frames.append(frame_image)
205
+ if len(video_frames) >= frame_count:
206
+ break
207
+
208
+ # Ensure at least one frame is returned if total frames are less than required
209
+ if len(video_frames) < frame_count and total_frames > 0:
210
+ for i in range(total_frames):
211
+ frame = vr[i].asnumpy()
212
+ frame_image = Image.fromarray(frame) # Convert to PIL.Image
213
+ video_frames.append(frame_image)
214
+ if len(video_frames) >= frame_count:
215
+ break
216
+
217
+ return video_frames
218
+
219
+ def sample_frames(video_path, num_frames=8):
220
+ cap = cv2.VideoCapture(video_path)
221
+ frames = []
222
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
223
+ indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
224
+
225
+ for i in indices:
226
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
227
+ ret, frame = cap.read()
228
+ if ret:
229
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
230
+ frames.append(Image.fromarray(frame))
231
+
232
+ cap.release()
233
+ return frames
234
+
235
+
236
+ def load_image(image_file):
237
+ if image_file.startswith("http") or image_file.startswith("https"):
238
+ response = requests.get(image_file)
239
+ if response.status_code == 200:
240
+ image = Image.open(BytesIO(response.content)).convert("RGB")
241
+ else:
242
+ print("failed to load the image")
243
+ else:
244
+ print("Load image from local file")
245
+ print(image_file)
246
+ image = Image.open(image_file).convert("RGB")
247
+
248
+ return image
249
+
250
+
251
+ def clear_response(history):
252
+ for index_conv in range(1, len(history)):
253
+ # loop until get a text response from our model.
254
+ conv = history[-index_conv]
255
+ if not (conv[0] is None):
256
+ break
257
+ question = history[-index_conv][0]
258
+ history = history[:-index_conv]
259
+ return history, question
260
+
261
+ chat_manager = ChatSessionManager()
262
+
263
+
264
+ def clear_history(history):
265
+ chatbot_instance = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
266
+ chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy()
267
+ return None
268
+
269
+
270
+
271
+ def add_message(history, message):
272
+ global chat_image_num
273
+ print("#### len(history)",len(history))
274
+ if not history:
275
+ history = []
276
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
277
+ chat_image_num = 0
278
+ for x in message["files"]:
279
+ if "realcase_video.jpg" in x:
280
+ x = x.replace("realcase_video.jpg", "realcase_video.mp4")
281
+ history.append(((x,), None))
282
+ if message["text"] is not None:
283
+ history.append((message["text"], None))
284
+ # print(f"### Chatbot instance ID: {id(our_chatbot)}")
285
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
286
+
287
 
288
  @spaces.GPU
289
+ def bot(history, temperature, top_p, max_output_tokens):
290
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
291
+ print(f"### Chatbot instance ID: {id(our_chatbot)}")
292
+ text = history[-1][0]
293
+ images_this_term = []
294
+ text_this_term = ""
295
+
296
+ is_video = False
297
+ num_new_images = 0
298
+ # previous_image = False
299
+ for i, message in enumerate(history[:-1]):
300
+ if type(message[0]) is tuple:
301
+ images_this_term.append(message[0][0])
302
+ if is_valid_video_filename(message[0][0]):
303
+ num_new_images += 1
304
+ is_video = True
305
+ elif is_valid_image_filename(message[0][0]):
306
+ print("#### Load image from local file",message[0][0])
307
+ num_new_images += 1
308
+ else:
309
+ raise ValueError("Invalid file format")
310
+ else:
311
+ num_new_images = 0
312
+
313
+
314
+ image_list = []
315
+ for f in images_this_term:
316
+ if is_valid_video_filename(f):
317
+ image_list += sample_frames(f, our_chatbot.num_frames)
318
+ elif is_valid_image_filename(f):
319
+ image_list.append(load_image(f))
320
+ else:
321
+ raise ValueError("Invalid image file")
322
+
323
+ all_image_hash = []
324
+ all_image_path = []
325
+ for file_path in images_this_term:
326
+ with open(file_path, "rb") as file:
327
+ file_data = file.read()
328
+ file_hash = hashlib.md5(file_data).hexdigest()
329
+ all_image_hash.append(file_hash)
330
+
331
+ t = datetime.datetime.now()
332
+ output_dir = os.path.join(
333
+ LOGDIR,
334
+ "serve_files",
335
+ f"{t.year}-{t.month:02d}-{t.day:02d}"
336
+ )
337
+ os.makedirs(output_dir, exist_ok=True)
338
+
339
+ if is_valid_image_filename(file_path):
340
+ # Process and save images
341
+ image = Image.open(file_path).convert("RGB")
342
+ filename = os.path.join(output_dir, f"{file_hash}.jpg")
343
+ all_image_path.append(filename)
344
+ if not os.path.isfile(filename):
345
+ print("Image saved to", filename)
346
+ image.save(filename)
347
+
348
+ elif is_valid_video_filename(file_path):
349
+ # Simplified video saving
350
+ filename = os.path.join(output_dir, f"{file_hash}.mp4")
351
+ all_image_path.append(filename)
352
+ if not os.path.isfile(filename):
353
+ print("Video saved to", filename)
354
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
355
+ # Directly copy the video file
356
+ with open(file_path, "rb") as src, open(filename, "wb") as dst:
357
+ dst.write(src.read())
358
+
359
+ image_tensor = []
360
+ if is_video:
361
+ image_tensor = our_chatbot.image_processor.preprocess(image_list, return_tensors="pt")["pixel_values"].half().to(our_chatbot.model.device)
362
+ elif num_new_images > 0:
363
+ image_tensor = [
364
+ our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
365
+ 0
366
+ ]
367
+ .half()
368
+ .to(our_chatbot.model.device)
369
+ for f in image_list
370
+ ]
371
+ image_tensor = torch.stack(image_tensor)
372
+
373
+ image_token = DEFAULT_IMAGE_TOKEN * num_new_images + "\n"
374
+
375
+ inp = text
376
+ inp = image_token + inp
377
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
378
+ # image = None
379
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
380
+ prompt = our_chatbot.conversation.get_prompt()
381
+
382
+ input_ids = tokenizer_image_token(
383
+ prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
384
+ ).unsqueeze(0).to(our_chatbot.model.device)
385
+ # print("### input_id",input_ids)
386
+ stop_str = (
387
+ our_chatbot.conversation.sep
388
+ if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
389
+ else our_chatbot.conversation.sep2
390
+ )
391
+ keywords = [stop_str]
392
+ stopping_criteria = KeywordsStoppingCriteria(
393
+ keywords, our_chatbot.tokenizer, input_ids
394
+ )
395
+
396
+ streamer = TextIteratorStreamer(
397
+ our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
398
+ )
399
+
400
+ if is_video:
401
+ input_image_tensor = [image_tensor]
402
+ elif num_new_images > 0:
403
+ input_image_tensor = image_tensor
404
+ else:
405
+ input_image_tensor = None
406
+
407
+ generate_kwargs = dict(
408
+ inputs=input_ids,
409
+ streamer=streamer,
410
+ images=input_image_tensor,
411
+ do_sample=True,
412
+ temperature=temperature,
413
+ top_p=top_p,
414
+ max_new_tokens=max_output_tokens,
415
+ use_cache=False,
416
+ stopping_criteria=[stopping_criteria],
417
+ modalities=["video"] if is_video else ["image"]
418
+ )
419
+
420
+ t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
421
+ t.start()
422
+
423
+ outputs = []
424
+ for stream_token in streamer:
425
+ outputs.append(stream_token)
426
+
427
+ history[-1] = [text, "".join(outputs)]
428
+ yield history
429
+ our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
430
+
431
+ with open(get_conv_log_filename(), "a") as fout:
432
+ data = {
433
+ "type": "chat",
434
+ "model": "MAmmoTH-VL2",
435
+ "state": history,
436
+ "images": all_image_hash,
437
+ "images_path": all_image_path
438
+ }
439
+ print("#### conv log",data)
440
+ fout.write(json.dumps(data) + "\n")
441
+ for upload_img in all_image_path:
442
+ api.upload_file(
443
+ path_or_fileobj=upload_img,
444
+ path_in_repo=upload_img.replace("./logs/", ""),
445
+ repo_id=repo_name,
446
+ repo_type="dataset",
447
+ # revision=revision,
448
+ # ignore_patterns=["data*"]
449
+ )
450
+ # upload json
451
+ api.upload_file(
452
+ path_or_fileobj=get_conv_log_filename(),
453
+ path_in_repo=get_conv_log_filename().replace("./logs/", ""),
454
+ repo_id=repo_name,
455
+ repo_type="dataset")
456
+
457
+
458
+
459
 
460
 
461
  with gr.Blocks(
 
501
  regenerate_btn = gr.Button(value="๐Ÿ”„ Regenerate", interactive=True)
502
  clear_btn = gr.Button(value="๐Ÿ—‘๏ธ Clear history", interactive=True)
503
 
504
+ chat_input = gr.MultimodalTextbox(
505
+ interactive=True,
506
+ file_types=["image", "video"],
507
+ placeholder="Enter message or upload file...",
508
+ show_label=False,
509
+ submit_btn="๐Ÿš€"
510
+ )
511
+
512
+ gr.Examples(
513
+ examples_per_page=20,
514
+ examples=[
515
+ [
516
+ {
517
+ "files": [
518
+ f"{cur_dir}/examples/172197131626056_P7966202.png",
519
+ ],
520
+ "text": "Why this image funny?",
521
+ }
522
+ ],
523
+ [
524
+ {
525
+ "files": [
526
+ f"{cur_dir}/examples/realcase_doc.png",
527
+ ],
528
+ "text": "Read text in the image",
529
+ }
530
+ ],
531
+ [
532
+ {
533
+ "files": [
534
+ f"{cur_dir}/examples/realcase_weather.jpg",
535
+ ],
536
+ "text": "List the weather for Monday to Friday",
537
+ }
538
+ ],
539
+ [
540
+ {
541
+ "files": [
542
+ f"{cur_dir}/examples/realcase_knowledge.jpg",
543
+ ],
544
+ "text": "Answer the following question based on the provided image: What country do these planes belong to?",
545
+ }
546
+ ],
547
+ [
548
+ {
549
+ "files": [
550
+ f"{cur_dir}/examples/realcase_math.jpg",
551
+ ],
552
+ "text": "Find the measure of angle 3. Please provide a step by step solution.",
553
+ }
554
+ ],
555
+ [
556
+ {
557
+ "files": [
558
+ f"{cur_dir}/examples/realcase_interact.jpg",
559
+ ],
560
+ "text": "Please perfectly describe this cartoon illustration in as much detail as possible",
561
+ }
562
+ ],
563
+ [
564
+ {
565
+ "files": [
566
+ f"{cur_dir}/examples/realcase_perfer.jpg",
567
+ ],
568
+ "text": "This is an image of a room. It could either be a real image captured in the room or a rendered image from a 3D scene reconstruction technique that is trained using real images of the room. A rendered image usually contains some visible artifacts (eg. blurred regions due to under-reconstructed areas) that do not faithfully represent the actual scene. You need to decide if its a real image or a rendered image by giving each image a photorealism score between 1 and 5.",
569
+ }
570
+ ],
571
+ [
572
+ {
573
+ "files": [
574
+ f"{cur_dir}/examples/realcase_multi1.png",
575
+ f"{cur_dir}/examples/realcase_multi2.png",
576
+ f"{cur_dir}/examples/realcase_multi3.png",
577
+ f"{cur_dir}/examples/realcase_multi4.png",
578
+ f"{cur_dir}/examples/realcase_multi5.png",
579
+ ],
580
+ "text": "Based on the five species in the images, draw a food chain. Explain the role of each species in the food chain.",
581
+ }
582
+ ],
583
+ ],
584
+ inputs=[chat_input],
585
+ label="Real World Image Cases",
586
+ )
587
+ gr.Examples(
588
+ examples=[
589
+ [
590
+ {
591
+ "files": [
592
+ f"{cur_dir}/examples/realcase_video.mp4",
593
+ ],
594
+ "text": "Please describe the video in detail.",
595
+ },
596
+ ]
597
+ ],
598
+ inputs=[chat_input],
599
+ label="Real World Video Case"
600
+ )
601
+
602
+ gr.Markdown(tos_markdown)
603
+ gr.Markdown(learn_more_markdown)
604
+ gr.Markdown(bibtext)
605
+
606
+ chat_input.submit(
607
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
608
+ ).then(bot, [chatbot, temperature, top_p, max_output_tokens], chatbot, api_name="bot_response").then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
609
+
610
+
611
+ # chatbot.like(print_like_dislike, None, None)
612
+ clear_btn.click(
613
+ fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all"
614
+ )
615
+
616
+ upvote_btn.click(
617
+ fn=upvote_last_response, inputs=chatbot, outputs=chatbot, api_name="upvote_last_response"
618
+ )
619
+
620
+
621
+ downvote_btn.click(
622
+ fn=downvote_last_response, inputs=chatbot, outputs=chatbot, api_name="upvote_last_response"
623
+ )
624
+
625
 
626
  demo.queue()
627
 
 
645
 
646
  model_path = args.model_path
647
  filt_invalid = "cut"
648
+ model_name = get_model_name_from_path(args.model_path)
649
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
650
+ model=model.to(torch.device('cuda'))
651
  chat_image_num = 0
652
  demo.launch()