Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
import time | |
import gradio as gr | |
from PIL import Image | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
from typing import List | |
from functools import lru_cache | |
MODEL_ID = "remyxai/SpaceThinker-Qwen2.5VL-3B" | |
def load_model(): | |
print("Loading model and processor...") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
).to(device) | |
processor = AutoProcessor.from_pretrained(MODEL_ID) | |
return model, processor | |
def process_image(image_path_or_obj): | |
if isinstance(image_path_or_obj, str): | |
image = Image.open(image_path_or_obj).convert("RGB") | |
elif isinstance(image_path_or_obj, Image.Image): | |
image = image_path_or_obj.convert("RGB") | |
else: | |
raise ValueError("process_image expects a file path (str) or PIL.Image") | |
max_width = 512 | |
if image.width > max_width: | |
aspect_ratio = image.height / image.width | |
new_height = int(max_width * aspect_ratio) | |
image = image.resize((max_width, new_height), Image.Resampling.LANCZOS) | |
return image | |
def get_latest_image(history): | |
for item in reversed(history): | |
if item["role"] == "user" and isinstance(item["content"], tuple): | |
return item["content"][0] | |
return None | |
def only_assistant_text(full_text: str) -> str: | |
if "assistant" in full_text: | |
parts = full_text.split("assistant", 1) | |
result = parts[-1].strip() | |
result = result.lstrip(":").strip() | |
return result | |
return full_text.strip() | |
def run_inference(image, prompt): | |
model, processor = load_model() | |
system_msg = ( | |
"You are VL-Thinking π€, a helpful assistant with excellent reasoning ability. " | |
"You should first think about the reasoning process and then provide the answer. " | |
"Use <think>...</think> and <answer>...</answer> tags." | |
) | |
conversation = [ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": system_msg}], | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": prompt}, | |
], | |
}, | |
] | |
text_input = processor.apply_chat_template( | |
conversation, tokenize=False, add_generation_prompt=True | |
) | |
inputs = processor(text=[text_input], images=[image], return_tensors="pt").to(model.device) | |
generated_ids = model.generate(**inputs, max_new_tokens=1024) | |
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
return only_assistant_text(output_text) | |
def add_message(history, user_input): | |
if not isinstance(history, list): | |
history = [] | |
files = user_input.get("files", []) | |
text = user_input.get("text", "") | |
for f in files: | |
history.append({"role": "user", "content": (f,)}) | |
if text: | |
history.append({"role": "user", "content": text}) | |
return history, gr.MultimodalTextbox(value=None) | |
def inference_interface(history): | |
if not history: | |
return history, gr.MultimodalTextbox(value=None) | |
user_text = "" | |
user_idx = -1 | |
for idx in range(len(history) - 1, -1, -1): | |
msg = history[idx] | |
if msg["role"] == "user" and isinstance(msg["content"], str): | |
user_text = msg["content"] | |
user_idx = idx | |
break | |
if user_idx == -1: | |
return history, gr.MultimodalTextbox(value=None) | |
latest_image = get_latest_image(history) | |
if not latest_image: | |
return history, gr.MultimodalTextbox(value=None) | |
pil_image = process_image(latest_image) | |
assistant_reply = run_inference(pil_image, user_text) | |
history.append({"role": "assistant", "content": assistant_reply}) | |
return history, gr.MultimodalTextbox(value=None) | |
def build_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# SpaceThinker-Qwen2.5VL-3B Image Prompt Chatbot") | |
chatbot = gr.Chatbot([], type="messages", line_breaks=True) | |
chat_input = gr.MultimodalTextbox( | |
interactive=True, | |
file_types=["image"], | |
placeholder="Enter text and upload an image.", | |
show_label=True | |
) | |
submit_event = chat_input.submit( | |
fn=add_message, | |
inputs=[chatbot, chat_input], | |
outputs=[chatbot, chat_input] | |
) | |
submit_event.then( | |
fn=inference_interface, | |
inputs=[chatbot], | |
outputs=[chatbot, chat_input] | |
) | |
with gr.Row(): | |
send_button = gr.Button("Send") | |
clear_button = gr.ClearButton([chatbot, chat_input]) | |
send_click = send_button.click( | |
fn=add_message, | |
inputs=[chatbot, chat_input], | |
outputs=[chatbot, chat_input] | |
) | |
send_click.then( | |
fn=inference_interface, | |
inputs=[chatbot], | |
outputs=[chatbot, chat_input] | |
) | |
gr.Examples( | |
examples=[ | |
{ | |
"text": "Give me the height of the man in the red hat in feet.", | |
"files": ["./examples/warehouse_rgb.jpg"] | |
} | |
], | |
inputs=[chat_input], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = build_demo() | |
demo.launch(share=True) | |