wenhu commited on
Commit
2c52937
·
verified ·
1 Parent(s): d8b1b85

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -697
app.py DELETED
@@ -1,697 +0,0 @@
1
- # from .demo_modelpart import InferenceDemo
2
- import gradio as gr
3
- import os
4
- from threading import Thread
5
-
6
- # import time
7
- import cv2
8
-
9
- import datetime
10
- # import copy
11
- import torch
12
-
13
- import spaces
14
- import numpy as np
15
-
16
- from llava.constants import DEFAULT_IMAGE_TOKEN
17
-
18
- from llava.constants import (
19
- IMAGE_TOKEN_INDEX,
20
- DEFAULT_IMAGE_TOKEN,
21
- )
22
- from llava.conversation import conv_templates, SeparatorStyle
23
- from llava.model.builder import load_pretrained_model
24
- from llava.utils import disable_torch_init
25
- from llava.mm_utils import (
26
- tokenizer_image_token,
27
- get_model_name_from_path,
28
- KeywordsStoppingCriteria,
29
- )
30
-
31
- from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown
32
-
33
- from decord import VideoReader, cpu
34
-
35
- import requests
36
- from PIL import Image
37
- import io
38
- from io import BytesIO
39
- from transformers import TextStreamer, TextIteratorStreamer
40
-
41
- import hashlib
42
- import PIL
43
- import base64
44
- import json
45
-
46
- import datetime
47
- import gradio as gr
48
- import gradio_client
49
- import subprocess
50
- import sys
51
-
52
- from huggingface_hub import HfApi
53
- from huggingface_hub import login
54
- from huggingface_hub import revision_exists
55
-
56
- login(token=os.environ["HF_TOKEN"],
57
- write_permission=True)
58
-
59
- api = HfApi()
60
- repo_name = os.environ["LOG_REPO"]
61
-
62
- external_log_dir = "./logs"
63
- LOGDIR = external_log_dir
64
- VOTEDIR = "./votes"
65
-
66
-
67
- def install_gradio_4_35_0():
68
- current_version = gr.__version__
69
- if current_version != "4.35.0":
70
- print(f"Current Gradio version: {current_version}")
71
- print("Installing Gradio 4.35.0...")
72
- subprocess.check_call([sys.executable, "-m", "pip", "install", "gradio==4.35.0", "--force-reinstall"])
73
- print("Gradio 4.35.0 installed successfully.")
74
- else:
75
- print("Gradio 4.35.0 is already installed.")
76
-
77
- # Call the function to install Gradio 4.35.0 if needed
78
- install_gradio_4_35_0()
79
-
80
- import gradio as gr
81
- import gradio_client
82
- print(f"Gradio version: {gr.__version__}")
83
- print(f"Gradio-client version: {gradio_client.__version__}")
84
-
85
- def get_conv_log_filename():
86
- t = datetime.datetime.now()
87
- name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
88
- return name
89
-
90
- def get_conv_vote_filename():
91
- t = datetime.datetime.now()
92
- name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json")
93
- if not os.path.isfile(name):
94
- os.makedirs(os.path.dirname(name), exist_ok=True)
95
- return name
96
-
97
- def vote_last_response(state, vote_type, model_selector):
98
- with open(get_conv_vote_filename(), "a") as fout:
99
- data = {
100
- "type": vote_type,
101
- "model": model_selector,
102
- "state": state,
103
- }
104
- fout.write(json.dumps(data) + "\n")
105
- api.upload_file(
106
- path_or_fileobj=get_conv_vote_filename(),
107
- path_in_repo=get_conv_vote_filename().replace("./votes/", ""),
108
- repo_id=repo_name,
109
- repo_type="dataset")
110
-
111
-
112
- def upvote_last_response(state):
113
- vote_last_response(state, "upvote", "MAmmoTH-VL2")
114
- gr.Info("Thank you for your voting!")
115
- return state
116
-
117
- def downvote_last_response(state):
118
- vote_last_response(state, "downvote", "MAmmoTH-VL2")
119
- gr.Info("Thank you for your voting!")
120
- return state
121
-
122
- class InferenceDemo(object):
123
- def __init__(
124
- self, args, model_path, tokenizer, model, image_processor, context_len
125
- ) -> None:
126
- disable_torch_init()
127
-
128
- self.tokenizer, self.model, self.image_processor, self.context_len = (
129
- tokenizer,
130
- model,
131
- image_processor,
132
- context_len,
133
- )
134
-
135
- if "llama-2" in model_name.lower():
136
- conv_mode = "llava_llama_2"
137
- elif "v1" in model_name.lower():
138
- conv_mode = "llava_v1"
139
- elif "mpt" in model_name.lower():
140
- conv_mode = "mpt"
141
- elif "qwen" in model_name.lower():
142
- conv_mode = "qwen_1_5"
143
- elif "pangea" in model_name.lower():
144
- conv_mode = "qwen_1_5"
145
- elif "mammoth-vl" in model_name.lower():
146
- conv_mode = "qwen_2_5"
147
- else:
148
- conv_mode = "llava_v0"
149
-
150
- if args.conv_mode is not None and conv_mode != args.conv_mode:
151
- print(
152
- "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
153
- conv_mode, args.conv_mode, args.conv_mode
154
- )
155
- )
156
- else:
157
- args.conv_mode = conv_mode
158
- self.conv_mode = conv_mode
159
- self.conversation = conv_templates[args.conv_mode].copy()
160
- self.num_frames = args.num_frames
161
-
162
- class ChatSessionManager:
163
- def __init__(self):
164
- self.chatbot_instance = None
165
-
166
- def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
167
- self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
168
- print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
169
-
170
- def reset_chatbot(self):
171
- self.chatbot_instance = None
172
-
173
- def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
174
- if self.chatbot_instance is None:
175
- self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
176
- return self.chatbot_instance
177
-
178
-
179
- def is_valid_video_filename(name):
180
- video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
181
-
182
- ext = name.split(".")[-1].lower()
183
-
184
- if ext in video_extensions:
185
- return True
186
- else:
187
- return False
188
-
189
- def is_valid_image_filename(name):
190
- image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
191
-
192
- ext = name.split(".")[-1].lower()
193
-
194
- if ext in image_extensions:
195
- return True
196
- else:
197
- return False
198
-
199
-
200
- def sample_frames_v1(video_file, num_frames):
201
- video = cv2.VideoCapture(video_file)
202
- total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
203
- interval = total_frames // num_frames
204
- frames = []
205
- for i in range(total_frames):
206
- ret, frame = video.read()
207
- pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
208
- if not ret:
209
- continue
210
- if i % interval == 0:
211
- frames.append(pil_img)
212
- video.release()
213
- return frames
214
-
215
- def sample_frames_v2(video_path, frame_count=32):
216
- video_frames = []
217
- vr = VideoReader(video_path, ctx=cpu(0))
218
- total_frames = len(vr)
219
- frame_interval = max(total_frames // frame_count, 1)
220
-
221
- for i in range(0, total_frames, frame_interval):
222
- frame = vr[i].asnumpy()
223
- frame_image = Image.fromarray(frame) # Convert to PIL.Image
224
- video_frames.append(frame_image)
225
- if len(video_frames) >= frame_count:
226
- break
227
-
228
- # Ensure at least one frame is returned if total frames are less than required
229
- if len(video_frames) < frame_count and total_frames > 0:
230
- for i in range(total_frames):
231
- frame = vr[i].asnumpy()
232
- frame_image = Image.fromarray(frame) # Convert to PIL.Image
233
- video_frames.append(frame_image)
234
- if len(video_frames) >= frame_count:
235
- break
236
-
237
- return video_frames
238
-
239
- def sample_frames(video_path, num_frames=8):
240
- cap = cv2.VideoCapture(video_path)
241
- frames = []
242
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
243
- indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
244
-
245
- for i in indices:
246
- cap.set(cv2.CAP_PROP_POS_FRAMES, i)
247
- ret, frame = cap.read()
248
- if ret:
249
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
250
- frames.append(Image.fromarray(frame))
251
-
252
- cap.release()
253
- return frames
254
-
255
-
256
- def load_image(image_file):
257
- if image_file.startswith("http") or image_file.startswith("https"):
258
- response = requests.get(image_file)
259
- if response.status_code == 200:
260
- image = Image.open(BytesIO(response.content)).convert("RGB")
261
- else:
262
- print("failed to load the image")
263
- else:
264
- print("Load image from local file")
265
- print(image_file)
266
- image = Image.open(image_file).convert("RGB")
267
-
268
- return image
269
-
270
-
271
- def clear_response(history):
272
- for index_conv in range(1, len(history)):
273
- # loop until get a text response from our model.
274
- conv = history[-index_conv]
275
- if not (conv[0] is None):
276
- break
277
- question = history[-index_conv][0]
278
- history = history[:-index_conv]
279
- return history, question
280
-
281
- chat_manager = ChatSessionManager()
282
-
283
-
284
- def clear_history(history):
285
- chatbot_instance = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
286
- chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy()
287
- return None
288
-
289
-
290
-
291
- def add_message(history, message):
292
- global chat_image_num
293
- print("#### len(history)",len(history))
294
- if not history:
295
- history = []
296
- our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
297
- chat_image_num = 0
298
- for x in message["files"]:
299
- if "realcase_video.jpg" in x:
300
- x = x.replace("realcase_video.jpg", "realcase_video.mp4")
301
- history.append(((x,), None))
302
- if message["text"] is not None:
303
- history.append((message["text"], None))
304
- # print(f"### Chatbot instance ID: {id(our_chatbot)}")
305
- return history, gr.MultimodalTextbox(value=None, interactive=False)
306
-
307
-
308
- @spaces.GPU
309
- def bot(history, temperature, top_p, max_output_tokens):
310
- our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
311
- print(f"### Chatbot instance ID: {id(our_chatbot)}")
312
- text = history[-1][0]
313
- images_this_term = []
314
- text_this_term = ""
315
-
316
- is_video = False
317
- num_new_images = 0
318
- # previous_image = False
319
- for i, message in enumerate(history[:-1]):
320
- if type(message[0]) is tuple:
321
- # if previous_image:
322
- # gr.Warning("Only one image can be uploaded in a conversation. Please reduce the number of images and start a new conversation.")
323
- # our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
324
- # return None
325
-
326
- images_this_term.append(message[0][0])
327
- if is_valid_video_filename(message[0][0]):
328
- # raise ValueError("Video is not supported")
329
- # num_new_images += our_chatbot.num_frames
330
- # num_new_images += len(sample_frames(message[0][0], our_chatbot.num_frames))
331
- num_new_images += 1
332
- is_video = True
333
- elif is_valid_image_filename(message[0][0]):
334
- print("#### Load image from local file",message[0][0])
335
- num_new_images += 1
336
- else:
337
- raise ValueError("Invalid file format")
338
- # previous_image = True
339
- else:
340
- num_new_images = 0
341
- # previous_image = False
342
-
343
-
344
- image_list = []
345
- for f in images_this_term:
346
- if is_valid_video_filename(f):
347
- image_list += sample_frames(f, our_chatbot.num_frames)
348
- elif is_valid_image_filename(f):
349
- image_list.append(load_image(f))
350
- else:
351
- raise ValueError("Invalid image file")
352
-
353
- all_image_hash = []
354
- all_image_path = []
355
- for file_path in images_this_term:
356
- with open(file_path, "rb") as file:
357
- file_data = file.read()
358
- file_hash = hashlib.md5(file_data).hexdigest()
359
- all_image_hash.append(file_hash)
360
-
361
- t = datetime.datetime.now()
362
- output_dir = os.path.join(
363
- LOGDIR,
364
- "serve_files",
365
- f"{t.year}-{t.month:02d}-{t.day:02d}"
366
- )
367
- os.makedirs(output_dir, exist_ok=True)
368
-
369
- if is_valid_image_filename(file_path):
370
- # Process and save images
371
- image = Image.open(file_path).convert("RGB")
372
- filename = os.path.join(output_dir, f"{file_hash}.jpg")
373
- all_image_path.append(filename)
374
- if not os.path.isfile(filename):
375
- print("Image saved to", filename)
376
- image.save(filename)
377
-
378
- elif is_valid_video_filename(file_path):
379
- # Simplified video saving
380
- filename = os.path.join(output_dir, f"{file_hash}.mp4")
381
- all_image_path.append(filename)
382
- if not os.path.isfile(filename):
383
- print("Video saved to", filename)
384
- os.makedirs(os.path.dirname(filename), exist_ok=True)
385
- # Directly copy the video file
386
- with open(file_path, "rb") as src, open(filename, "wb") as dst:
387
- dst.write(src.read())
388
-
389
- image_tensor = []
390
- if is_video:
391
- image_tensor = our_chatbot.image_processor.preprocess(image_list, return_tensors="pt")["pixel_values"].half().to(our_chatbot.model.device)
392
- elif num_new_images > 0:
393
- image_tensor = [
394
- our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
395
- 0
396
- ]
397
- .half()
398
- .to(our_chatbot.model.device)
399
- for f in image_list
400
- ]
401
- image_tensor = torch.stack(image_tensor)
402
-
403
- image_token = DEFAULT_IMAGE_TOKEN * num_new_images + "\n"
404
-
405
- inp = text
406
- inp = image_token + inp
407
- our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
408
- # image = None
409
- our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
410
- prompt = our_chatbot.conversation.get_prompt()
411
-
412
- input_ids = tokenizer_image_token(
413
- prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
414
- ).unsqueeze(0).to(our_chatbot.model.device)
415
- # print("### input_id",input_ids)
416
- stop_str = (
417
- our_chatbot.conversation.sep
418
- if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
419
- else our_chatbot.conversation.sep2
420
- )
421
- keywords = [stop_str]
422
- stopping_criteria = KeywordsStoppingCriteria(
423
- keywords, our_chatbot.tokenizer, input_ids
424
- )
425
-
426
- streamer = TextIteratorStreamer(
427
- our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
428
- )
429
- print(our_chatbot.model.device)
430
- print(input_ids.device)
431
- # print(image_tensor.device)
432
-
433
-
434
- if is_video:
435
- input_image_tensor = [image_tensor]
436
- elif num_new_images > 0:
437
- input_image_tensor = image_tensor
438
- else:
439
- input_image_tensor = None
440
-
441
- generate_kwargs = dict(
442
- inputs=input_ids,
443
- streamer=streamer,
444
- images=input_image_tensor,
445
- do_sample=True,
446
- temperature=temperature,
447
- top_p=top_p,
448
- max_new_tokens=max_output_tokens,
449
- use_cache=False,
450
- stopping_criteria=[stopping_criteria],
451
- modalities=["video"] if is_video else ["image"]
452
- )
453
-
454
- t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
455
- t.start()
456
-
457
- outputs = []
458
- for stream_token in streamer:
459
- outputs.append(stream_token)
460
-
461
- history[-1] = [text, "".join(outputs)]
462
- yield history
463
- our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
464
- # print("### turn end history", history)
465
- # print("### turn end conv",our_chatbot.conversation)
466
-
467
- with open(get_conv_log_filename(), "a") as fout:
468
- data = {
469
- "type": "chat",
470
- "model": "MAmmoTH-VL2",
471
- "state": history,
472
- "images": all_image_hash,
473
- "images_path": all_image_path
474
- }
475
- print("#### conv log",data)
476
- fout.write(json.dumps(data) + "\n")
477
- for upload_img in all_image_path:
478
- api.upload_file(
479
- path_or_fileobj=upload_img,
480
- path_in_repo=upload_img.replace("./logs/", ""),
481
- repo_id=repo_name,
482
- repo_type="dataset",
483
- # revision=revision,
484
- # ignore_patterns=["data*"]
485
- )
486
- # upload json
487
- api.upload_file(
488
- path_or_fileobj=get_conv_log_filename(),
489
- path_in_repo=get_conv_log_filename().replace("./logs/", ""),
490
- repo_id=repo_name,
491
- repo_type="dataset")
492
-
493
-
494
-
495
- txt = gr.Textbox(
496
- scale=4,
497
- show_label=False,
498
- placeholder="Enter text and press enter.",
499
- container=False,
500
- )
501
-
502
- with gr.Blocks(
503
- css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 40px}",
504
- ) as demo:
505
-
506
- cur_dir = os.path.dirname(os.path.abspath(__file__))
507
- # gr.Markdown(title_markdown)
508
- gr.HTML(html_header)
509
-
510
- with gr.Column():
511
- with gr.Accordion("Parameters", open=False) as parameter_row:
512
- temperature = gr.Slider(
513
- minimum=0.0,
514
- maximum=1.0,
515
- value=0.7,
516
- step=0.1,
517
- interactive=True,
518
- label="Temperature",
519
- )
520
- top_p = gr.Slider(
521
- minimum=0.0,
522
- maximum=1.0,
523
- value=1,
524
- step=0.1,
525
- interactive=True,
526
- label="Top P",
527
- )
528
- max_output_tokens = gr.Slider(
529
- minimum=0,
530
- maximum=8192,
531
- value=4096,
532
- step=256,
533
- interactive=True,
534
- label="Max output tokens",
535
- )
536
- with gr.Row():
537
- chatbot = gr.Chatbot([], elem_id="MAmmoTH-VL-8B", bubble_full_width=False, height=750)
538
-
539
- with gr.Row():
540
- upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
541
- downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
542
- flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
543
- # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True)
544
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
545
- clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
546
-
547
-
548
- chat_input = gr.MultimodalTextbox(
549
- interactive=True,
550
- file_types=["image", "video"],
551
- placeholder="Enter message or upload file...",
552
- show_label=False,
553
- submit_btn="🚀"
554
- )
555
-
556
- print(cur_dir)
557
- gr.Examples(
558
- examples_per_page=20,
559
- examples=[
560
- [
561
- {
562
- "files": [
563
- f"{cur_dir}/examples/172197131626056_P7966202.png",
564
- ],
565
- "text": "Why this image funny?",
566
- }
567
- ],
568
- [
569
- {
570
- "files": [
571
- f"{cur_dir}/examples/realcase_doc.png",
572
- ],
573
- "text": "Read text in the image",
574
- }
575
- ],
576
- [
577
- {
578
- "files": [
579
- f"{cur_dir}/examples/realcase_weather.jpg",
580
- ],
581
- "text": "List the weather for Monday to Friday",
582
- }
583
- ],
584
- [
585
- {
586
- "files": [
587
- f"{cur_dir}/examples/realcase_knowledge.jpg",
588
- ],
589
- "text": "Answer the following question based on the provided image: What country do these planes belong to?",
590
- }
591
- ],
592
- [
593
- {
594
- "files": [
595
- f"{cur_dir}/examples/realcase_math.jpg",
596
- ],
597
- "text": "Find the measure of angle 3. Please provide a step by step solution.",
598
- }
599
- ],
600
- [
601
- {
602
- "files": [
603
- f"{cur_dir}/examples/realcase_interact.jpg",
604
- ],
605
- "text": "Please perfectly describe this cartoon illustration in as much detail as possible",
606
- }
607
- ],
608
- [
609
- {
610
- "files": [
611
- f"{cur_dir}/examples/realcase_perfer.jpg",
612
- ],
613
- "text": "This is an image of a room. It could either be a real image captured in the room or a rendered image from a 3D scene reconstruction technique that is trained using real images of the room. A rendered image usually contains some visible artifacts (eg. blurred regions due to under-reconstructed areas) that do not faithfully represent the actual scene. You need to decide if its a real image or a rendered image by giving each image a photorealism score between 1 and 5.",
614
- }
615
- ],
616
- [
617
- {
618
- "files": [
619
- f"{cur_dir}/examples/realcase_multi1.png",
620
- f"{cur_dir}/examples/realcase_multi2.png",
621
- f"{cur_dir}/examples/realcase_multi3.png",
622
- f"{cur_dir}/examples/realcase_multi4.png",
623
- f"{cur_dir}/examples/realcase_multi5.png",
624
- ],
625
- "text": "Based on the five species in the images, draw a food chain. Explain the role of each species in the food chain.",
626
- }
627
- ],
628
- ],
629
- inputs=[chat_input],
630
- label="Real World Image Cases",
631
- )
632
- gr.Examples(
633
- examples=[
634
- [
635
- {
636
- "files": [
637
- f"{cur_dir}/examples/realcase_video.mp4",
638
- ],
639
- "text": "Please describe the video in detail.",
640
- },
641
- ]
642
- ],
643
- inputs=[chat_input],
644
- label="Real World Video Case"
645
- )
646
-
647
- gr.Markdown(tos_markdown)
648
- gr.Markdown(learn_more_markdown)
649
- gr.Markdown(bibtext)
650
-
651
- chat_input.submit(
652
- add_message, [chatbot, chat_input], [chatbot, chat_input]
653
- ).then(bot, [chatbot, temperature, top_p, max_output_tokens], chatbot, api_name="bot_response").then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
654
-
655
-
656
- # chatbot.like(print_like_dislike, None, None)
657
- clear_btn.click(
658
- fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all"
659
- )
660
-
661
- upvote_btn.click(
662
- fn=upvote_last_response, inputs=chatbot, outputs=chatbot, api_name="upvote_last_response"
663
- )
664
-
665
-
666
- downvote_btn.click(
667
- fn=downvote_last_response, inputs=chatbot, outputs=chatbot, api_name="upvote_last_response"
668
- )
669
-
670
-
671
- demo.queue()
672
-
673
- if __name__ == "__main__":
674
- import argparse
675
-
676
- argparser = argparse.ArgumentParser()
677
- argparser.add_argument("--server_name", default="0.0.0.0", type=str)
678
- argparser.add_argument("--model_path", default="TIGER-Lab/MAmmoTH-VL2", type=str)
679
- argparser.add_argument("--model-base", type=str, default=None)
680
- argparser.add_argument("--num-gpus", type=int, default=1)
681
- argparser.add_argument("--conv-mode", type=str, default=None)
682
- argparser.add_argument("--temperature", type=float, default=0.7)
683
- argparser.add_argument("--max-new-tokens", type=int, default=4096)
684
- argparser.add_argument("--num_frames", type=int, default=32)
685
- argparser.add_argument("--load-8bit", action="store_true")
686
- argparser.add_argument("--load-4bit", action="store_true")
687
- argparser.add_argument("--debug", action="store_true")
688
-
689
- args = argparser.parse_args()
690
-
691
- model_path = args.model_path
692
- filt_invalid = "cut"
693
- model_name = get_model_name_from_path(args.model_path)
694
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
695
- model=model.to(torch.device('cuda'))
696
- chat_image_num = 0
697
- demo.launch()