Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
9571b72
1
Parent(s):
021d6bd
add support for all three steps
Browse files
app.py
CHANGED
@@ -16,8 +16,8 @@ auth_token = os.environ.get("TOKEN_FROM_SECRET")
|
|
16 |
# LLM part
|
17 |
##########################################
|
18 |
import torch
|
19 |
-
from transformers import AutoProcessor, AutoTokenizer
|
20 |
-
from transformers import
|
21 |
from qwen_vl_utils import process_vision_info
|
22 |
from threading import Thread
|
23 |
|
@@ -36,7 +36,7 @@ processor = AutoProcessor.from_pretrained(MLLM_MODEL_PATH)
|
|
36 |
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
|
37 |
|
38 |
mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(MLLM_MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto")
|
39 |
-
llm =
|
40 |
|
41 |
mllm_sampling = dict(do_sample=False, temperature=0, max_new_tokens=8192)
|
42 |
llm_sampling = dict(temperature=0.6, top_p=0.95, max_new_tokens=8192)
|
@@ -52,25 +52,6 @@ def build_messages(image_path, question):
|
|
52 |
]
|
53 |
return cap_msgs, qa_msgs
|
54 |
|
55 |
-
# === Run Captioning and QA ===
|
56 |
-
def run_mllm_tentative(image_tensor, cap_prompt, qa_prompt):
|
57 |
-
qa_output = mllm.generate([{"multi_modal_data": {"image": image_tensor}, "prompt": qa_prompt[0]}], sampling_params=mllm_sampling)
|
58 |
-
return qa_output[0].outputs[0].text
|
59 |
-
|
60 |
-
def run_mllm_caption(image_tensor, cap_prompt, qa_prompt):
|
61 |
-
cap_output = mllm.generate([{"multi_modal_data": {"image": image_tensor}, "prompt": cap_prompt[0]}], sampling_params=mllm_sampling)
|
62 |
-
return cap_output[0].outputs[0].text
|
63 |
-
|
64 |
-
# === Final Reasoning Step ===
|
65 |
-
def run_llm_reasoning(caption, question, answer):
|
66 |
-
messages = [
|
67 |
-
{"role": "system", "content": SYSTEM_PROMPT_LLM},
|
68 |
-
{"role": "user", "content": LLM_PROMPT.format(caption, question, answer)}
|
69 |
-
]
|
70 |
-
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
71 |
-
output = llm.generate([{"prompt": prompt}], sampling_params=llm_sampling)
|
72 |
-
return output[0].outputs[0].text
|
73 |
-
|
74 |
##########################################
|
75 |
# Streaming
|
76 |
##########################################
|
@@ -204,18 +185,31 @@ def http_bot(state):
|
|
204 |
logging.info(f"Query-conditioned Caption: {caption_text}")
|
205 |
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
206 |
yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
############
|
221 |
# Layout Markdown
|
|
|
16 |
# LLM part
|
17 |
##########################################
|
18 |
import torch
|
19 |
+
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
|
20 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
|
21 |
from qwen_vl_utils import process_vision_info
|
22 |
from threading import Thread
|
23 |
|
|
|
36 |
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
|
37 |
|
38 |
mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(MLLM_MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto")
|
39 |
+
llm = AutoModelForCausalLM.from_pretrained(LLM_MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto")
|
40 |
|
41 |
mllm_sampling = dict(do_sample=False, temperature=0, max_new_tokens=8192)
|
42 |
llm_sampling = dict(temperature=0.6, top_p=0.95, max_new_tokens=8192)
|
|
|
52 |
]
|
53 |
return cap_msgs, qa_msgs
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
##########################################
|
56 |
# Streaming
|
57 |
##########################################
|
|
|
185 |
logging.info(f"Query-conditioned Caption: {caption_text}")
|
186 |
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
187 |
yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
|
188 |
+
|
189 |
+
# Step 3: Text-only Reasoning
|
190 |
+
reason_msgs = [
|
191 |
+
{"role": "system", "content": SYSTEM_PROMPT_LLM},
|
192 |
+
{"role": "user", "content": LLM_PROMPT.format(caption_text, prompt, tentative_answer)}
|
193 |
+
]
|
194 |
+
reason_prompt = tokenizer.apply_chat_template(reason_msgs, tokenize=False, add_generation_prompt=True)
|
195 |
+
reason_inputs = tokenizer(reason_prompt, return_tensors="pt").to(llm.device)
|
196 |
+
|
197 |
+
state.append_message(state.roles[1], "# Text-only Reasoning\n\n▌")
|
198 |
+
try:
|
199 |
+
for generated_text in stream_response(llm, reason_inputs, llm_streamer, reason_prompt, llm_sampling):
|
200 |
+
output = generated_text[len(reason_prompt):].strip()
|
201 |
+
state.messages[-1][-1] = "# Text-only Reasoning\n\n" + output + "▌"
|
202 |
+
yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
|
203 |
+
except Exception as e:
|
204 |
+
os.system("nvidia-smi")
|
205 |
+
logging.info(traceback.print_exc())
|
206 |
+
state.messages[-1][-1] = server_error_msg
|
207 |
+
yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
|
208 |
+
return
|
209 |
+
final_response = output
|
210 |
+
logging.info(f"Text-only Reasoning: {final_response}")
|
211 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
212 |
+
yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
|
213 |
|
214 |
############
|
215 |
# Layout Markdown
|