wenhu commited on
Commit
7b3e7e6
·
verified ·
1 Parent(s): 442a515

Update app_test.py

Browse files
Files changed (1) hide show
  1. app_test.py +90 -0
app_test.py CHANGED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .demo_modelpart import InferenceDemo
2
+ import gradio as gr
3
+ import os
4
+ from threading import Thread
5
+
6
+ # import time
7
+ import cv2
8
+
9
+ import datetime
10
+ # import copy
11
+ import torch
12
+
13
+ import spaces
14
+ import numpy as np
15
+
16
+ from llava.constants import DEFAULT_IMAGE_TOKEN
17
+
18
+ from llava.constants import (
19
+ IMAGE_TOKEN_INDEX,
20
+ DEFAULT_IMAGE_TOKEN,
21
+ )
22
+ from llava.conversation import conv_templates, SeparatorStyle
23
+ from llava.model.builder import load_pretrained_model
24
+ from llava.utils import disable_torch_init
25
+ from llava.mm_utils import (
26
+ tokenizer_image_token,
27
+ get_model_name_from_path,
28
+ KeywordsStoppingCriteria,
29
+ )
30
+
31
+ from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown
32
+
33
+ from decord import VideoReader, cpu
34
+
35
+ import requests
36
+ from PIL import Image
37
+ import io
38
+ from io import BytesIO
39
+ from transformers import TextStreamer, TextIteratorStreamer
40
+
41
+ import hashlib
42
+ import PIL
43
+ import base64
44
+ import json
45
+
46
+ import datetime
47
+ import gradio as gr
48
+ import gradio_client
49
+ import subprocess
50
+ import sys
51
+
52
+ from huggingface_hub import HfApi
53
+ from huggingface_hub import login
54
+ from huggingface_hub import revision_exists
55
+
56
+ login(token=os.environ["HF_TOKEN"],
57
+ write_permission=True)
58
+
59
+ api = HfApi()
60
+ repo_name = os.environ["LOG_REPO"]
61
+
62
+ external_log_dir = "./logs"
63
+ LOGDIR = external_log_dir
64
+ VOTEDIR = "./votes"
65
+
66
+ if __name__ == "__main__":
67
+ import argparse
68
+
69
+ argparser = argparse.ArgumentParser()
70
+ argparser.add_argument("--server_name", default="0.0.0.0", type=str)
71
+ argparser.add_argument("--model_path", default="TIGER-Lab/MAmmoTH-VL2", type=str)
72
+ argparser.add_argument("--model-base", type=str, default=None)
73
+ argparser.add_argument("--num-gpus", type=int, default=1)
74
+ argparser.add_argument("--conv-mode", type=str, default=None)
75
+ argparser.add_argument("--temperature", type=float, default=0.7)
76
+ argparser.add_argument("--max-new-tokens", type=int, default=4096)
77
+ argparser.add_argument("--num_frames", type=int, default=32)
78
+ argparser.add_argument("--load-8bit", action="store_true")
79
+ argparser.add_argument("--load-4bit", action="store_true")
80
+ argparser.add_argument("--debug", action="store_true")
81
+
82
+ args = argparser.parse_args()
83
+
84
+ model_path = args.model_path
85
+ filt_invalid = "cut"
86
+ model_name = get_model_name_from_path(args.model_path)
87
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
88
+ model=model.to(torch.device('cuda'))
89
+ chat_image_num = 0
90
+ demo.launch()