salma-remyx commited on
Commit
e19b349
·
1 Parent(s): 476e594

update app

Browse files
Files changed (1) hide show
  1. app.py +69 -118
app.py CHANGED
@@ -1,176 +1,127 @@
1
  import spaces
2
  import torch
3
- import time
4
  import gradio as gr
5
  from PIL import Image
6
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
7
- from typing import List
8
  from functools import lru_cache
9
 
10
  MODEL_ID = "remyxai/SpaceThinker-Qwen2.5VL-3B"
11
 
12
- @spaces.GPU
13
  @lru_cache(maxsize=1)
14
- def load_model():
15
- print("Loading model and processor...")
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
18
  MODEL_ID,
19
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
20
- ).to(device)
21
  processor = AutoProcessor.from_pretrained(MODEL_ID)
22
  return model, processor
23
 
24
- def process_image(image_path_or_obj):
25
- if isinstance(image_path_or_obj, str):
26
- image = Image.open(image_path_or_obj).convert("RGB")
27
- elif isinstance(image_path_or_obj, Image.Image):
28
- image = image_path_or_obj.convert("RGB")
29
- else:
30
- raise ValueError("process_image expects a file path (str) or PIL.Image")
31
-
32
- max_width = 512
33
- if image.width > max_width:
34
- aspect_ratio = image.height / image.width
35
- new_height = int(max_width * aspect_ratio)
36
- image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)
37
- return image
38
-
39
- def get_latest_image(history):
40
- for item in reversed(history):
41
- if item["role"] == "user" and isinstance(item["content"], tuple):
42
- return item["content"][0]
43
- return None
44
-
45
- def only_assistant_text(full_text: str) -> str:
46
- if "assistant" in full_text:
47
- parts = full_text.split("assistant", 1)
48
- result = parts[-1].strip()
49
- result = result.lstrip(":").strip()
50
- return result
51
- return full_text.strip()
52
-
53
- def run_inference(image, prompt):
54
- model, processor = load_model()
55
  system_msg = (
56
- "You are VL-Thinking 🤔, a helpful assistant with excellent reasoning ability. "
57
- "You should first think about the reasoning process and then provide the answer. "
58
- "Use <think>...</think> and <answer>...</answer> tags."
59
  )
60
  conversation = [
61
- {
62
- "role": "system",
63
- "content": [{"type": "text", "text": system_msg}],
64
- },
65
- {
66
- "role": "user",
67
- "content": [
68
- {"type": "image", "image": image},
69
- {"type": "text", "text": prompt},
70
- ],
71
- },
72
  ]
73
- text_input = processor.apply_chat_template(
 
 
74
  conversation, tokenize=False, add_generation_prompt=True
75
  )
 
 
 
 
 
 
 
 
76
 
77
- inputs = processor(text=[text_input], images=[image], return_tensors="pt").to(model.device)
78
- generated_ids = model.generate(**inputs, max_new_tokens=1024)
79
- output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
80
- return only_assistant_text(output_text)
81
 
82
  def add_message(history, user_input):
83
- if not isinstance(history, list):
84
  history = []
85
-
86
- files = user_input.get("files", [])
87
- text = user_input.get("text", "")
88
-
89
- for f in files:
90
  history.append({"role": "user", "content": (f,)})
91
-
92
  if text:
93
  history.append({"role": "user", "content": text})
94
-
95
  return history, gr.MultimodalTextbox(value=None)
96
 
 
97
  def inference_interface(history):
98
  if not history:
99
  return history, gr.MultimodalTextbox(value=None)
100
-
101
- user_text = ""
102
- user_idx = -1
103
- for idx in range(len(history) - 1, -1, -1):
104
- msg = history[idx]
105
- if msg["role"] == "user" and isinstance(msg["content"], str):
106
- user_text = msg["content"]
107
- user_idx = idx
108
- break
109
-
110
- if user_idx == -1:
111
  return history, gr.MultimodalTextbox(value=None)
112
-
113
- latest_image = get_latest_image(history)
114
- if not latest_image:
 
 
 
 
115
  return history, gr.MultimodalTextbox(value=None)
116
 
117
- pil_image = process_image(latest_image)
118
- assistant_reply = run_inference(pil_image, user_text)
119
-
120
- history.append({"role": "assistant", "content": assistant_reply})
121
  return history, gr.MultimodalTextbox(value=None)
122
 
 
123
  def build_demo():
124
  with gr.Blocks() as demo:
125
  gr.Markdown("# SpaceThinker-Qwen2.5VL-3B Image Prompt Chatbot")
126
-
127
- chatbot = gr.Chatbot([], type="messages", line_breaks=True)
128
-
129
  chat_input = gr.MultimodalTextbox(
130
  interactive=True,
131
  file_types=["image"],
132
  placeholder="Enter text and upload an image.",
133
  show_label=True
134
  )
135
-
136
- submit_event = chat_input.submit(
137
- fn=add_message,
138
- inputs=[chatbot, chat_input],
139
- outputs=[chatbot, chat_input]
140
  )
141
- submit_event.then(
142
- fn=inference_interface,
143
- inputs=[chatbot],
144
- outputs=[chatbot, chat_input]
145
  )
146
-
147
  with gr.Row():
148
- send_button = gr.Button("Send")
149
- clear_button = gr.ClearButton([chatbot, chat_input])
150
-
151
- send_click = send_button.click(
152
- fn=add_message,
153
- inputs=[chatbot, chat_input],
154
- outputs=[chatbot, chat_input]
155
  )
156
  send_click.then(
157
- fn=inference_interface,
158
- inputs=[chatbot],
159
- outputs=[chatbot, chat_input]
160
  )
161
-
162
- gr.Examples(
163
- examples=[
164
- {
165
- "text": "Give me the height of the man in the red hat in feet.",
166
- "files": ["./examples/warehouse_rgb.jpg"]
167
- }
168
- ],
169
- inputs=[chat_input],
170
- )
171
-
172
  return demo
173
 
 
174
  if __name__ == "__main__":
175
  demo = build_demo()
176
  demo.launch(share=True)
 
1
  import spaces
2
  import torch
 
3
  import gradio as gr
4
  from PIL import Image
5
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
 
6
  from functools import lru_cache
7
 
8
  MODEL_ID = "remyxai/SpaceThinker-Qwen2.5VL-3B"
9
 
 
10
  @lru_cache(maxsize=1)
11
+ def _load_model():
12
+ """Load and cache the model and processor inside GPU worker."""
 
13
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
14
  MODEL_ID,
15
+ torch_dtype=torch.bfloat16
16
+ ).to("cuda")
17
  processor = AutoProcessor.from_pretrained(MODEL_ID)
18
  return model, processor
19
 
20
+ @spaces.GPU
21
+ def gpu_inference(image_path: str, prompt: str) -> str:
22
+ """Perform inference entirely in GPU subprocess."""
23
+ model, processor = _load_model()
24
+
25
+ # Load and preprocess image
26
+ image = Image.open(image_path).convert("RGB")
27
+ if image.width > 512:
28
+ ratio = image.height / image.width
29
+ image = image.resize((512, int(512 * ratio)), Image.Resampling.LANCZOS)
30
+
31
+ # Build conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  system_msg = (
33
+ "You are VL-Thinking 🤔, a helpful assistant. "
34
+ "Think through your reasoning then provide the answer. "
35
+ "Wrap reasoning in <think>...</think> and final in <answer>...</answer>."
36
  )
37
  conversation = [
38
+ {"role": "system", "content": [{"type": "text", "text": system_msg}]},
39
+ {"role": "user", "content": [
40
+ {"type": "image", "image": image},
41
+ {"type": "text", "text": prompt}
42
+ ]}
 
 
 
 
 
 
43
  ]
44
+
45
+ # Tokenize, generate, decode
46
+ chat_input = processor.apply_chat_template(
47
  conversation, tokenize=False, add_generation_prompt=True
48
  )
49
+ inputs = processor(text=[chat_input], images=[image], return_tensors="pt").to("cuda")
50
+ output_ids = model.generate(**inputs, max_new_tokens=1024)
51
+ decoded = processor.batch_decode(
52
+ output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
53
+ )[0]
54
+
55
+ # Extract assistant portion
56
+ return decoded.split("assistant", 1)[-1].strip().lstrip(":").strip()
57
 
58
+ # Message handling
 
 
 
59
 
60
  def add_message(history, user_input):
61
+ if history is None:
62
  history = []
63
+ for f in user_input.get("files", []):
 
 
 
 
64
  history.append({"role": "user", "content": (f,)})
65
+ text = user_input.get("text", "")
66
  if text:
67
  history.append({"role": "user", "content": text})
 
68
  return history, gr.MultimodalTextbox(value=None)
69
 
70
+
71
  def inference_interface(history):
72
  if not history:
73
  return history, gr.MultimodalTextbox(value=None)
74
+ # Last user text
75
+ user_text = next(
76
+ (m["content"] for m in reversed(history)
77
+ if m["role"] == "user" and isinstance(m["content"], str)),
78
+ None
79
+ )
80
+ if user_text is None:
 
 
 
 
81
  return history, gr.MultimodalTextbox(value=None)
82
+ # Last user image
83
+ image_path = next(
84
+ (m["content"][0] for m in reversed(history)
85
+ if m["role"] == "user" and isinstance(m["content"], tuple)),
86
+ None
87
+ )
88
+ if image_path is None:
89
  return history, gr.MultimodalTextbox(value=None)
90
 
91
+ # GPU inference
92
+ reply = gpu_inference(image_path, user_text)
93
+ history.append({"role": "assistant", "content": reply})
 
94
  return history, gr.MultimodalTextbox(value=None)
95
 
96
+
97
  def build_demo():
98
  with gr.Blocks() as demo:
99
  gr.Markdown("# SpaceThinker-Qwen2.5VL-3B Image Prompt Chatbot")
100
+ chatbot = gr.Chatbot([], type="messages", label="Conversation")
 
 
101
  chat_input = gr.MultimodalTextbox(
102
  interactive=True,
103
  file_types=["image"],
104
  placeholder="Enter text and upload an image.",
105
  show_label=True
106
  )
107
+ submit_evt = chat_input.submit(
108
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
 
 
 
109
  )
110
+ submit_evt.then(
111
+ inference_interface, [chatbot], [chatbot, chat_input]
 
 
112
  )
 
113
  with gr.Row():
114
+ send_btn = gr.Button("Send")
115
+ clear_btn = gr.ClearButton([chatbot, chat_input])
116
+ send_click = send_btn.click(
117
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
 
 
 
118
  )
119
  send_click.then(
120
+ inference_interface, [chatbot], [chatbot, chat_input]
 
 
121
  )
 
 
 
 
 
 
 
 
 
 
 
122
  return demo
123
 
124
+
125
  if __name__ == "__main__":
126
  demo = build_demo()
127
  demo.launch(share=True)