File size: 5,570 Bytes
a8b0636
80b7578
 
ebd9056
80b7578
 
 
eaa703f
ebd9056
80b7578
0d09a3a
80b7578
eaa703f
80b7578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707a904
 
 
80b7578
0d09a3a
80b7578
 
 
 
 
 
 
 
 
eaa703f
80b7578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d09a3a
80b7578
 
 
 
0d09a3a
80b7578
 
 
ebd9056
80b7578
 
0d09a3a
80b7578
707a904
ebd9056
80b7578
707a904
ebd9056
80b7578
ebd9056
80b7578
 
 
 
 
707a904
80b7578
707a904
 
 
80b7578
 
707a904
 
80b7578
ebd9056
80b7578
 
 
ebd9056
80b7578
 
3e7a2b7
707a904
80b7578
3e7a2b7
 
22fc8c6
80b7578
 
707a904
 
80b7578
 
 
 
707a904
0d09a3a
ebd9056
80b7578
707a904
80b7578
 
 
 
707a904
80b7578
 
0d09a3a
3e7a2b7
22fc8c6
80b7578
 
ebd9056
80b7578
 
 
 
 
 
 
 
 
3e7a2b7
ebd9056
3e7a2b7
0c890d5
80b7578
 
 
 
0c890d5
80b7578
0d09a3a
ebd9056
 
 
0d09a3a
3e7a2b7
0d09a3a
80b7578
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import spaces
import torch
import time
import gradio as gr
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from typing import List
from functools import lru_cache

MODEL_ID = "remyxai/SpaceThinker-Qwen2.5VL-3B"

@spaces.GPU
@lru_cache(maxsize=1)
def load_model():
    print("Loading model and processor...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
    ).to(device)
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    return model, processor

def process_image(image_path_or_obj):
    if isinstance(image_path_or_obj, str):
        image = Image.open(image_path_or_obj).convert("RGB")
    elif isinstance(image_path_or_obj, Image.Image):
        image = image_path_or_obj.convert("RGB")
    else:
        raise ValueError("process_image expects a file path (str) or PIL.Image")

    max_width = 512
    if image.width > max_width:
        aspect_ratio = image.height / image.width
        new_height = int(max_width * aspect_ratio)
        image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)
    return image

def get_latest_image(history):
    for item in reversed(history):
        if item["role"] == "user" and isinstance(item["content"], tuple):
            return item["content"][0]
    return None

def only_assistant_text(full_text: str) -> str:
    if "assistant" in full_text:
        parts = full_text.split("assistant", 1)
        result = parts[-1].strip()
        result = result.lstrip(":").strip()
        return result
    return full_text.strip()

def run_inference(image, prompt):
    model, processor = load_model()
    system_msg = (
        "You are VL-Thinking 🤔, a helpful assistant with excellent reasoning ability. "
        "You should first think about the reasoning process and then provide the answer. "
        "Use <think>...</think> and <answer>...</answer> tags."
    )
    conversation = [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_msg}],
        },
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ],
        },
    ]
    text_input = processor.apply_chat_template(
        conversation, tokenize=False, add_generation_prompt=True
    )

    inputs = processor(text=[text_input], images=[image], return_tensors="pt").to(model.device)
    generated_ids = model.generate(**inputs, max_new_tokens=1024)
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    return only_assistant_text(output_text)

def add_message(history, user_input):
    if not isinstance(history, list):
        history = []

    files = user_input.get("files", [])
    text = user_input.get("text", "")

    for f in files:
        history.append({"role": "user", "content": (f,)})

    if text:
        history.append({"role": "user", "content": text})

    return history, gr.MultimodalTextbox(value=None)

def inference_interface(history):
    if not history:
        return history, gr.MultimodalTextbox(value=None)

    user_text = ""
    user_idx = -1
    for idx in range(len(history) - 1, -1, -1):
        msg = history[idx]
        if msg["role"] == "user" and isinstance(msg["content"], str):
            user_text = msg["content"]
            user_idx = idx
            break

    if user_idx == -1:
        return history, gr.MultimodalTextbox(value=None)

    latest_image = get_latest_image(history)
    if not latest_image:
        return history, gr.MultimodalTextbox(value=None)

    pil_image = process_image(latest_image)
    assistant_reply = run_inference(pil_image, user_text)

    history.append({"role": "assistant", "content": assistant_reply})
    return history, gr.MultimodalTextbox(value=None)

def build_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# SpaceThinker-Qwen2.5VL-3B Image Prompt Chatbot")

        chatbot = gr.Chatbot([], type="messages", line_breaks=True)

        chat_input = gr.MultimodalTextbox(
            interactive=True,
            file_types=["image"],
            placeholder="Enter text and upload an image.",
            show_label=True
        )

        submit_event = chat_input.submit(
            fn=add_message,
            inputs=[chatbot, chat_input],
            outputs=[chatbot, chat_input]
        )
        submit_event.then(
            fn=inference_interface,
            inputs=[chatbot],
            outputs=[chatbot, chat_input]
        )

        with gr.Row():
            send_button = gr.Button("Send")
            clear_button = gr.ClearButton([chatbot, chat_input])

        send_click = send_button.click(
            fn=add_message,
            inputs=[chatbot, chat_input],
            outputs=[chatbot, chat_input]
        )
        send_click.then(
            fn=inference_interface,
            inputs=[chatbot],
            outputs=[chatbot, chat_input]
        )

        gr.Examples(
            examples=[
                {
                    "text": "Give me the height of the man in the red hat in feet.",
                    "files": ["./examples/warehouse_rgb.jpg"]
                }
            ],
            inputs=[chat_input],
        )

    return demo

if __name__ == "__main__":
    demo = build_demo()
    demo.launch(share=True)