salma-remyx commited on
Commit
80b7578
ยท
1 Parent(s): 0c890d5

add SpaceThinker

Browse files
README.md CHANGED
@@ -1,15 +1,12 @@
1
  ---
2
- title: VQASynth
3
- emoji: ๐ŸŽน
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.5.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: VQASynth Scene Reconstruction Pipeline
12
- startup_duration_timeout: 4h
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: SpaceThinker-Qwen2.5VL-3B
3
+ emoji: ๐ŸŒŒ
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.15.0
8
  app_file: app.py
9
  pinned: false
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,196 +1,226 @@
1
  import spaces
2
- import subprocess
3
- import sys
4
-
5
- # Ensure the package is installed from the Git repository
6
- package_name = "vqasynth" # Replace with the actual package name if different
7
- git_repo_url = "git+https://github.com/remyxai/VQASynth.git"
8
-
9
- try:
10
- __import__(package_name)
11
- except ImportError:
12
- print(f"{package_name} not found. Installing from {git_repo_url}...")
13
- subprocess.check_call([sys.executable, "-m", "pip", "install", git_repo_url])
14
-
15
- import os
16
- import uuid
17
- import tempfile
18
-
19
- import cv2
20
- import open3d as o3d
21
- import PIL
22
- from PIL import Image
23
-
24
- from vqasynth.depth import DepthEstimator
25
- from vqasynth.localize import Localizer
26
- from vqasynth.scene_fusion import SpatialSceneConstructor
27
- from vqasynth.prompts import PromptGenerator
28
-
29
- import numpy as np
30
  import gradio as gr
 
 
 
31
 
32
- import spacy
33
-
34
- try:
35
- nlp = spacy.load("en_core_web_sm")
36
- except OSError:
37
- # Download the model if it's not already available
38
- from spacy.cli import download
39
-
40
- download("en_core_web_sm")
41
- nlp = spacy.load("en_core_web_sm")
42
-
43
- depth = DepthEstimator(from_onnx=False)
44
- localizer = Localizer()
45
- spatial_scene_constructor = SpatialSceneConstructor()
46
- prompt_generator = PromptGenerator()
47
-
48
 
49
- def combine_segmented_pointclouds(
50
- pointcloud_ply_files: list, captions: list, prompts: list, cache_dir: str
51
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  """
53
- Process a list of segmented point clouds to combine two based on captions and return the resulting 3D point cloud and the identified prompt.
54
-
55
- Args:
56
- pointcloud_ply_files (list): List of file paths to `.pcd` files representing segmented point clouds.
57
- captions (list): List of captions corresponding to the segmented point clouds.
58
- prompts (list): List of prompts containing questions and answers about the captions.
59
- cache_dir (str): Directory to save the final `.ply` and `.obj` files.
60
-
61
- Returns:
62
- tuple: The path to the generated `.obj` file and the identified prompt text.
63
  """
64
- selected_prompt = None
65
- selected_indices = None
66
- for i, caption1 in enumerate(captions):
67
- for j, caption2 in enumerate(captions):
68
- if i != j:
69
- for prompt in prompts:
70
- if caption1 in prompt and caption2 in prompt:
71
- selected_prompt = prompt
72
- selected_indices = (i, j)
73
- break
74
- if selected_prompt:
75
- break
76
- if selected_prompt:
77
- break
78
-
79
- if not selected_prompt or not selected_indices:
80
- raise ValueError("No prompt found containing two captions.")
81
-
82
- idx1, idx2 = selected_indices
83
- pointcloud_files = [pointcloud_ply_files[idx1], pointcloud_ply_files[idx2]]
84
- captions = [captions[idx1], captions[idx2]]
85
 
86
- combined_point_cloud = o3d.geometry.PointCloud()
87
- for idx, pointcloud_file in enumerate(pointcloud_files):
88
- pcd = o3d.io.read_point_cloud(pointcloud_file)
89
- if pcd.is_empty():
90
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- combined_point_cloud += pcd
 
 
 
 
93
 
94
- if combined_point_cloud.is_empty():
95
- raise ValueError(
96
- "Combined point cloud is empty after loading the selected segments."
97
- )
 
 
 
 
98
 
99
- uuid_out = str(uuid.uuid4())
100
- ply_file = os.path.join(cache_dir, f"combined_output_{uuid_out}.ply")
101
- obj_file = os.path.join(cache_dir, f"combined_output_{uuid_out}.obj")
102
 
103
- o3d.io.write_point_cloud(ply_file, combined_point_cloud)
 
 
 
104
 
105
- mesh = o3d.io.read_triangle_mesh(ply_file)
106
- o3d.io.write_triangle_mesh(obj_file, mesh)
 
107
 
108
- return obj_file, selected_prompt
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- @spaces.GPU
112
- def run_vqasynth_pipeline(image: PIL.Image, cache_dir: str):
113
- depth_map, focal_length = depth.run(image)
114
- masks, bounding_boxes, captions = localizer.run(image)
115
- pointcloud_data, cannonicalized = spatial_scene_constructor.run(
116
- str(0), image, depth_map, focal_length, masks, cache_dir
117
- )
118
- prompts = prompt_generator.run(captions, pointcloud_data, cannonicalized)
119
- obj_file, selected_prompt = combine_segmented_pointclouds(
120
- pointcloud_data, captions, prompts, cache_dir
121
- )
122
- return obj_file, selected_prompt
123
 
 
 
124
 
125
- def process_image(image: str):
126
- # Use a persistent temporary directory to keep the .obj file accessible by Gradio
127
- temp_dir = tempfile.mkdtemp()
128
- image = Image.open(image).convert("RGB")
129
- obj_file, prompt = run_vqasynth_pipeline(image, temp_dir)
130
- return obj_file, prompt
131
 
 
 
 
132
 
133
  def build_demo():
134
  with gr.Blocks() as demo:
135
- gr.Markdown(
136
- """
137
- # Synthesizing SpatialVQA Samples with VQASynth
138
- This space helps test the full [VQASynth](https://github.com/remyxai/VQASynth) scene reconstruction pipeline on a single image with visualizations.
139
- ### [Github](https://github.com/remyxai/VQASynth) | [Collection](https://huggingface.co/collections/remyxai/spacevlms-66a3dbb924756d98e7aec678)
140
- """
 
 
141
  )
142
 
143
- gr.Markdown(
144
- """
145
- ## Instructions
146
- Upload an image, and the tool will generate a corresponding 3D point cloud visualization of the objects found and an example prompt and response describing a spatial relationship between the objects.
147
- """
 
 
 
 
 
 
148
  )
149
 
 
150
  with gr.Row():
151
- with gr.Column():
152
- image_input = gr.Image(type="filepath", label="Upload an Image")
153
- generate_button = gr.Button("Generate")
154
 
155
- with gr.Column():
156
- model_output = gr.Model3D(label="3D Point Cloud") # Only used as output
157
- caption_output = gr.Text(label="Caption")
158
-
159
- generate_button.click(
160
- process_image, inputs=image_input, outputs=[model_output, caption_output]
 
 
 
161
  )
162
 
 
163
  gr.Examples(
164
  examples=[
165
- ["./examples/warehouse_rgb.jpg"],
166
- ["./examples/spooky_doggy.png"],
167
- ["./examples/bee_and_flower.jpg"],
168
- ["./examples/gears.png"],
169
- ["./examples/road-through-dense-forest.jpg"],
170
  ],
171
- inputs=image_input,
172
- label="Example Images",
173
- examples_per_page=5,
174
- )
175
-
176
- gr.Markdown(
177
- """
178
- ## Citation
179
- ```
180
- @article{chen2024spatialvlm,
181
- title = {SpatialVLM: Endowing Vision-Language Models with Spatial Reasoning Capabilities},
182
- author = {Chen, Boyuan and Xu, Zhuo and Kirmani, Sean and Ichter, Brian and Driess, Danny and Florence, Pete and Sadigh, Dorsa and Guibas, Leonidas and Xia, Fei},
183
- journal = {arXiv preprint arXiv:2401.12168},
184
- year = {2024},
185
- url = {https://arxiv.org/abs/2401.12168},
186
- }
187
- ```
188
- """
189
  )
190
 
191
  return demo
192
 
193
-
194
  if __name__ == "__main__":
195
  demo = build_demo()
196
  demo.launch(share=True)
 
 
1
  import spaces
2
+ import torch
3
+ import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import gradio as gr
5
+ from PIL import Image
6
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
7
+ from typing import List
8
 
9
+ MODEL_ID = "remyxai/SpaceThinker-Qwen2.5VL-3B"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ @spaces.GPU
12
+ def load_model():
13
+ print("Loading model and processor...")
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
16
+ MODEL_ID,
17
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
18
+ ).to(device)
19
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
20
+ return model, processor
21
+
22
+ model, processor = load_model()
23
+
24
+ def process_image(image_path_or_obj):
25
+ """Loads, resizes, and preprocesses an image path or Pillow Image."""
26
+ if isinstance(image_path_or_obj, str):
27
+ # Path on disk or from history
28
+ image = Image.open(image_path_or_obj).convert("RGB")
29
+ elif isinstance(image_path_or_obj, Image.Image):
30
+ image = image_path_or_obj.convert("RGB")
31
+ else:
32
+ raise ValueError("process_image expects a file path (str) or PIL.Image")
33
+
34
+ max_width = 512
35
+ if image.width > max_width:
36
+ aspect_ratio = image.height / image.width
37
+ new_height = int(max_width * aspect_ratio)
38
+ image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)
39
+ print(f"Resized image to: {max_width}x{new_height}")
40
+ return image
41
+
42
+ def get_latest_image(history):
43
  """
44
+ Look from the end to find the last user-uploaded image (stored as (file_path,) ).
45
+ Return None if not found.
 
 
 
 
 
 
 
 
46
  """
47
+ for user_msg, _assistant_msg in reversed(history):
48
+ if isinstance(user_msg, tuple) and len(user_msg) > 0:
49
+ return user_msg[0]
50
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ def only_assistant_text(full_text: str) -> str:
53
+ """
54
+ Helper to strip out any lines containing 'system', 'user', etc.,
55
+ and return only the final assistant answer.
56
+ Adjust this parsing if your model's output format differs.
57
+ """
58
+ # Example output might look like:
59
+ # system
60
+ # ...
61
+ # user
62
+ # ...
63
+ # assistant
64
+ # The final answer
65
+ #
66
+ # We'll just split on 'assistant' and return everything after it.
67
+ if "assistant" in full_text:
68
+ parts = full_text.split("assistant", 1)
69
+ result = parts[-1].strip()
70
+ # Remove any leading punctuation (like a colon)
71
+ result = result.lstrip(":").strip()
72
+ return result
73
+ return full_text.strip()
74
+
75
+ def run_inference(image, prompt):
76
+ """Runs Qwen2.5-VL inference on a single image and text prompt."""
77
+ system_msg = (
78
+ "You are VL-Thinking ๐Ÿค”, a helpful assistant with excellent reasoning ability. "
79
+ "You should first think about the reasoning process and then provide the answer. "
80
+ "Use <think>...</think> and <answer>...</answer> tags."
81
+ )
82
+ conversation = [
83
+ {
84
+ "role": "system",
85
+ "content": [{"type": "text", "text": system_msg}],
86
+ },
87
+ {
88
+ "role": "user",
89
+ "content": [
90
+ {"type": "image", "image": image},
91
+ {"type": "text", "text": prompt},
92
+ ],
93
+ },
94
+ ]
95
+ text_input = processor.apply_chat_template(
96
+ conversation, tokenize=False, add_generation_prompt=True
97
+ )
98
 
99
+ inputs = processor(text=[text_input], images=[image], return_tensors="pt").to(model.device)
100
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
101
+ output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
102
+ # Parse out only the final assistant text
103
+ return only_assistant_text(output_text)
104
 
105
+ def add_message(history, user_input):
106
+ """
107
+ Step 1 (triggered by user's 'Submit' or 'Send'):
108
+ - Save new text or images into `history`.
109
+ - The Chatbot display uses pairs: [user_text_or_image, assistant_reply].
110
+ """
111
+ if not isinstance(history, list):
112
+ history = []
113
 
114
+ files = user_input.get("files", [])
115
+ text = user_input.get("text", "")
 
116
 
117
+ # Store images
118
+ for f in files:
119
+ # Each image is stored as `[(file_path,), None]`
120
+ history.append([(f,), None])
121
 
122
+ # Store text
123
+ if text:
124
+ history.append([text, None])
125
 
126
+ return history, gr.MultimodalTextbox(value=None)
127
 
128
+ def inference_interface(history):
129
+ """
130
+ Step 2: Use the most recent text + the most recent image to run Qwen2.5-VL.
131
+ Instead of adding another entry, we fill the assistant's answer into
132
+ the last user text entry.
133
+ """
134
+ if not history:
135
+ return history, gr.MultimodalTextbox(value=None)
136
+
137
+ # 1) Get the user's most recent text
138
+ user_text = ""
139
+ # We'll search from the end for the first str we find
140
+ for idx in range(len(history) - 1, -1, -1):
141
+ user_msg, assistant_msg = history[idx]
142
+ if isinstance(user_msg, str):
143
+ user_text = user_msg
144
+ # We'll also keep track of this index so we can fill in the assistant reply
145
+ user_idx = idx
146
+ break
147
+ else:
148
+ # No user text found
149
+ print("No user text found in history. Skipping inference.")
150
+ return history, gr.MultimodalTextbox(value=None)
151
 
152
+ # 2) Get the latest image from the entire conversation
153
+ latest_image = get_latest_image(history)
154
+ if not latest_image:
155
+ # No image found => can't run the model
156
+ print("No image found in history. Skipping inference.")
157
+ return history, gr.MultimodalTextbox(value=None)
 
 
 
 
 
 
158
 
159
+ # 3) Process the image
160
+ pil_image = process_image(latest_image)
161
 
162
+ # 4) Run inference
163
+ assistant_reply = run_inference(pil_image, user_text)
 
 
 
 
164
 
165
+ # 5) Fill that assistant reply back into the last user text entry
166
+ history[user_idx][1] = assistant_reply
167
+ return history, gr.MultimodalTextbox(value=None)
168
 
169
  def build_demo():
170
  with gr.Blocks() as demo:
171
+ gr.Markdown("# SpaceThinker-Qwen2.5VL-3B Image Prompt Chatbot")
172
+
173
+ chatbot = gr.Chatbot([], line_breaks=True)
174
+ chat_input = gr.MultimodalTextbox(
175
+ interactive=True,
176
+ file_types=["image"],
177
+ placeholder="Enter text and upload an image.",
178
+ show_label=True
179
  )
180
 
181
+ # When the user presses Enter in the MultimodalTextbox:
182
+ submit_event = chat_input.submit(
183
+ fn=add_message, # Step 1: store user data
184
+ inputs=[chatbot, chat_input],
185
+ outputs=[chatbot, chat_input]
186
+ )
187
+ # After storing, run inference
188
+ submit_event.then(
189
+ fn=inference_interface, # Step 2: run Qwen2.5-VL
190
+ inputs=[chatbot],
191
+ outputs=[chatbot, chat_input]
192
  )
193
 
194
+ # Same logic for a "Send" button
195
  with gr.Row():
196
+ send_button = gr.Button("Send")
197
+ clear_button = gr.ClearButton([chatbot, chat_input])
 
198
 
199
+ send_click = send_button.click(
200
+ fn=add_message,
201
+ inputs=[chatbot, chat_input],
202
+ outputs=[chatbot, chat_input]
203
+ )
204
+ send_click.then(
205
+ fn=inference_interface,
206
+ inputs=[chatbot],
207
+ outputs=[chatbot, chat_input]
208
  )
209
 
210
+ # Example
211
  gr.Examples(
212
  examples=[
213
+ {
214
+ "text": "Give me the height of the man in the red hat in feet.",
215
+ "files": ["./examples/warehouse_rgb.jpg"]
216
+ }
 
217
  ],
218
+ inputs=[chat_input],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  )
220
 
221
  return demo
222
 
 
223
  if __name__ == "__main__":
224
  demo = build_demo()
225
  demo.launch(share=True)
226
+
checkpoints/depth_pro.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3eb35ca68168ad3d14cb150f8947a4edf85589941661fdb2686259c80685c0ce
3
- size 1904446787
 
 
 
 
examples/bee_and_flower.jpg DELETED
Binary file (18.2 kB)
 
examples/gears.png DELETED
Binary file (525 kB)
 
examples/road-through-dense-forest.jpg DELETED
Binary file (292 kB)
 
examples/spooky_doggy.png DELETED
Binary file (892 kB)
 
requirements.txt CHANGED
@@ -1,36 +1,9 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu121
2
-
3
- torch==2.4.0
4
- torchvision==0.19.0
5
- torchaudio==2.4.0
6
-
7
- transformers==4.44.2
8
- pillow==11.0.0
9
- gradio==5.5.0
10
- accelerate==0.34.2
11
- numpy==1.26.4
12
- timm==1.0.9
13
- einops==0.7.0
14
- open3d==0.18.0
15
- opencv-python==4.7.0.72
16
- tqdm==4.66.3
17
- torchprofile==0.0.4
18
- matplotlib==3.6.2
19
- huggingface-hub==0.25.1
20
- onnx==1.13.1
21
- onnxruntime==1.14.1
22
- onnxsim==0.4.35
23
- scipy==1.12.0
24
- litellm==1.25.2
25
- pycocotools==2.0.6
26
- datasets==3.1.0
27
- spacy==3.7.5
28
- pandas==2.2.3
29
- html5lib==1.1
30
- spaces==0.30.4
31
-
32
- #git+https://github.com/remyxai/VQASynth.git
33
- git+https://github.com/apple/ml-depth-pro.git
34
- git+https://github.com/facebookresearch/sam2.git
35
- git+https://github.com/openai/CLIP.git
36
- flash-attn @ https://remyx.ai/assets/spatialvlm/flash_attn-2.7.0.post2-cp310-cp310-linux_x86_64.whl
 
1
+ torch
2
+ transformers>=4.41.0
3
+ Pillow
4
+ gradio==5.15.0
5
+ spaces
6
+ multiprocess
7
+ requests
8
+ accelerate>=0.26.0
9
+ git+https://github.com/huggingface/transformers.git