wenhu commited on
Commit
be2d2b2
Β·
verified Β·
1 Parent(s): df7f4b3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +723 -0
app.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .demo_modelpart import InferenceDemo
2
+ import gradio as gr
3
+ import os
4
+ from threading import Thread
5
+
6
+ # import time
7
+ import cv2
8
+
9
+ import datetime
10
+ # import copy
11
+ import torch
12
+
13
+ import spaces
14
+ import numpy as np
15
+
16
+ from llava import conversation as conversation_lib
17
+ from llava.constants import DEFAULT_IMAGE_TOKEN
18
+
19
+
20
+ from llava.constants import (
21
+ IMAGE_TOKEN_INDEX,
22
+ DEFAULT_IMAGE_TOKEN,
23
+ )
24
+ from llava.conversation import conv_templates, SeparatorStyle
25
+ from llava.model.builder import load_pretrained_model
26
+ from llava.utils import disable_torch_init
27
+ from llava.mm_utils import (
28
+ tokenizer_image_token,
29
+ get_model_name_from_path,
30
+ KeywordsStoppingCriteria,
31
+ )
32
+
33
+ from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown
34
+
35
+ from decord import VideoReader, cpu
36
+
37
+ import requests
38
+ from PIL import Image
39
+ import io
40
+ from io import BytesIO
41
+ from transformers import TextStreamer, TextIteratorStreamer
42
+
43
+ import hashlib
44
+ import PIL
45
+ import base64
46
+ import json
47
+
48
+ import datetime
49
+ import gradio as gr
50
+ import gradio_client
51
+ import subprocess
52
+ import sys
53
+
54
+ from huggingface_hub import HfApi
55
+ from huggingface_hub import login
56
+ from huggingface_hub import revision_exists
57
+
58
+ login(token=os.environ["HF_TOKEN"],
59
+ write_permission=True)
60
+
61
+ api = HfApi()
62
+ repo_name = os.environ["LOG_REPO"]
63
+
64
+ external_log_dir = "./logs"
65
+ LOGDIR = external_log_dir
66
+ VOTEDIR = "./votes"
67
+
68
+
69
+ def install_gradio_4_35_0():
70
+ current_version = gr.__version__
71
+ if current_version != "4.35.0":
72
+ print(f"Current Gradio version: {current_version}")
73
+ print("Installing Gradio 4.35.0...")
74
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "gradio==4.35.0", "--force-reinstall"])
75
+ print("Gradio 4.35.0 installed successfully.")
76
+ else:
77
+ print("Gradio 4.35.0 is already installed.")
78
+
79
+ # Call the function to install Gradio 4.35.0 if needed
80
+ install_gradio_4_35_0()
81
+
82
+ import gradio as gr
83
+ import gradio_client
84
+ print(f"Gradio version: {gr.__version__}")
85
+ print(f"Gradio-client version: {gradio_client.__version__}")
86
+
87
+ def get_conv_log_filename():
88
+ t = datetime.datetime.now()
89
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
90
+ return name
91
+
92
+ def get_conv_vote_filename():
93
+ t = datetime.datetime.now()
94
+ name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json")
95
+ if not os.path.isfile(name):
96
+ os.makedirs(os.path.dirname(name), exist_ok=True)
97
+ return name
98
+
99
+ def vote_last_response(state, vote_type, model_selector):
100
+ with open(get_conv_vote_filename(), "a") as fout:
101
+ data = {
102
+ "type": vote_type,
103
+ "model": model_selector,
104
+ "state": state,
105
+ }
106
+ fout.write(json.dumps(data) + "\n")
107
+ api.upload_file(
108
+ path_or_fileobj=get_conv_vote_filename(),
109
+ path_in_repo=get_conv_vote_filename().replace("./votes/", ""),
110
+ repo_id=repo_name,
111
+ repo_type="dataset")
112
+
113
+
114
+ def upvote_last_response(state):
115
+ vote_last_response(state, "upvote", "MAmmoTH-VL2")
116
+ gr.Info("Thank you for your voting!")
117
+ return state
118
+
119
+ def downvote_last_response(state):
120
+ vote_last_response(state, "downvote", "MAmmoTH-VL2")
121
+ gr.Info("Thank you for your voting!")
122
+ return state
123
+
124
+ class InferenceDemo(object):
125
+ def __init__(
126
+ self, args, model_path, tokenizer, model, image_processor, context_len
127
+ ) -> None:
128
+ disable_torch_init()
129
+
130
+ self.tokenizer, self.model, self.image_processor, self.context_len = (
131
+ tokenizer,
132
+ model,
133
+ image_processor,
134
+ context_len,
135
+ )
136
+
137
+ if "llama-2" in model_name.lower():
138
+ conv_mode = "llava_llama_2"
139
+ elif "v1" in model_name.lower():
140
+ conv_mode = "llava_v1"
141
+ elif "mpt" in model_name.lower():
142
+ conv_mode = "mpt"
143
+ elif "qwen" in model_name.lower():
144
+ conv_mode = "qwen_1_5"
145
+ elif "pangea" in model_name.lower():
146
+ conv_mode = "qwen_1_5"
147
+ elif "mammoth-vl" in model_name.lower():
148
+ conv_mode = "qwen_2_5"
149
+ else:
150
+ conv_mode = "llava_v0"
151
+
152
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
153
+ print(
154
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
155
+ conv_mode, args.conv_mode, args.conv_mode
156
+ )
157
+ )
158
+ else:
159
+ args.conv_mode = conv_mode
160
+ self.conv_mode = conv_mode
161
+ self.conversation = conv_templates[args.conv_mode].copy()
162
+ self.num_frames = args.num_frames
163
+
164
+ class ChatSessionManager:
165
+ def __init__(self):
166
+ self.chatbot_instance = None
167
+
168
+ def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
169
+ self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
170
+ print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
171
+
172
+ def reset_chatbot(self):
173
+ self.chatbot_instance = None
174
+
175
+ def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
176
+ if self.chatbot_instance is None:
177
+ self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
178
+ return self.chatbot_instance
179
+
180
+
181
+ def is_valid_video_filename(name):
182
+ video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
183
+
184
+ ext = name.split(".")[-1].lower()
185
+
186
+ if ext in video_extensions:
187
+ return True
188
+ else:
189
+ return False
190
+
191
+ def is_valid_image_filename(name):
192
+ image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
193
+
194
+ ext = name.split(".")[-1].lower()
195
+
196
+ if ext in image_extensions:
197
+ return True
198
+ else:
199
+ return False
200
+
201
+
202
+ def sample_frames_v1(video_file, num_frames):
203
+ video = cv2.VideoCapture(video_file)
204
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
205
+ interval = total_frames // num_frames
206
+ frames = []
207
+ for i in range(total_frames):
208
+ ret, frame = video.read()
209
+ pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
210
+ if not ret:
211
+ continue
212
+ if i % interval == 0:
213
+ frames.append(pil_img)
214
+ video.release()
215
+ return frames
216
+
217
+ def sample_frames_v2(video_path, frame_count=32):
218
+ video_frames = []
219
+ vr = VideoReader(video_path, ctx=cpu(0))
220
+ total_frames = len(vr)
221
+ frame_interval = max(total_frames // frame_count, 1)
222
+
223
+ for i in range(0, total_frames, frame_interval):
224
+ frame = vr[i].asnumpy()
225
+ frame_image = Image.fromarray(frame) # Convert to PIL.Image
226
+ video_frames.append(frame_image)
227
+ if len(video_frames) >= frame_count:
228
+ break
229
+
230
+ # Ensure at least one frame is returned if total frames are less than required
231
+ if len(video_frames) < frame_count and total_frames > 0:
232
+ for i in range(total_frames):
233
+ frame = vr[i].asnumpy()
234
+ frame_image = Image.fromarray(frame) # Convert to PIL.Image
235
+ video_frames.append(frame_image)
236
+ if len(video_frames) >= frame_count:
237
+ break
238
+
239
+ return video_frames
240
+
241
+ def sample_frames(video_path, num_frames=8):
242
+ cap = cv2.VideoCapture(video_path)
243
+ frames = []
244
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
245
+ indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
246
+
247
+ for i in indices:
248
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
249
+ ret, frame = cap.read()
250
+ if ret:
251
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
252
+ frames.append(Image.fromarray(frame))
253
+
254
+ cap.release()
255
+ return frames
256
+
257
+
258
+ def load_image(image_file):
259
+ if image_file.startswith("http") or image_file.startswith("https"):
260
+ response = requests.get(image_file)
261
+ if response.status_code == 200:
262
+ image = Image.open(BytesIO(response.content)).convert("RGB")
263
+ else:
264
+ print("failed to load the image")
265
+ else:
266
+ print("Load image from local file")
267
+ print(image_file)
268
+ image = Image.open(image_file).convert("RGB")
269
+
270
+ return image
271
+
272
+
273
+ def clear_response(history):
274
+ for index_conv in range(1, len(history)):
275
+ # loop until get a text response from our model.
276
+ conv = history[-index_conv]
277
+ if not (conv[0] is None):
278
+ break
279
+ question = history[-index_conv][0]
280
+ history = history[:-index_conv]
281
+ return history, question
282
+
283
+ chat_manager = ChatSessionManager()
284
+
285
+
286
+ def clear_history(history):
287
+ chatbot_instance = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
288
+ chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy()
289
+ return None
290
+
291
+
292
+
293
+ def add_message(history, message):
294
+ global chat_image_num
295
+ print("#### len(history)",len(history))
296
+ if not history:
297
+ history = []
298
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
299
+ chat_image_num = 0
300
+
301
+ # if len(message["files"]) <= 1:
302
+ # for x in message["files"]:
303
+ # history.append(((x,), None))
304
+ # chat_image_num += 1
305
+ # if chat_image_num > 1:
306
+ # history = []
307
+ # chat_manager.reset_chatbot()
308
+ # our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
309
+ # chat_image_num = 0
310
+ # for x in message["files"]:
311
+ # history.append(((x,), None))
312
+ # chat_image_num += 1
313
+
314
+ # if message["text"] is not None:
315
+ # history.append((message["text"], None))
316
+
317
+ # print(f"### Chatbot instance ID: {id(our_chatbot)}")
318
+ # return history, gr.MultimodalTextbox(value=None, interactive=False)
319
+ # else:
320
+ for x in message["files"]:
321
+ if "realcase_video.jpg" in x:
322
+ x = x.replace("realcase_video.jpg", "realcase_video.mp4")
323
+ history.append(((x,), None))
324
+ if message["text"] is not None:
325
+ history.append((message["text"], None))
326
+ # print(f"### Chatbot instance ID: {id(our_chatbot)}")
327
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
328
+
329
+
330
+ @spaces.GPU
331
+ def bot(history, temperature, top_p, max_output_tokens):
332
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
333
+ print(f"### Chatbot instance ID: {id(our_chatbot)}")
334
+ text = history[-1][0]
335
+ images_this_term = []
336
+ text_this_term = ""
337
+
338
+ is_video = False
339
+ num_new_images = 0
340
+ # previous_image = False
341
+ for i, message in enumerate(history[:-1]):
342
+ if type(message[0]) is tuple:
343
+ # if previous_image:
344
+ # gr.Warning("Only one image can be uploaded in a conversation. Please reduce the number of images and start a new conversation.")
345
+ # our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
346
+ # return None
347
+
348
+ images_this_term.append(message[0][0])
349
+ if is_valid_video_filename(message[0][0]):
350
+ # raise ValueError("Video is not supported")
351
+ # num_new_images += our_chatbot.num_frames
352
+ # num_new_images += len(sample_frames(message[0][0], our_chatbot.num_frames))
353
+ num_new_images += 1
354
+ is_video = True
355
+ elif is_valid_image_filename(message[0][0]):
356
+ print("#### Load image from local file",message[0][0])
357
+ num_new_images += 1
358
+ else:
359
+ raise ValueError("Invalid file format")
360
+ # previous_image = True
361
+ else:
362
+ num_new_images = 0
363
+ # previous_image = False
364
+
365
+
366
+ image_list = []
367
+ for f in images_this_term:
368
+ if is_valid_video_filename(f):
369
+ image_list += sample_frames(f, our_chatbot.num_frames)
370
+ elif is_valid_image_filename(f):
371
+ image_list.append(load_image(f))
372
+ else:
373
+ raise ValueError("Invalid image file")
374
+
375
+ all_image_hash = []
376
+ all_image_path = []
377
+ for file_path in images_this_term:
378
+ with open(file_path, "rb") as file:
379
+ file_data = file.read()
380
+ file_hash = hashlib.md5(file_data).hexdigest()
381
+ all_image_hash.append(file_hash)
382
+
383
+ t = datetime.datetime.now()
384
+ output_dir = os.path.join(
385
+ LOGDIR,
386
+ "serve_files",
387
+ f"{t.year}-{t.month:02d}-{t.day:02d}"
388
+ )
389
+ os.makedirs(output_dir, exist_ok=True)
390
+
391
+ if is_valid_image_filename(file_path):
392
+ # Process and save images
393
+ image = Image.open(file_path).convert("RGB")
394
+ filename = os.path.join(output_dir, f"{file_hash}.jpg")
395
+ all_image_path.append(filename)
396
+ if not os.path.isfile(filename):
397
+ print("Image saved to", filename)
398
+ image.save(filename)
399
+
400
+ elif is_valid_video_filename(file_path):
401
+ # Simplified video saving
402
+ filename = os.path.join(output_dir, f"{file_hash}.mp4")
403
+ all_image_path.append(filename)
404
+ if not os.path.isfile(filename):
405
+ print("Video saved to", filename)
406
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
407
+ # Directly copy the video file
408
+ with open(file_path, "rb") as src, open(filename, "wb") as dst:
409
+ dst.write(src.read())
410
+
411
+ image_tensor = []
412
+ if is_video:
413
+ image_tensor = our_chatbot.image_processor.preprocess(image_list, return_tensors="pt")["pixel_values"].half().to(our_chatbot.model.device)
414
+ elif num_new_images > 0:
415
+ image_tensor = [
416
+ our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
417
+ 0
418
+ ]
419
+ .half()
420
+ .to(our_chatbot.model.device)
421
+ for f in image_list
422
+ ]
423
+ image_tensor = torch.stack(image_tensor)
424
+
425
+ image_token = DEFAULT_IMAGE_TOKEN * num_new_images + "\n"
426
+
427
+ inp = text
428
+ inp = image_token + inp
429
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
430
+ # image = None
431
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
432
+ prompt = our_chatbot.conversation.get_prompt()
433
+
434
+ input_ids = tokenizer_image_token(
435
+ prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
436
+ ).unsqueeze(0).to(our_chatbot.model.device)
437
+ # print("### input_id",input_ids)
438
+ stop_str = (
439
+ our_chatbot.conversation.sep
440
+ if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
441
+ else our_chatbot.conversation.sep2
442
+ )
443
+ keywords = [stop_str]
444
+ stopping_criteria = KeywordsStoppingCriteria(
445
+ keywords, our_chatbot.tokenizer, input_ids
446
+ )
447
+
448
+ streamer = TextIteratorStreamer(
449
+ our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
450
+ )
451
+ print(our_chatbot.model.device)
452
+ print(input_ids.device)
453
+ # print(image_tensor.device)
454
+
455
+
456
+ if is_video:
457
+ input_image_tensor = [image_tensor]
458
+ elif num_new_images > 0:
459
+ input_image_tensor = image_tensor
460
+ else:
461
+ input_image_tensor = None
462
+
463
+ generate_kwargs = dict(
464
+ inputs=input_ids,
465
+ streamer=streamer,
466
+ images=input_image_tensor,
467
+ do_sample=True,
468
+ temperature=temperature,
469
+ top_p=top_p,
470
+ max_new_tokens=max_output_tokens,
471
+ use_cache=False,
472
+ stopping_criteria=[stopping_criteria],
473
+ modalities=["video"] if is_video else ["image"]
474
+ )
475
+
476
+ t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
477
+ t.start()
478
+
479
+ outputs = []
480
+ for stream_token in streamer:
481
+ outputs.append(stream_token)
482
+
483
+ history[-1] = [text, "".join(outputs)]
484
+ yield history
485
+ our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
486
+ # print("### turn end history", history)
487
+ # print("### turn end conv",our_chatbot.conversation)
488
+
489
+ with open(get_conv_log_filename(), "a") as fout:
490
+ data = {
491
+ "type": "chat",
492
+ "model": "MAmmoTH-VL2",
493
+ "state": history,
494
+ "images": all_image_hash,
495
+ "images_path": all_image_path
496
+ }
497
+ print("#### conv log",data)
498
+ fout.write(json.dumps(data) + "\n")
499
+ for upload_img in all_image_path:
500
+ api.upload_file(
501
+ path_or_fileobj=upload_img,
502
+ path_in_repo=upload_img.replace("./logs/", ""),
503
+ repo_id=repo_name,
504
+ repo_type="dataset",
505
+ # revision=revision,
506
+ # ignore_patterns=["data*"]
507
+ )
508
+ # upload json
509
+ api.upload_file(
510
+ path_or_fileobj=get_conv_log_filename(),
511
+ path_in_repo=get_conv_log_filename().replace("./logs/", ""),
512
+ repo_id=repo_name,
513
+ repo_type="dataset")
514
+
515
+
516
+
517
+ txt = gr.Textbox(
518
+ scale=4,
519
+ show_label=False,
520
+ placeholder="Enter text and press enter.",
521
+ container=False,
522
+ )
523
+
524
+ with gr.Blocks(
525
+ css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 40px}",
526
+ ) as demo:
527
+
528
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
529
+ # gr.Markdown(title_markdown)
530
+ gr.HTML(html_header)
531
+
532
+ with gr.Column():
533
+ with gr.Accordion("Parameters", open=False) as parameter_row:
534
+ temperature = gr.Slider(
535
+ minimum=0.0,
536
+ maximum=1.0,
537
+ value=0.7,
538
+ step=0.1,
539
+ interactive=True,
540
+ label="Temperature",
541
+ )
542
+ top_p = gr.Slider(
543
+ minimum=0.0,
544
+ maximum=1.0,
545
+ value=1,
546
+ step=0.1,
547
+ interactive=True,
548
+ label="Top P",
549
+ )
550
+ max_output_tokens = gr.Slider(
551
+ minimum=0,
552
+ maximum=8192,
553
+ value=4096,
554
+ step=256,
555
+ interactive=True,
556
+ label="Max output tokens",
557
+ )
558
+ with gr.Row():
559
+ chatbot = gr.Chatbot([], elem_id="MAmmoTH-VL-8B", bubble_full_width=False, height=750)
560
+
561
+ with gr.Row():
562
+ upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=True)
563
+ downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=True)
564
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
565
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True)
566
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=True)
567
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear history", interactive=True)
568
+
569
+
570
+ chat_input = gr.MultimodalTextbox(
571
+ interactive=True,
572
+ file_types=["image", "video"],
573
+ placeholder="Enter message or upload file...",
574
+ show_label=False,
575
+ submit_btn="πŸš€"
576
+ )
577
+
578
+ print(cur_dir)
579
+ gr.Examples(
580
+ examples_per_page=20,
581
+ examples=[
582
+ [
583
+ {
584
+ "files": [
585
+ f"{cur_dir}/examples/172197131626056_P7966202.png",
586
+ ],
587
+ "text": "Why this image funny?",
588
+ }
589
+ ],
590
+ [
591
+ {
592
+ "files": [
593
+ f"{cur_dir}/examples/realcase_doc.png",
594
+ ],
595
+ "text": "Read text in the image",
596
+ }
597
+ ],
598
+ [
599
+ {
600
+ "files": [
601
+ f"{cur_dir}/examples/realcase_weather.jpg",
602
+ ],
603
+ "text": "List the weather for Monday to Friday",
604
+ }
605
+ ],
606
+ [
607
+ {
608
+ "files": [
609
+ f"{cur_dir}/examples/realcase_knowledge.jpg",
610
+ ],
611
+ "text": "Answer the following question based on the provided image: What country do these planes belong to?",
612
+ }
613
+ ],
614
+ [
615
+ {
616
+ "files": [
617
+ f"{cur_dir}/examples/realcase_math.jpg",
618
+ ],
619
+ "text": "Find the measure of angle 3. Please provide a step by step solution.",
620
+ }
621
+ ],
622
+ [
623
+ {
624
+ "files": [
625
+ f"{cur_dir}/examples/realcase_interact.jpg",
626
+ ],
627
+ "text": "Please perfectly describe this cartoon illustration in as much detail as possible",
628
+ }
629
+ ],
630
+ [
631
+ {
632
+ "files": [
633
+ f"{cur_dir}/examples/realcase_perfer.jpg",
634
+ ],
635
+ "text": "This is an image of a room. It could either be a real image captured in the room or a rendered image from a 3D scene reconstruction technique that is trained using real images of the room. A rendered image usually contains some visible artifacts (eg. blurred regions due to under-reconstructed areas) that do not faithfully represent the actual scene. You need to decide if its a real image or a rendered image by giving each image a photorealism score between 1 and 5.",
636
+ }
637
+ ],
638
+ [
639
+ {
640
+ "files": [
641
+ f"{cur_dir}/examples/realcase_multi1.png",
642
+ f"{cur_dir}/examples/realcase_multi2.png",
643
+ f"{cur_dir}/examples/realcase_multi3.png",
644
+ f"{cur_dir}/examples/realcase_multi4.png",
645
+ f"{cur_dir}/examples/realcase_multi5.png",
646
+ ],
647
+ "text": "Based on the five species in the images, draw a food chain. Explain the role of each species in the food chain.",
648
+ }
649
+ ],
650
+ ],
651
+ inputs=[chat_input],
652
+ label="Real World Image Cases",
653
+ )
654
+ gr.Examples(
655
+ examples=[
656
+ [
657
+ {
658
+ "files": [
659
+ f"{cur_dir}/examples/realcase_video.mp4",
660
+ ],
661
+ "text": "Please describe the video in detail.",
662
+ },
663
+ ]
664
+ ],
665
+ inputs=[chat_input],
666
+ label="Real World Video Case"
667
+ )
668
+
669
+ gr.Markdown(tos_markdown)
670
+ gr.Markdown(learn_more_markdown)
671
+ gr.Markdown(bibtext)
672
+
673
+ chat_input.submit(
674
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
675
+ ).then(bot, [chatbot, temperature, top_p, max_output_tokens], chatbot, api_name="bot_response").then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
676
+
677
+
678
+ # chatbot.like(print_like_dislike, None, None)
679
+ clear_btn.click(
680
+ fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all"
681
+ )
682
+
683
+ upvote_btn.click(
684
+ fn=upvote_last_response, inputs=chatbot, outputs=chatbot, api_name="upvote_last_response"
685
+ )
686
+
687
+
688
+ downvote_btn.click(
689
+ fn=downvote_last_response, inputs=chatbot, outputs=chatbot, api_name="upvote_last_response"
690
+ )
691
+
692
+
693
+ demo.queue()
694
+
695
+ if __name__ == "__main__":
696
+ import argparse
697
+
698
+ argparser = argparse.ArgumentParser()
699
+ argparser.add_argument("--server_name", default="0.0.0.0", type=str)
700
+ argparser.add_argument("--port", default="6123", type=str)
701
+ argparser.add_argument(
702
+ "--model_path", default="TIGER-Lab/MAmmoTH-VL2", type=str
703
+ )
704
+ # argparser.add_argument("--model-path", type=str, default="facebook/opt-350m")
705
+ argparser.add_argument("--model-base", type=str, default=None)
706
+ argparser.add_argument("--num-gpus", type=int, default=1)
707
+ argparser.add_argument("--conv-mode", type=str, default=None)
708
+ argparser.add_argument("--temperature", type=float, default=0.7)
709
+ argparser.add_argument("--max-new-tokens", type=int, default=4096)
710
+ argparser.add_argument("--num_frames", type=int, default=32)
711
+ argparser.add_argument("--load-8bit", action="store_true")
712
+ argparser.add_argument("--load-4bit", action="store_true")
713
+ argparser.add_argument("--debug", action="store_true")
714
+
715
+ args = argparser.parse_args()
716
+
717
+ model_path = args.model_path
718
+ filt_invalid = "cut"
719
+ model_name = get_model_name_from_path(args.model_path)
720
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
721
+ model=model.to(torch.device('cuda'))
722
+ chat_image_num = 0
723
+ demo.launch()