BenkHel commited on
Commit
b0dba11
·
verified ·
1 Parent(s): 8df865e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -74
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import subprocess
3
  import sys
4
  import os
@@ -10,7 +9,7 @@ import subprocess
10
  import spaces
11
  import cumo.serve.gradio_web_server as gws
12
 
13
- from transformers import AutoProcessor,AutoTokenizer, AutoImageProcessor
14
 
15
  import datetime
16
  import json
@@ -36,55 +35,6 @@ from cumo.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_ST
36
  from transformers import TextIteratorStreamer
37
  from threading import Thread
38
 
39
- # Execute the pip install command with additional options
40
- #subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U']
41
-
42
- headers = {"User-Agent": "CuMo"}
43
-
44
- no_change_btn = gr.Button()
45
- enable_btn = gr.Button(interactive=True)
46
- disable_btn = gr.Button(interactive=False)
47
-
48
- device = "cuda" if torch.cuda.is_available() else "cpu"
49
- model_path = 'BenkHel/CumoThesis'
50
- conv_mode = 'mistral_instruct_system' # Diese Variable wird noch für die Konversationstemplates benötigt
51
- load_8bit = False
52
- load_4bit = False
53
-
54
- import sys
55
- import os
56
- import argparse
57
- import time
58
- import subprocess
59
- import spaces
60
- import cumo.serve.gradio_web_server as gws
61
-
62
- import datetime
63
- import json
64
-
65
- import gradio as gr
66
- import requests
67
- from PIL import Image
68
-
69
- from cumo.conversation import (default_conversation, conv_templates, SeparatorStyle)
70
- from cumo.constants import LOGDIR
71
- from cumo.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
72
- import hashlib
73
-
74
- import torch
75
- import io
76
- from cumo.constants import WORKER_HEART_BEAT_INTERVAL
77
- from cumo.utils import (build_logger, server_error_msg,
78
- pretty_print_semaphore)
79
- from cumo.model.builder import load_pretrained_model
80
- from cumo.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
81
- from cumo.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
82
- from transformers import TextIteratorStreamer
83
- from threading import Thread
84
-
85
- # Execute the pip install command with additional options
86
- #subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U']
87
-
88
  headers = {"User-Agent": "CuMo"}
89
 
90
  no_change_btn = gr.Button()
@@ -98,17 +48,21 @@ model_name = 'CuMo-mistral-7b'
98
  conv_mode = 'mistral_instruct_system'
99
  load_8bit = False
100
  load_4bit = False
101
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, load_8bit, load_4bit, device=device, use_flash_attn=False)
 
 
 
102
  model.config.training = False
103
-
 
 
 
104
  def upvote_last_response(state):
105
  return ("",) + (disable_btn,) * 3
106
 
107
-
108
  def downvote_last_response(state):
109
  return ("",) + (disable_btn,) * 3
110
 
111
-
112
  def flag_last_response(state):
113
  return ("",) + (disable_btn,) * 3
114
 
@@ -121,15 +75,12 @@ def add_text(state, imagebox, textbox, image_process_mode):
121
  state = conv_templates[conv_mode].copy()
122
 
123
  if imagebox is not None:
124
- textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox
125
  image = Image.open(imagebox).convert('RGB')
126
-
127
  if imagebox is not None:
128
  textbox = (textbox, image, image_process_mode)
129
-
130
  state.append_message(state.roles[0], textbox)
131
  state.append_message(state.roles[1], None)
132
-
133
  yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
134
 
135
  def delete_text(state, image_process_mode):
@@ -149,9 +100,8 @@ def regenerate(state, image_process_mode):
149
 
150
  @spaces.GPU
151
  def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens):
152
- prompt = state.get_prompt()
153
  images = state.get_images(return_pil=True)
154
- #prompt, image_args = process_image(prompt, images)
155
 
156
  ori_prompt = prompt
157
  num_image_tokens = 0
@@ -160,8 +110,6 @@ def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, m
160
  if len(images) > 0:
161
  if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
162
  raise ValueError("Number of images does not match number of <image> tokens in prompt")
163
-
164
- #images = [load_image_from_base64(image) for image in images]
165
  image_sizes = [image.size for image in images]
166
  images = process_images(images, image_processor, model.config)
167
 
@@ -174,7 +122,6 @@ def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, m
174
  if getattr(model.config, 'mm_use_im_start_end', False):
175
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
176
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
177
-
178
  num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
179
  else:
180
  images = None
@@ -193,7 +140,6 @@ def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, m
193
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
194
 
195
  max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
196
-
197
  if max_new_tokens < 1:
198
  yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
199
  return
@@ -217,25 +163,29 @@ def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, m
217
  generated_text = generated_text[:-len(stop_str)]
218
  state.messages[-1][-1] = generated_text
219
  yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
220
-
221
  yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
222
-
223
  torch.cuda.empty_cache()
224
 
225
  title_markdown = ("""
226
- # CuMo: Scaling Multimodal LLM with Co-Upcycled Mixture-of-Experts
227
- [[Project Page](https://chrisjuniorli.github.io/project/CuMo/)] [[Code](https://github.com/SHI-Labs/CuMo)] [[Model](https://huggingface.co/shi-labs/CuMo-mistral-7b)] | 📚 [[Arxiv](https://arxiv.org/pdf/2405.05949)]]
228
  """)
229
 
230
  tos_markdown = ("""
231
- ### Terms of use
232
- By using this service, users are required to agree to the following terms:
233
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
234
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
 
 
 
 
 
 
235
  For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
236
  """)
237
 
238
 
 
239
  learn_more_markdown = ("""
240
  ### License
241
  The service is a research preview intended for non-commercial use only, subject to the. Please contact us if you find any potential violation.
@@ -247,7 +197,15 @@ block_css = """
247
  }
248
  """
249
 
250
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
 
 
 
 
 
 
 
 
251
  with gr.Blocks(title="CuMo", theme=gr.themes.Default(), css=block_css) as demo:
252
  state = gr.State()
253
 
 
 
1
  import subprocess
2
  import sys
3
  import os
 
9
  import spaces
10
  import cumo.serve.gradio_web_server as gws
11
 
12
+ from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor
13
 
14
  import datetime
15
  import json
 
35
  from transformers import TextIteratorStreamer
36
  from threading import Thread
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  headers = {"User-Agent": "CuMo"}
39
 
40
  no_change_btn = gr.Button()
 
48
  conv_mode = 'mistral_instruct_system'
49
  load_8bit = False
50
  load_4bit = False
51
+
52
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
53
+ model_path, model_base, model_name, load_8bit, load_4bit, device=device, use_flash_attn=False
54
+ )
55
  model.config.training = False
56
+
57
+ # FIXED PROMPT
58
+ FIXED_PROMPT = "What material is this item and how to dispose of it?"
59
+
60
  def upvote_last_response(state):
61
  return ("",) + (disable_btn,) * 3
62
 
 
63
  def downvote_last_response(state):
64
  return ("",) + (disable_btn,) * 3
65
 
 
66
  def flag_last_response(state):
67
  return ("",) + (disable_btn,) * 3
68
 
 
75
  state = conv_templates[conv_mode].copy()
76
 
77
  if imagebox is not None:
78
+ textbox = DEFAULT_IMAGE_TOKEN + '\n' + FIXED_PROMPT
79
  image = Image.open(imagebox).convert('RGB')
 
80
  if imagebox is not None:
81
  textbox = (textbox, image, image_process_mode)
 
82
  state.append_message(state.roles[0], textbox)
83
  state.append_message(state.roles[1], None)
 
84
  yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
85
 
86
  def delete_text(state, image_process_mode):
 
100
 
101
  @spaces.GPU
102
  def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens):
103
+ prompt = FIXED_PROMPT # <-- Hier fest!
104
  images = state.get_images(return_pil=True)
 
105
 
106
  ori_prompt = prompt
107
  num_image_tokens = 0
 
110
  if len(images) > 0:
111
  if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
112
  raise ValueError("Number of images does not match number of <image> tokens in prompt")
 
 
113
  image_sizes = [image.size for image in images]
114
  images = process_images(images, image_processor, model.config)
115
 
 
122
  if getattr(model.config, 'mm_use_im_start_end', False):
123
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
124
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
 
125
  num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
126
  else:
127
  images = None
 
140
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
141
 
142
  max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
 
143
  if max_new_tokens < 1:
144
  yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
145
  return
 
163
  generated_text = generated_text[:-len(stop_str)]
164
  state.messages[-1][-1] = generated_text
165
  yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
 
166
  yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
 
167
  torch.cuda.empty_cache()
168
 
169
  title_markdown = ("""
170
+ # CuMo: Trained for waste management
 
171
  """)
172
 
173
  tos_markdown = ("""
174
+ ### Source and Terms of use
175
+ This demo is based on the original CuMo project by SHI-Labs ([GitHub](https://github.com/SHI-Labs/CuMo)).
176
+ If you use this service or build upon this work, please cite the original publication:
177
+ Li, Jiachen and Wang, Xinyao and Zhu, Sijie and Kuo, Chia-wen and Xu, Lu and Chen, Fan and Jain, Jitesh and Shi, Humphrey and Wen, Longyin.
178
+ CuMo: Scaling Multimodal LLM with Co-Upcycled Mixture-of-Experts. arXiv preprint, 2024.
179
+ [[arXiv](https://arxiv.org/abs/2405.05949)]
180
+
181
+ By using this service, users are required to agree to the following terms:
182
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
183
+
184
  For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
185
  """)
186
 
187
 
188
+
189
  learn_more_markdown = ("""
190
  ### License
191
  The service is a research preview intended for non-commercial use only, subject to the. Please contact us if you find any potential violation.
 
197
  }
198
  """
199
 
200
+
201
+
202
+ textbox = gr.Textbox(
203
+ show_label=False,
204
+ placeholder="Prompt is fixed: What material is this item and how to dispose of it?",
205
+ container=False,
206
+ interactive=False
207
+ )
208
+
209
  with gr.Blocks(title="CuMo", theme=gr.themes.Default(), css=block_css) as demo:
210
  state = gr.State()
211