Vision-CAIR commited on
Commit
dd78ee5
Β·
1 Parent(s): 0dea8d9

Delete demo.py

Browse files
Files changed (1) hide show
  1. demo.py +0 -171
demo.py DELETED
@@ -1,171 +0,0 @@
1
- import argparse
2
- import os
3
- import random
4
-
5
- import numpy as np
6
- import torch
7
- import torch.backends.cudnn as cudnn
8
- import gradio as gr
9
-
10
- from transformers import StoppingCriteriaList
11
-
12
- from minigpt4.common.config import Config
13
- from minigpt4.common.dist_utils import get_rank
14
- from minigpt4.common.registry import registry
15
- from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub
16
-
17
- # imports modules for registration
18
- from minigpt4.datasets.builders import *
19
- from minigpt4.models import *
20
- from minigpt4.processors import *
21
- from minigpt4.runners import *
22
- from minigpt4.tasks import *
23
-
24
-
25
- def parse_args():
26
- parser = argparse.ArgumentParser(description="Demo")
27
- parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
28
- parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
29
- parser.add_argument(
30
- "--options",
31
- nargs="+",
32
- help="override some settings in the used config, the key-value pair "
33
- "in xxx=yyy format will be merged into config file (deprecate), "
34
- "change to --cfg-options instead.",
35
- )
36
- args = parser.parse_args()
37
- return args
38
-
39
-
40
- def setup_seeds(config):
41
- seed = config.run_cfg.seed + get_rank()
42
-
43
- random.seed(seed)
44
- np.random.seed(seed)
45
- torch.manual_seed(seed)
46
-
47
- cudnn.benchmark = False
48
- cudnn.deterministic = True
49
-
50
-
51
- # ========================================
52
- # Model Initialization
53
- # ========================================
54
-
55
- conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
56
- 'pretrain_llama2': CONV_VISION_LLama2}
57
-
58
- print('Initializing Chat')
59
- args = parse_args()
60
- cfg = Config(args)
61
-
62
- model_config = cfg.model_cfg
63
- model_config.device_8bit = args.gpu_id
64
- model_cls = registry.get_model_class(model_config.arch)
65
- model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
66
-
67
- CONV_VISION = conv_dict[model_config.model_type]
68
-
69
- vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
70
- vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
71
-
72
- stop_words_ids = [[835], [2277, 29937]]
73
- stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
74
- stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
75
-
76
- chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
77
- print('Initialization Finished')
78
-
79
-
80
- # ========================================
81
- # Gradio Setting
82
- # ========================================
83
-
84
-
85
- def gradio_reset(chat_state, img_list):
86
- if chat_state is not None:
87
- chat_state.messages = []
88
- if img_list is not None:
89
- img_list = []
90
- return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
91
-
92
-
93
- def upload_img(gr_img, text_input, chat_state):
94
- if gr_img is None:
95
- return None, None, gr.update(interactive=True), chat_state, None
96
- chat_state = CONV_VISION.copy()
97
- img_list = []
98
- llm_message = chat.upload_img(gr_img, chat_state, img_list)
99
- chat.encode_img(img_list)
100
- return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
101
-
102
-
103
- def gradio_ask(user_message, chatbot, chat_state):
104
- if len(user_message) == 0:
105
- return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
106
- chat.ask(user_message, chat_state)
107
- chatbot = chatbot + [[user_message, None]]
108
- return '', chatbot, chat_state
109
-
110
-
111
- def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
112
- llm_message = chat.answer(conv=chat_state,
113
- img_list=img_list,
114
- num_beams=num_beams,
115
- temperature=temperature,
116
- max_new_tokens=300,
117
- max_length=2000)[0]
118
- chatbot[-1][1] = llm_message
119
- return chatbot, chat_state, img_list
120
-
121
-
122
- title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
123
- description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
124
- article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
125
- """
126
-
127
- #TODO show examples below
128
-
129
- with gr.Blocks() as demo:
130
- gr.Markdown(title)
131
- gr.Markdown(description)
132
- gr.Markdown(article)
133
-
134
- with gr.Row():
135
- with gr.Column(scale=1):
136
- image = gr.Image(type="pil")
137
- upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
138
- clear = gr.Button("Restart")
139
-
140
- num_beams = gr.Slider(
141
- minimum=1,
142
- maximum=10,
143
- value=1,
144
- step=1,
145
- interactive=True,
146
- label="beam search numbers)",
147
- )
148
-
149
- temperature = gr.Slider(
150
- minimum=0.1,
151
- maximum=2.0,
152
- value=1.0,
153
- step=0.1,
154
- interactive=True,
155
- label="Temperature",
156
- )
157
-
158
- with gr.Column(scale=2):
159
- chat_state = gr.State()
160
- img_list = gr.State()
161
- chatbot = gr.Chatbot(label='MiniGPT-4')
162
- text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
163
-
164
- upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
165
-
166
- text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
167
- gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
168
- )
169
- clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
170
-
171
- demo.launch(share=True, enable_queue=True)