KaiChen1998 commited on
Commit
9571b72
·
1 Parent(s): 021d6bd

add support for all three steps

Browse files
Files changed (1) hide show
  1. app.py +28 -34
app.py CHANGED
@@ -16,8 +16,8 @@ auth_token = os.environ.get("TOKEN_FROM_SECRET")
16
  # LLM part
17
  ##########################################
18
  import torch
19
- from transformers import AutoProcessor, AutoTokenizer
20
- from transformers import Qwen2ForCausalLM, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
21
  from qwen_vl_utils import process_vision_info
22
  from threading import Thread
23
 
@@ -36,7 +36,7 @@ processor = AutoProcessor.from_pretrained(MLLM_MODEL_PATH)
36
  tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
37
 
38
  mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(MLLM_MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto")
39
- llm = Qwen2ForCausalLM.from_pretrained(LLM_MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto")
40
 
41
  mllm_sampling = dict(do_sample=False, temperature=0, max_new_tokens=8192)
42
  llm_sampling = dict(temperature=0.6, top_p=0.95, max_new_tokens=8192)
@@ -52,25 +52,6 @@ def build_messages(image_path, question):
52
  ]
53
  return cap_msgs, qa_msgs
54
 
55
- # === Run Captioning and QA ===
56
- def run_mllm_tentative(image_tensor, cap_prompt, qa_prompt):
57
- qa_output = mllm.generate([{"multi_modal_data": {"image": image_tensor}, "prompt": qa_prompt[0]}], sampling_params=mllm_sampling)
58
- return qa_output[0].outputs[0].text
59
-
60
- def run_mllm_caption(image_tensor, cap_prompt, qa_prompt):
61
- cap_output = mllm.generate([{"multi_modal_data": {"image": image_tensor}, "prompt": cap_prompt[0]}], sampling_params=mllm_sampling)
62
- return cap_output[0].outputs[0].text
63
-
64
- # === Final Reasoning Step ===
65
- def run_llm_reasoning(caption, question, answer):
66
- messages = [
67
- {"role": "system", "content": SYSTEM_PROMPT_LLM},
68
- {"role": "user", "content": LLM_PROMPT.format(caption, question, answer)}
69
- ]
70
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
71
- output = llm.generate([{"prompt": prompt}], sampling_params=llm_sampling)
72
- return output[0].outputs[0].text
73
-
74
  ##########################################
75
  # Streaming
76
  ##########################################
@@ -204,18 +185,31 @@ def http_bot(state):
204
  logging.info(f"Query-conditioned Caption: {caption_text}")
205
  state.messages[-1][-1] = state.messages[-1][-1][:-1]
206
  yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
207
-
208
-
209
-
210
- # caption_text = run_mllm_caption(image_tensor, cap_prompt, qa_prompt)
211
- # state.append_message(state.roles[1], "# Caption\n\n" + caption_text)
212
- # logging.info("# Caption\n\n" + caption_text)
213
- # yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
214
-
215
- # final_answer = run_llm_reasoning(caption_text, QUESTION, tentative_answer)
216
- # state.append_message(state.roles[1], "# Final Response\n\n" + final_answer)
217
- # logging.info("# Final Response\n\n" + final_answer)
218
- # yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  ############
221
  # Layout Markdown
 
16
  # LLM part
17
  ##########################################
18
  import torch
19
+ from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
20
+ from transformers import Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
21
  from qwen_vl_utils import process_vision_info
22
  from threading import Thread
23
 
 
36
  tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
37
 
38
  mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(MLLM_MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto")
39
+ llm = AutoModelForCausalLM.from_pretrained(LLM_MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto")
40
 
41
  mllm_sampling = dict(do_sample=False, temperature=0, max_new_tokens=8192)
42
  llm_sampling = dict(temperature=0.6, top_p=0.95, max_new_tokens=8192)
 
52
  ]
53
  return cap_msgs, qa_msgs
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  ##########################################
56
  # Streaming
57
  ##########################################
 
185
  logging.info(f"Query-conditioned Caption: {caption_text}")
186
  state.messages[-1][-1] = state.messages[-1][-1][:-1]
187
  yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
188
+
189
+ # Step 3: Text-only Reasoning
190
+ reason_msgs = [
191
+ {"role": "system", "content": SYSTEM_PROMPT_LLM},
192
+ {"role": "user", "content": LLM_PROMPT.format(caption_text, prompt, tentative_answer)}
193
+ ]
194
+ reason_prompt = tokenizer.apply_chat_template(reason_msgs, tokenize=False, add_generation_prompt=True)
195
+ reason_inputs = tokenizer(reason_prompt, return_tensors="pt").to(llm.device)
196
+
197
+ state.append_message(state.roles[1], "# Text-only Reasoning\n\n")
198
+ try:
199
+ for generated_text in stream_response(llm, reason_inputs, llm_streamer, reason_prompt, llm_sampling):
200
+ output = generated_text[len(reason_prompt):].strip()
201
+ state.messages[-1][-1] = "# Text-only Reasoning\n\n" + output + "▌"
202
+ yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
203
+ except Exception as e:
204
+ os.system("nvidia-smi")
205
+ logging.info(traceback.print_exc())
206
+ state.messages[-1][-1] = server_error_msg
207
+ yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
208
+ return
209
+ final_response = output
210
+ logging.info(f"Text-only Reasoning: {final_response}")
211
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
212
+ yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
213
 
214
  ############
215
  # Layout Markdown