BenkHel commited on
Commit
5dad961
·
verified ·
1 Parent(s): c3c0c8e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -0
app.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ import os
4
+
5
+ from transformers import TextIteratorStreamer
6
+ import argparse
7
+ import time
8
+ import subprocess
9
+ import spaces
10
+ import cumo.serve.gradio_web_server as gws
11
+
12
+ from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor
13
+
14
+ import datetime
15
+ import json
16
+
17
+ import gradio as gr
18
+ import requests
19
+ from PIL import Image
20
+
21
+ from cumo.conversation import (default_conversation, conv_templates, SeparatorStyle)
22
+ from cumo.constants import LOGDIR
23
+ from cumo.model.language_model.llava_mistral import LlavaMistralForCausalLM
24
+ from cumo.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
25
+ import hashlib
26
+
27
+ import torch
28
+ import io
29
+ from cumo.constants import WORKER_HEART_BEAT_INTERVAL
30
+ from cumo.utils import (build_logger, server_error_msg,
31
+ pretty_print_semaphore)
32
+ from cumo.model.builder import load_pretrained_model
33
+ from cumo.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
34
+ from cumo.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
35
+ from transformers import TextIteratorStreamer
36
+ from threading import Thread
37
+
38
+ headers = {"User-Agent": "CuMo"}
39
+
40
+ no_change_btn = gr.Button()
41
+ enable_btn = gr.Button(interactive=True)
42
+ disable_btn = gr.Button(interactive=False)
43
+
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ model_path = 'BenkHel/CumoThesis'
46
+ model_base = 'mistralai/Mistral-7B-Instruct-v0.2'
47
+ model_name = 'CuMo-mistral-7b'
48
+ conv_mode = 'mistral_instruct_system'
49
+ load_8bit = False
50
+ load_4bit = False
51
+
52
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
53
+ model_path, model_base, model_name, load_8bit, load_4bit, device=device, use_flash_attn=False
54
+ )
55
+ model.config.training = False
56
+
57
+ # FIXED PROMPT
58
+ FIXED_PROMPT = "<image>\nWhat material is this item and how to dispose of it?"
59
+
60
+ def clear_history():
61
+ state = default_conversation.copy()
62
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
63
+
64
+ def add_text(state, imagebox, textbox, image_process_mode):
65
+ if state is None:
66
+ state = conv_templates[conv_mode].copy()
67
+
68
+ if imagebox is not None:
69
+ textbox = FIXED_PROMPT
70
+ image = Image.open(imagebox).convert('RGB')
71
+ textbox = (textbox, image, image_process_mode)
72
+ state.append_message(state.roles[0], textbox)
73
+ state.append_message(state.roles[1], None)
74
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
75
+
76
+ def delete_text(state, image_process_mode):
77
+ state.messages[-1][-1] = None
78
+ prev_human_msg = state.messages[-2]
79
+ if type(prev_human_msg[1]) in (tuple, list):
80
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
81
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
82
+
83
+ def regenerate(state, image_process_mode):
84
+ state.messages[-1][-1] = None
85
+ prev_human_msg = state.messages[-2]
86
+ if type(prev_human_msg[1]) in (tuple, list):
87
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
88
+ state.skip_next = False
89
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
90
+
91
+ @spaces.GPU
92
+ def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens):
93
+ prompt = FIXED_PROMPT # <-- Hier fest!
94
+ images = state.get_images(return_pil=True)
95
+
96
+ ori_prompt = prompt
97
+ num_image_tokens = 0
98
+
99
+ if images is not None and len(images) > 0:
100
+ if len(images) > 0:
101
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
102
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
103
+ image_sizes = [image.size for image in images]
104
+ images = process_images(images, image_processor, model.config)
105
+
106
+ if type(images) is list:
107
+ images = [image.to(model.device, dtype=torch.float16) for image in images]
108
+ else:
109
+ images = images.to(model.device, dtype=torch.float16)
110
+
111
+ replace_token = DEFAULT_IMAGE_TOKEN
112
+ if getattr(model.config, 'mm_use_im_start_end', False):
113
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
114
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
115
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
116
+ else:
117
+ images = None
118
+ image_sizes = None
119
+ image_args = {"images": images, "image_sizes": image_sizes}
120
+ else:
121
+ images = None
122
+ image_args = {}
123
+
124
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
125
+ max_new_tokens = 512
126
+ do_sample = True if temperature > 0.001 else False
127
+ stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
128
+
129
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
130
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
131
+
132
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
133
+ if max_new_tokens < 1:
134
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
135
+ return
136
+
137
+ thread = Thread(target=model.generate, kwargs=dict(
138
+ inputs=input_ids,
139
+ do_sample=do_sample,
140
+ temperature=temperature,
141
+ top_p=top_p,
142
+ max_new_tokens=max_new_tokens,
143
+ streamer=streamer,
144
+ use_cache=True,
145
+ pad_token_id=tokenizer.eos_token_id,
146
+ **image_args
147
+ ))
148
+ thread.start()
149
+ generated_text = ''
150
+ for new_text in streamer:
151
+ generated_text += new_text
152
+ if generated_text.endswith(stop_str):
153
+ generated_text = generated_text[:-len(stop_str)]
154
+ state.messages[-1][-1] = generated_text
155
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
156
+ yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
157
+ torch.cuda.empty_cache()
158
+
159
+ title_markdown = ("""
160
+ # CuMo: Trained for waste management
161
+ """)
162
+
163
+ tos_markdown = ("""
164
+ ### Please "🗑️ Clear" the output before offering a new picture!
165
+ ### Source and Terms of use
166
+ This demo is based on the original CuMo project by SHI-Labs ([GitHub](https://github.com/SHI-Labs/CuMo)).
167
+ If you use this service or build upon this work, please cite the original publication:
168
+ Li, Jiachen and Wang, Xinyao and Zhu, Sijie and Kuo, Chia-wen and Xu, Lu and Chen, Fan and Jain, Jitesh and Shi, Humphrey and Wen, Longyin.
169
+ CuMo: Scaling Multimodal LLM with Co-Upcycled Mixture-of-Experts. arXiv preprint, 2024.
170
+ [[arXiv](https://arxiv.org/abs/2405.05949)]
171
+
172
+ By using this service, users are required to agree to the following terms:
173
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
174
+
175
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
176
+ """)
177
+
178
+
179
+
180
+ learn_more_markdown = ("""
181
+ ### License
182
+ The service is a research preview intended for non-commercial use only, subject to the. Please contact us if you find any potential violation.
183
+ """)
184
+
185
+ block_css = """
186
+ #buttons button {
187
+ min-width: min(120px,100%);
188
+ }
189
+ """
190
+
191
+
192
+
193
+ textbox = gr.Textbox(
194
+ show_label=False,
195
+ placeholder="Prompt is fixed: What material is this item and how to dispose of it?",
196
+ container=False,
197
+ interactive=False
198
+ )
199
+
200
+ with gr.Blocks(title="CuMo", theme=gr.themes.Default(), css=block_css) as demo:
201
+ state = gr.State()
202
+
203
+ gr.Markdown(title_markdown)
204
+
205
+ with gr.Row():
206
+ with gr.Column(scale=3):
207
+ imagebox = gr.Image(label="Input Image", type="filepath")
208
+ image_process_mode = gr.Radio(
209
+ ["Crop", "Resize", "Pad", "Default"],
210
+ value="Default",
211
+ label="Preprocess for non-square image", visible=False)
212
+
213
+
214
+ #cur_dir = os.path.dirname(os.path.abspath(__file__))
215
+ cur_dir = './cumo/serve'
216
+ default_prompt = "<image>\nWhat material is this item and how to dispose of it?"
217
+ gr.Examples(examples=[
218
+ [f"{cur_dir}/examples/0165 CB.jpg", default_prompt],
219
+ [f"{cur_dir}/examples/0225 PA.jpg", default_prompt],
220
+ [f"{cur_dir}/examples/0787 GM.jpg", default_prompt],
221
+ [f"{cur_dir}/examples/1396 A.jpg", default_prompt],
222
+ [f"{cur_dir}/examples/2001 P.jpg", default_prompt],
223
+ [f"{cur_dir}/examples/2658 PE.jpg", default_prompt],
224
+ [f"{cur_dir}/examples/3113 R.jpg", default_prompt],
225
+ [f"{cur_dir}/examples/3750 RPC.jpg", default_prompt],
226
+ [f"{cur_dir}/examples/5033 CC.jpg", default_prompt],
227
+ [f"{cur_dir}/examples/5307 B.jpg", default_prompt],
228
+ ], inputs=[imagebox, textbox], cache_examples=False)
229
+
230
+
231
+ with gr.Accordion("Parameters", open=False) as parameter_row:
232
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
233
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
234
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
235
+
236
+ with gr.Column(scale=8):
237
+ chatbot = gr.Chatbot(
238
+ elem_id="chatbot",
239
+ label="CuMo Chatbot",
240
+ height=650,
241
+ layout="panel",
242
+ )
243
+ with gr.Row():
244
+ with gr.Column(scale=8):
245
+ textbox.render()
246
+ with gr.Column(scale=1, min_width=50):
247
+ submit_btn = gr.Button(value="Send", variant="primary")
248
+ with gr.Row(elem_id="buttons") as button_row:
249
+ stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
250
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
251
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
252
+
253
+
254
+ gr.Markdown(tos_markdown)
255
+ gr.Markdown(learn_more_markdown)
256
+ url_params = gr.JSON(visible=False)
257
+
258
+ # Register listeners
259
+ btn_list = [regenerate_btn, clear_btn]
260
+ clear_btn.click(
261
+ clear_history,
262
+ None,
263
+ [state, chatbot, textbox, imagebox] + btn_list,
264
+ queue=False
265
+ )
266
+
267
+ regenerate_btn.click(
268
+ delete_text,
269
+ [state, image_process_mode],
270
+ [state, chatbot, textbox, imagebox] + btn_list,
271
+ ).then(
272
+ generate,
273
+ [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
274
+ [state, chatbot, textbox, imagebox] + btn_list,
275
+ )
276
+ textbox.submit(
277
+ add_text,
278
+ [state, imagebox, textbox, image_process_mode],
279
+ [state, chatbot, textbox, imagebox] + btn_list,
280
+ ).then(
281
+ generate,
282
+ [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
283
+ [state, chatbot, textbox, imagebox] + btn_list,
284
+ )
285
+
286
+ submit_btn.click(
287
+ add_text,
288
+ [state, imagebox, textbox, image_process_mode],
289
+ [state, chatbot, textbox, imagebox] + btn_list,
290
+ ).then(
291
+ generate,
292
+ [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
293
+ [state, chatbot, textbox, imagebox] + btn_list,
294
+ )
295
+
296
+ demo.queue(
297
+ status_update_rate=10,
298
+ api_open=False
299
+ ).launch()