File size: 12,794 Bytes
5f870ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742fae5
5f870ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import os
import argparse
import traceback
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')

import spaces
import gradio as gr
from conversation_public import default_conversation

auth_token = os.environ.get("TOKEN_FROM_SECRET")

##########################################
# LLM part
##########################################
from transformers import AutoProcessor, AutoTokenizer, TextIteratorStreamer
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
from threading import Thread

# === Prompts ===
SYSTEM_PROMPT_LLM = "You are a helpful assistant."
SYSTEM_PROMPT_CAP = "You are given an image and a relevant question. Based on the query, please describe the image in details. Do not try to answer the question."

CAPTION_PROMPT = "Question: {}\nPlease describe the image. DO NOT try to answer the question!"
LLM_PROMPT = """In the following text, you will receive a detailed caption of an image and a relevant question. In addition, you will be provided with a tentative model response. You goal is to answer the question using these information.\n\n### The detailed caption of the provided image: {}\n\n### Note that the caption might contain incorrect solutions, do not be misguided by them.\n\n### A problem to be solved: {}\n\n### A tentative model response: {}\n\n### Note that the above tentative response might be inaccurate (due to calculation errors, incorrect logic/reasoning and so on), under such a case, please ignore it and give your own solutions. However, if you do not have enough evidence to show it is wrong, please output the tentative response."""

# === Initialize Models ===
MLLM_MODEL_PATH = "Qwen/Qwen2.5-VL-7B-Instruct"
LLM_MODEL_PATH = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"

processor = AutoProcessor.from_pretrained(MLLM_MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)

mllm = LLM(model=MLLM_MODEL_PATH, tensor_parallel_size=1, gpu_memory_utilization=0.8,
           device='cuda:0', dtype="bfloat16", limit_mm_per_prompt={"image": 1})

llm = LLM(model=LLM_MODEL_PATH, tensor_parallel_size=1, gpu_memory_utilization=0.8,
          device='cuda:0', dtype="bfloat16")

mllm_sampling = SamplingParams(temperature=0, max_tokens=8192)
llm_sampling = SamplingParams(temperature=0.6, top_p=0.95, max_tokens=8192)

# === Build Prompts ===
def build_messages(image_path, question):
    cap_msgs = [
        {"role": "system", "content": SYSTEM_PROMPT_CAP},
        {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": CAPTION_PROMPT.format(question)}]}
    ]
    qa_msgs = [
        {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": question + " Please think step by step. The final answer MUST BE put in \\boxed{}."}]}
    ]
    return cap_msgs, qa_msgs

# === Run Captioning and QA ===
def run_mllm_tentative(image_tensor, cap_prompt, qa_prompt):
    qa_output = mllm.generate([{"multi_modal_data": {"image": image_tensor}, "prompt": qa_prompt[0]}], sampling_params=mllm_sampling)
    return qa_output[0].outputs[0].text

def run_mllm_caption(image_tensor, cap_prompt, qa_prompt):
    cap_output = mllm.generate([{"multi_modal_data": {"image": image_tensor}, "prompt": cap_prompt[0]}], sampling_params=mllm_sampling)
    return cap_output[0].outputs[0].text

# === Final Reasoning Step ===
def run_llm_reasoning(caption, question, answer):
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT_LLM},
        {"role": "user", "content": LLM_PROMPT.format(caption, question, answer)}
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    output = llm.generate([{"prompt": prompt}], sampling_params=llm_sampling)
    return output[0].outputs[0].text

##########################################
# Gradio part
##########################################
no_change_btn = gr.Button()
enable_btn = gr.Button(interactive=True)
disable_btn = gr.Button(interactive=False)
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
server_oom_msg = "**OUT OF GPU MEMORY DETECTED. PLEASE DECREASE THE MAX OUTPUT TOKENS AND REGENERATE.**"

def load_demo_refresh_model_list():
    logging.info(f"load_demo.")
    state = default_conversation.copy()
    return state

def regenerate(state, image_process_mode):
    logging.info(f"regenerate.")
    state.messages[-1][-1] = None
    prev_human_msg = state.messages[-2]
    if type(prev_human_msg[1]) in (tuple, list):
        prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode, *prev_human_msg[1][3:])
    state.skip_next = False
    return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2

def clear_history():
    logging.info(f"clear_history.")
    state = default_conversation.copy()
    return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2

############
# Show prompt in the chatbot
# Input: [state, textbox, imagebox, image_process_mode]
# Return: [state, chatbot, textbox, imagebox] + btn_list
############
def add_text(state, text, image, image_process_mode):
    # Input legality checking
    logging.info(f"add_text. len: {len(text)}")
    if len(text) <= 0 or image is None:
        state.skip_next = True
        return (state, state.to_gradio_chatbot_public(), "", None) + (no_change_btn,) * 2
    
    # Deal with image inputs
    if image is not None:
        text = (text, image, image_process_mode, None)
    
    # Single round only
    state = default_conversation.copy()
    state.append_message(state.roles[0], text)
    state.skip_next = False
    logging.info(str(state.messages))
    return (state, state.to_gradio_chatbot_public(), "") + (disable_btn,) * 2

############
# Get response
# Input: [state]
# Return: [state, chatbot] + btn_list
############
@spaces.GPU
def http_bot(state):
    logging.info(f"http_bot.")

    if state.skip_next:
        yield (state, state.to_gradio_chatbot_public()) + (no_change_btn,) * 2
        return

    # Retrive prompt
    prompt = state.messages[-1][0][0]
    all_images = state.get_images(return_pil=True)[0]
    pload = {"prompt": prompt, "images": f'List of {len(state.get_images())} images: {all_images}'}
    logging.info(f"==== request ====\n{pload}")
    
    # Construct prompt
    cap_msgs, qa_msgs = build_messages(all_images, prompt)
    cap_prompt = processor.apply_chat_template([cap_msgs], tokenize=False, add_generation_prompt=True)
    qa_prompt = processor.apply_chat_template([qa_msgs], tokenize=False, add_generation_prompt=True)

    image_tensor, _ = process_vision_info(cap_msgs)
    tentative_answer = run_mllm_tentative(image_tensor, cap_prompt, qa_prompt)
    state.append_message(state.roles[1], "# Tentative Response\n\n" + tentative_answer)
    logging.info("# Tentative Response\n\n" + tentative_answer)
    yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
    
    caption_text = run_mllm_caption(image_tensor, cap_prompt, qa_prompt)
    state.append_message(state.roles[1], "# Caption\n\n" + caption_text)
    logging.info("# Caption\n\n" + caption_text)
    yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
        
    final_answer = run_llm_reasoning(caption_text, QUESTION, tentative_answer)
    state.append_message(state.roles[1], "# Final Response\n\n" + final_answer)
    logging.info("# Final Response\n\n" + final_answer)
    yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2

############
# Layout Markdown
############
title_markdown = ("""
<div style="display: flex; align-items: center; padding: 20px; border-radius: 10px; background-color: #f0f0f0;">
  <div>
    <h1 style="margin: 0;">RACRO: Perceptual Decoupling for Scalable Multi-modal Reasoning via Reward-Optimized Captioning</h1>
    <h2 style="margin: 10px 0;">πŸ“ƒ <a href="https://www.arxiv.org/abs/2506.04559" style="font-weight: 400;">Paper</a> | πŸ’» <a href="https://github.com/gyhdog99/RACRO2" style="font-weight: 400;">Code</a> | πŸ€— <a href="https://huggingface.co/collections/KaiChen1998/racro-6848ec8c65b3a0bf33d0fbdb" style="font-weight: 400;">HuggingFace</a></h2>
    <p  style="margin: 20px 0;">
      <strong>1. RACRO is designed for multi-modal reasoning, and thus, image inputs are <mark>ALWAYS</mark> necessary!</strong><br/>
      <strong>2. Models are deployed with vLLM, which unfortunately, still does not support streaming outputs for MLLMs.</strong>
    </p>
  </div>
</div>
""")

learn_more_markdown = ("""
## Citation
<pre><code>@article{gou2025perceptual,
  author    = {Gou, Yunhao and Chen, Kai and Liu, Zhili and Hong, Lanqing and Jin, Xin and Li, Zhenguo and Kwok, James T. and Zhang, Yu}, 
  title     = {Perceptual Decoupling for Scalable Multi-modal Reasoning via Reward-Optimized Captioning},
  journal   = {arXiv preprint arXiv:2506.04559},
  year      = {2025},
}</code></pre>
""")

block_css = """
#buttons button {
    min-width: min(120px,100%);
}
.message-row img {
    margin: 0px !important;
}
.avatar-container img {
    padding: 0px !important;
}
"""

############
# Layout Demo
############
def build_demo(embed_mode):
    textbox = gr.Textbox(label="Text", show_label=False, placeholder="Enter text and then click πŸ’¬ Chat to talk with me ^v^", container=False)
    with gr.Blocks(title="RACRO", theme=gr.themes.Default(), css=block_css) as demo:
        state = gr.State()
        if not embed_mode:
            gr.HTML(title_markdown)

        ##############
        # Chatbot
        ##############
        with gr.Row(equal_height=True):
            with gr.Column(scale=1):
                imagebox = gr.Image(type="pil", label="Image")
                image_process_mode = gr.Radio(
                    ["Crop", "Resize", "Pad", "Default"],
                    value="Default",
                    label="Preprocess for non-square image", visible=False)

                gr.Examples(examples=[
                    ["./examples/demo_example.png", "When the canister is momentarily stopped by the spring, by what distance $d$ is the spring compressed?"],
                ], inputs=[imagebox, textbox], label='Examples')

            with gr.Column(scale=8):
                chatbot = gr.Chatbot(
                    elem_id="chatbot",
                    label="RACRO Chatbot",
                    layout="bubble",
                    avatar_images=["examples/user_avator.png", "examples/icon_256.png"]
                )
                textbox.render()
                with gr.Row(elem_id="buttons") as button_row:
                    submit_btn = gr.Button(value="πŸ’¬  Chat", variant="primary")
                    # stop_btn = gr.Button(value="⏹️  Stop Generation", interactive=False)
                    regenerate_btn = gr.Button(value="πŸ”„  Regenerate", interactive=False)
                    clear_btn = gr.Button(value="πŸ—‘οΈ  Clear", interactive=False)
        
        if not embed_mode:
            gr.Markdown(learn_more_markdown)

        # Register listeners
        btn_list = [regenerate_btn, clear_btn]
        regenerate_btn.click(
            regenerate,
            [state, image_process_mode],
            [state, chatbot, textbox, imagebox] + btn_list
        ).then(
            http_bot,
            [state],
            [state, chatbot] + btn_list,
        )

        clear_btn.click(
            clear_history,
            None,
            [state, chatbot, textbox, imagebox] + btn_list,
            queue=False
        )

        # probably mean press enter
        textbox.submit(
            add_text,
            [state, textbox, imagebox, image_process_mode],
            [state, chatbot, textbox, imagebox] + btn_list,
            queue=False
        ).then(
            http_bot,
            [state],
            [state, chatbot] + btn_list,
        )

        submit_btn.click(
            add_text,
            [state, textbox, imagebox, image_process_mode],
            [state, chatbot, textbox, imagebox] + btn_list
        ).then(
            http_bot,
            [state],
            [state, chatbot] + btn_list,
        )

        ##############
        # Demo loading
        ##############
        demo.load(
            load_demo_refresh_model_list,
            None,
            [state],
            queue=False
        )
    return demo


parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true")
parser.add_argument("--embed", action="store_true")
args = parser.parse_args()

demo = build_demo(args.embed)
demo.queue(
    max_size=10,
    api_open=False
).launch(
    favicon_path="./examples/icon_256.png",
    allowed_paths=["/"],
    share=args.share
)