File size: 15,070 Bytes
5f870ca
 
 
 
 
996d4eb
 
5f870ca
 
 
 
 
 
 
 
 
 
9350a8c
9571b72
 
5f870ca
 
 
 
 
 
 
 
 
 
 
80524d9
5f870ca
 
 
 
 
694f7e2
 
5f870ca
c5c3010
9a0f0f5
5f870ca
 
 
 
 
 
 
 
 
 
 
 
7c2502a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f870ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b849db2
 
5f870ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
885de3a
5f870ca
 
 
 
 
 
 
 
 
 
885de3a
5f870ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23991fe
5f870ca
 
 
 
 
 
c255fb2
 
5f870ca
1288171
 
7c2502a
 
 
 
 
0b75061
7c2502a
 
 
 
 
 
 
 
4c8802b
 
dfd1066
 
5f870ca
dfd1066
 
 
4c8802b
dfd1066
 
 
 
 
 
 
 
 
4c8802b
 
dfd1066
 
9571b72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41a0b93
5f870ca
 
 
 
 
 
 
 
 
 
9350a8c
5f870ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5948818
5f870ca
 
 
 
021d6bd
5f870ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
import os
import argparse
import traceback
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
logging.getLogger("http").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)

import spaces
import gradio as gr
from conversation_public import default_conversation

auth_token = os.environ.get("TOKEN_FROM_SECRET")

##########################################
# LLM part
##########################################
import torch
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
from transformers import Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from qwen_vl_utils import process_vision_info
from threading import Thread

# === Prompts ===
SYSTEM_PROMPT_LLM = "You are a helpful assistant."
SYSTEM_PROMPT_CAP = "You are given an image and a relevant question. Based on the query, please describe the image in details. Do not try to answer the question."

CAPTION_PROMPT = "Question: {}\nPlease describe the image. DO NOT try to answer the question!"
LLM_PROMPT = """In the following text, you will receive a detailed caption of an image and a relevant question. In addition, you will be provided with a tentative model response. You goal is to answer the question using these information.\n\n### The detailed caption of the provided image: {}\n\n### Note that the caption might contain incorrect solutions, do not be misguided by them.\n\n### A problem to be solved: {}\n\n### A tentative model response: {}\n\n### Note that the above tentative response might be inaccurate (due to calculation errors, incorrect logic/reasoning and so on), under such a case, please ignore it and give your own solutions. However, if you do not have enough evidence to show it is wrong, please output the tentative response."""

# === Initialize Models ===
MLLM_MODEL_PATH = "KaiChen1998/RACRO-7B-CRO-GRPO"
LLM_MODEL_PATH = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"

processor = AutoProcessor.from_pretrained(MLLM_MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)

mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(MLLM_MODEL_PATH, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto")
llm = AutoModelForCausalLM.from_pretrained(LLM_MODEL_PATH, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto")

mllm_sampling = dict(do_sample=False, temperature=0, max_new_tokens=8192)
llm_sampling = dict(temperature=0.6, top_p=0.95, max_new_tokens=8192)

# === Build Prompts ===
def build_messages(image_path, question):
    cap_msgs = [
        {"role": "system", "content": SYSTEM_PROMPT_CAP},
        {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": CAPTION_PROMPT.format(question)}]}
    ]
    qa_msgs = [
        {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": question + " Please think step by step. The final answer MUST BE put in \\boxed{}."}]}
    ]
    return cap_msgs, qa_msgs

##########################################
# Streaming
##########################################
mllm_streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
llm_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)

def stream_response(model, inputs, streamer, prompt, gen_kwargs):
    thread = Thread(target=model.generate, kwargs=dict(
        streamer=streamer,
        **inputs,
        **gen_kwargs
        )
    )
    thread.start()

    generated_text = prompt
    for new_text in streamer:
        generated_text += new_text
        yield generated_text

##########################################
# Gradio part
##########################################
no_change_btn = gr.Button()
enable_btn = gr.Button(interactive=True)
disable_btn = gr.Button(interactive=False)
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
server_oom_msg = "**OUT OF GPU MEMORY DETECTED. PLEASE DECREASE THE MAX OUTPUT TOKENS AND REGENERATE.**"

def load_demo_refresh_model_list():
    logging.info(f"load_demo.")
    state = default_conversation.copy()
    return state

def regenerate(state, image_process_mode):
    logging.info(f"regenerate.")
    state.messages = state.messages[:-3]
    prev_human_msg = state.messages[-1]
    if type(prev_human_msg[1]) in (tuple, list):
        prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode, *prev_human_msg[1][3:])
    state.skip_next = False
    return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2

def clear_history():
    logging.info(f"clear_history.")
    state = default_conversation.copy()
    return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2

############
# Show prompt in the chatbot
# Input: [state, textbox, imagebox, image_process_mode]
# Return: [state, chatbot, textbox, imagebox] + btn_list
############
def add_text(state, text, image, image_process_mode):
    # Input legality checking
    logging.info(f"add_text. len: {len(text)}")
    if len(text) <= 0 or image is None:
        state.skip_next = True
        return (state, state.to_gradio_chatbot_public(), "", None) + (no_change_btn,) * 2
    
    # Deal with image inputs
    if image is not None:
        text = (text, image, image_process_mode, None)
    
    # Single round only
    state = default_conversation.copy()
    state.append_message(state.roles[0], text)
    state.skip_next = False
    logging.info(str(state.messages))
    return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2

############
# Get response
# Input: [state]
# Return: [state, chatbot] + btn_list
############
@spaces.GPU
def http_bot(state):
    logging.info(f"http_bot.")

    if state.skip_next:
        yield (state, state.to_gradio_chatbot_public()) + (no_change_btn,) * 2
        return

    # Retrive prompt
    prompt = state.messages[-1][-1][0]
    all_images = state.get_images(return_pil=True)[0]
    pload = {"prompt": prompt, "images": f'List of {len(state.get_images())} images: {all_images}'}
    logging.info(f"==== request ====\n{pload}")
    
    # Construct prompt
    cap_msgs, qa_msgs = build_messages(all_images, prompt)
    cap_prompt = processor.apply_chat_template(cap_msgs, tokenize=False, add_generation_prompt=True)
    qa_prompt = processor.apply_chat_template(qa_msgs, tokenize=False, add_generation_prompt=True)
    image_tensor, _ = process_vision_info(cap_msgs)
    cap_inputs = processor(text=cap_prompt, images=image_tensor, return_tensors="pt").to(mllm.device)
    qa_inputs = processor(text=qa_prompt, images=image_tensor, return_tensors="pt").to(mllm.device)    
    
    # Step 1: Tentative Response
    state.append_message(state.roles[1], "# Tentative Response\n\nβ–Œ")
    try:
        for generated_text in stream_response(mllm, qa_inputs, mllm_streamer, qa_prompt, mllm_sampling):
            output = generated_text[len(qa_prompt):].strip()
            state.messages[-1][-1] = "# Tentative Response\n\n" + output + "β–Œ"
            yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
    except Exception as e:
        os.system("nvidia-smi")
        logging.info(traceback.print_exc())
        state.messages[-1][-1] = server_error_msg
        yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
        return
    tentative_answer = output
    logging.info(f"Tentative Response: {tentative_answer}")
    state.messages[-1][-1] = state.messages[-1][-1][:-1]
    yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
    
    # Step 2: Query-conditioned Caption
    state.append_message(state.roles[1], "# Query-conditioned Caption\n\nβ–Œ")
    try:
        for generated_text in stream_response(mllm, cap_inputs, mllm_streamer, cap_prompt, mllm_sampling):
            output = generated_text[len(cap_prompt):].strip()
            state.messages[-1][-1] = "# Query-conditioned Caption\n\n" + output + "β–Œ"
            yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
    except Exception as e:
        os.system("nvidia-smi")
        logging.info(traceback.print_exc())
        state.messages[-1][-1] = server_error_msg
        yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
        return
    caption_text = output
    logging.info(f"Query-conditioned Caption: {caption_text}")
    state.messages[-1][-1] = state.messages[-1][-1][:-1]
    yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
    
    # Step 3: Text-only Reasoning
    reason_msgs = [
        {"role": "system", "content": SYSTEM_PROMPT_LLM},
        {"role": "user", "content": LLM_PROMPT.format(caption_text, prompt, tentative_answer)}
    ]
    reason_prompt = tokenizer.apply_chat_template(reason_msgs, tokenize=False, add_generation_prompt=True)
    reason_inputs = tokenizer(reason_prompt, return_tensors="pt").to(llm.device)
    
    state.append_message(state.roles[1], "# Text-only Reasoning\n\nβ–Œ")
    try:
        for generated_text in stream_response(llm, reason_inputs, llm_streamer, reason_prompt, llm_sampling):
            output = generated_text[len(reason_prompt):].strip()
            state.messages[-1][-1] = "# Text-only Reasoning\n\n" + output + "β–Œ"
            yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
    except Exception as e:
        os.system("nvidia-smi")
        logging.info(traceback.print_exc())
        state.messages[-1][-1] = server_error_msg
        yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
        return
    final_response = output
    logging.info(f"Text-only Reasoning: {final_response}")
    state.messages[-1][-1] = state.messages[-1][-1][:-1]
    yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2

############
# Layout Markdown
############
title_markdown = ("""
<div style="display: flex; align-items: center; padding: 20px; border-radius: 10px; background-color: #f0f0f0;">
  <div>
    <h1 style="margin: 0;">RACRO: Perceptual Decoupling for Scalable Multi-modal Reasoning via Reward-Optimized Captioning</h1>
    <h2 style="margin: 10px 0;">πŸ“ƒ <a href="https://www.arxiv.org/abs/2506.04559" style="font-weight: 400;">Paper</a> | πŸ’» <a href="https://github.com/gyhdog99/RACRO2" style="font-weight: 400;">Code</a> | πŸ€— <a href="https://huggingface.co/collections/KaiChen1998/racro-6848ec8c65b3a0bf33d0fbdb" style="font-weight: 400;">HuggingFace</a></h2>
    <p  style="margin: 20px 0;">
      <strong>1. RACRO is designed for multi-modal reasoning, and thus, image inputs are <mark>ALWAYS</mark> necessary!</strong>
    </p>
  </div>
</div>
""")

learn_more_markdown = ("""
## Citation
<pre><code>@article{gou2025perceptual,
  author    = {Gou, Yunhao and Chen, Kai and Liu, Zhili and Hong, Lanqing and Jin, Xin and Li, Zhenguo and Kwok, James T. and Zhang, Yu}, 
  title     = {Perceptual Decoupling for Scalable Multi-modal Reasoning via Reward-Optimized Captioning},
  journal   = {arXiv preprint arXiv:2506.04559},
  year      = {2025},
}</code></pre>
""")

block_css = """
#buttons button {
    min-width: min(120px,100%);
}
.message-row img {
    margin: 0px !important;
}
.avatar-container img {
    padding: 0px !important;
}
"""

############
# Layout Demo
############
def build_demo(embed_mode):
    textbox = gr.Textbox(label="Text", show_label=False, placeholder="Enter text and then click πŸ’¬ Chat to talk with me ^v^", container=False)
    with gr.Blocks(title="RACRO", theme=gr.themes.Default(), css=block_css) as demo:
        state = gr.State()
        if not embed_mode:
            gr.HTML(title_markdown)

        ##############
        # Chatbot
        ##############
        with gr.Row(equal_height=True):
            with gr.Column(scale=1):
                imagebox = gr.Image(type="pil", label="Image")
                image_process_mode = gr.Radio(
                    ["Crop", "Resize", "Pad", "Default"],
                    value="Default",
                    label="Preprocess for non-square image", visible=False)

                gr.Examples(examples=[
                    ["./examples/image-text/demo_example.jpg", "When the canister is momentarily stopped by the spring, by what distance $d$ is the spring compressed?"],
                ], inputs=[imagebox, textbox], label='Examples')

            with gr.Column(scale=8):
                chatbot = gr.Chatbot(
                    type="messages",
                    elem_id="chatbot",
                    label="RACRO Chatbot",
                    layout="bubble",
                    avatar_images=["examples/user_avator.png", "examples/icon_256.png"]
                )
                textbox.render()
                with gr.Row(elem_id="buttons") as button_row:
                    submit_btn = gr.Button(value="πŸ’¬  Chat", variant="primary")
                    # stop_btn = gr.Button(value="⏹️  Stop Generation", interactive=False)
                    regenerate_btn = gr.Button(value="πŸ”„  Regenerate", interactive=False)
                    clear_btn = gr.Button(value="πŸ—‘οΈ  Clear", interactive=False)
        
        if not embed_mode:
            gr.Markdown(learn_more_markdown)

        # Register listeners
        btn_list = [regenerate_btn, clear_btn]
        regenerate_btn.click(
            regenerate,
            [state, image_process_mode],
            [state, chatbot, textbox, imagebox] + btn_list
        ).then(
            http_bot,
            [state],
            [state, chatbot] + btn_list,
        )

        clear_btn.click(
            clear_history,
            None,
            [state, chatbot, textbox, imagebox] + btn_list,
            queue=False
        )

        # probably mean press enter
        textbox.submit(
            add_text,
            [state, textbox, imagebox, image_process_mode],
            [state, chatbot, textbox, imagebox] + btn_list,
            queue=False
        ).then(
            http_bot,
            [state],
            [state, chatbot] + btn_list,
        )

        submit_btn.click(
            add_text,
            [state, textbox, imagebox, image_process_mode],
            [state, chatbot, textbox, imagebox] + btn_list
        ).then(
            http_bot,
            [state],
            [state, chatbot] + btn_list,
        )

        ##############
        # Demo loading
        ##############
        demo.load(
            load_demo_refresh_model_list,
            None,
            [state],
            queue=False
        )
    return demo


parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true")
parser.add_argument("--embed", action="store_true")
args = parser.parse_args()

demo = build_demo(args.embed)
demo.queue(
    max_size=10,
    api_open=False
).launch(
    favicon_path="./examples/icon_256.png",
    allowed_paths=["/"],
    share=args.share
)