File size: 7,485 Bytes
a9f74e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import spaces
import torch
import time
import gradio as gr
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from typing import List

MODEL_ID = "remyxai/SpaceQwen2.5-VL-3B-Instruct"

@spaces.GPU
def load_model():
    print("Loading model and processor...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
    ).to(device)
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    return model, processor

model, processor = load_model()

def process_image(image_path_or_obj):
    """Loads, resizes, and preprocesses an image path or Pillow Image."""
    if isinstance(image_path_or_obj, str):
        # Path on disk or from history
        image = Image.open(image_path_or_obj).convert("RGB")
    elif isinstance(image_path_or_obj, Image.Image):
        image = image_path_or_obj.convert("RGB")
    else:
        raise ValueError("process_image expects a file path (str) or PIL.Image")

    max_width = 512
    if image.width > max_width:
        aspect_ratio = image.height / image.width
        new_height = int(max_width * aspect_ratio)
        image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)
        print(f"Resized image to: {max_width}x{new_height}")
    return image

def get_latest_image(history):
    """
    Look from the end to find the last user-uploaded image (stored as (file_path,) ).
    Return None if not found.
    """
    for user_msg, _assistant_msg in reversed(history):
        if isinstance(user_msg, tuple) and len(user_msg) > 0:
            return user_msg[0]
    return None

def only_assistant_text(full_text: str) -> str:
    """
    Helper to strip out any lines containing 'system', 'user', etc.,
    and return only the final assistant answer.
    Adjust this parsing if your model's output format differs.
    """
    # Example output might look like:
    #   system
    #   ...
    #   user
    #   ...
    #   assistant
    #   The final answer
    #
    # We'll just split on 'assistant' and return everything after it.
    if "assistant" in full_text:
        parts = full_text.split("assistant", 1)
        result = parts[-1].strip()
        # Remove any leading punctuation (like a colon)
        result = result.lstrip(":").strip()
        return result
    return full_text.strip()

def run_inference(image, prompt):
    """Runs Qwen2.5-VL inference on a single image and text prompt."""
    system_msg = (
        "You are a Vision Language Model specialized in interpreting visual data from images. "
        "Your task is to analyze the provided image and respond to queries with concise answers."
    )
    conversation = [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_msg}],
        },
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ],
        },
    ]
    text_input = processor.apply_chat_template(
        conversation, tokenize=False, add_generation_prompt=True
    )

    inputs = processor(text=[text_input], images=[image], return_tensors="pt").to(model.device)
    generated_ids = model.generate(**inputs, max_new_tokens=1024)
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    # Parse out only the final assistant text
    return only_assistant_text(output_text)

def add_message(history, user_input):
    """
    Step 1 (triggered by user's 'Submit' or 'Send'):
    - Save new text or images into `history`.
    - The Chatbot display uses pairs: [user_text_or_image, assistant_reply].
    """
    if not isinstance(history, list):
        history = []

    files = user_input.get("files", [])
    text = user_input.get("text", "")

    # Store images
    for f in files:
        # Each image is stored as `[(file_path,), None]`
        history.append([(f,), None])

    # Store text
    if text:
        history.append([text, None])

    return history, gr.MultimodalTextbox(value=None)

def inference_interface(history):
    """
    Step 2: Use the most recent text + the most recent image to run Qwen2.5-VL.
    Instead of adding another entry, we fill the assistant's answer into
    the last user text entry.
    """
    if not history:
        return history, gr.MultimodalTextbox(value=None)

    # 1) Get the user's most recent text
    user_text = ""
    # We'll search from the end for the first str we find
    for idx in range(len(history) - 1, -1, -1):
        user_msg, assistant_msg = history[idx]
        if isinstance(user_msg, str):
            user_text = user_msg
            # We'll also keep track of this index so we can fill in the assistant reply
            user_idx = idx
            break
    else:
        # No user text found
        print("No user text found in history. Skipping inference.")
        return history, gr.MultimodalTextbox(value=None)

    # 2) Get the latest image from the entire conversation
    latest_image = get_latest_image(history)
    if not latest_image:
        # No image found => can't run the model
        print("No image found in history. Skipping inference.")
        return history, gr.MultimodalTextbox(value=None)

    # 3) Process the image
    pil_image = process_image(latest_image)

    # 4) Run inference
    assistant_reply = run_inference(pil_image, user_text)

    # 5) Fill that assistant reply back into the last user text entry
    history[user_idx][1] = assistant_reply
    return history, gr.MultimodalTextbox(value=None)

def build_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# SpaceQwen2.5-VL Image Prompt Chatbot")

        chatbot = gr.Chatbot([], line_breaks=True)
        chat_input = gr.MultimodalTextbox(
            interactive=True,
            file_types=["image"],
            placeholder="Enter text or upload an image (or both).",
            show_label=True
        )

        # When the user presses Enter in the MultimodalTextbox:
        submit_event = chat_input.submit(
            fn=add_message,  # Step 1: store user data
            inputs=[chatbot, chat_input],
            outputs=[chatbot, chat_input]
        )
        # After storing, run inference
        submit_event.then(
            fn=inference_interface,  # Step 2: run Qwen2.5-VL
            inputs=[chatbot],
            outputs=[chatbot, chat_input]
        )

        # Same logic for a "Send" button
        with gr.Row():
            send_button = gr.Button("Send")
            clear_button = gr.ClearButton([chatbot, chat_input])

        send_click = send_button.click(
            fn=add_message,
            inputs=[chatbot, chat_input],
            outputs=[chatbot, chat_input]
        )
        send_click.then(
            fn=inference_interface,
            inputs=[chatbot],
            outputs=[chatbot, chat_input]
        )

        # Example
        gr.Examples(
            examples=[
                {
                    "text": "Give me the height of the man in the red hat in feet.",
                    "files": ["./examples/warehouse_rgb.jpg"]
                }
            ],
            inputs=[chat_input],
        )

    return demo

if __name__ == "__main__":
    demo = build_demo()
    demo.launch(share=True)