Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +52 -139
src/txagent/txagent.py
CHANGED
@@ -23,11 +23,11 @@ class TxAgent:
|
|
23 |
def __init__(self, model_name,
|
24 |
rag_model_name,
|
25 |
tool_files_dict=None,
|
26 |
-
enable_finish=
|
27 |
-
enable_rag=
|
28 |
enable_summary=False,
|
29 |
init_rag_num=0,
|
30 |
-
step_rag_num=
|
31 |
summary_mode='step',
|
32 |
summary_skip_last_k=0,
|
33 |
summary_context_length=None,
|
@@ -79,7 +79,7 @@ class TxAgent:
|
|
79 |
if model_name == self.model_name:
|
80 |
return f"The model {model_name} is already loaded."
|
81 |
self.model_name = model_name
|
82 |
-
self.model = LLM(model=self.model_name, enforce_eager=True)
|
83 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
84 |
self.tokenizer = self.model.get_tokenizer()
|
85 |
return f"Model {model_name} loaded successfully."
|
@@ -165,16 +165,16 @@ class TxAgent:
|
|
165 |
def add_special_tools(self, tools, call_agent=False):
|
166 |
if not self.enable_rag and not self.enable_finish:
|
167 |
return tools
|
168 |
-
if self.enable_finish and self.tooluniverse:
|
169 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
170 |
logger.info("Finish tool is added")
|
171 |
-
if call_agent and self.tooluniverse:
|
172 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
|
173 |
logger.info("CallAgent tool is added")
|
174 |
-
elif self.enable_rag and self.tooluniverse:
|
175 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
|
176 |
logger.info("Tool_RAG tool is added")
|
177 |
-
if self.additional_default_tools is not None and self.tooluniverse:
|
178 |
for each_tool_name in self.additional_default_tools:
|
179 |
tool_prompt = self.tooluniverse.get_one_tool_by_one_name(each_tool_name, return_prompt=True)
|
180 |
if tool_prompt is not None:
|
@@ -183,7 +183,7 @@ class TxAgent:
|
|
183 |
return tools
|
184 |
|
185 |
def add_finish_tools(self, tools):
|
186 |
-
if not self.enable_finish or not self.tooluniverse:
|
187 |
return tools
|
188 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
189 |
logger.info("Finish tool is added")
|
@@ -346,10 +346,9 @@ class TxAgent:
|
|
346 |
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
|
347 |
if conversation[-1]['role'] == 'assistant':
|
348 |
conversation.append({'role': 'tool', 'content': 'Errors occurred, provide final answer.'})
|
349 |
-
finish_tools_prompt = self.add_finish_tools([]) if self.enable_finish else []
|
350 |
last_outputs_str = self.llm_infer(messages=conversation,
|
351 |
temperature=temperature,
|
352 |
-
tools=
|
353 |
output_begin_string='[FinalAnswer]',
|
354 |
skip_special_tokens=True,
|
355 |
max_new_tokens=max_new_tokens, max_token=max_token)
|
@@ -654,131 +653,45 @@ Generate one summarized sentence about "function calls' responses" with necessar
|
|
654 |
return "Invalid input."
|
655 |
if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
|
656 |
return ""
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
conversation, status=last_status, enable_summary=enable_summary)
|
700 |
-
if function_call_messages:
|
701 |
-
conversation.extend(function_call_messages)
|
702 |
-
yield history
|
703 |
-
else:
|
704 |
-
next_round = False
|
705 |
-
conversation.append({"role": "assistant", "content": ''.join(last_outputs)})
|
706 |
-
return ''.join(last_outputs).replace("</s>", "")
|
707 |
-
if self.enable_checker:
|
708 |
-
good_status, wrong_info = checker.check_conversation()
|
709 |
-
if not good_status:
|
710 |
-
logger.warning(f"Checker flagged reasoning error: {wrong_info}")
|
711 |
-
break
|
712 |
-
last_outputs = []
|
713 |
-
last_outputs_str, token_overflow = self.llm_infer(
|
714 |
-
messages=conversation,
|
715 |
-
temperature=temperature,
|
716 |
-
tools=picked_tools_prompt,
|
717 |
-
skip_special_tokens=False,
|
718 |
-
max_new_tokens=max_new_tokens,
|
719 |
-
max_token=max_token,
|
720 |
-
seed=seed,
|
721 |
-
check_token_status=True)
|
722 |
-
logger.debug(f"llm_infer output: {last_outputs_str[:100] if last_outputs_str else None}")
|
723 |
-
if last_outputs_str is None:
|
724 |
-
logger.warning("Token overflow")
|
725 |
-
if self.force_finish:
|
726 |
-
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
727 |
-
conversation, temperature, max_new_tokens, max_token)
|
728 |
-
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
|
729 |
-
yield history
|
730 |
-
return last_outputs_str
|
731 |
-
error_msg = "Token limit exceeded."
|
732 |
-
history.append(ChatMessage(role="assistant", content=error_msg))
|
733 |
-
yield history
|
734 |
-
return error_msg
|
735 |
-
last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
|
736 |
-
for msg in history:
|
737 |
-
if msg.metadata is not None:
|
738 |
-
msg.metadata['status'] = 'done'
|
739 |
-
if '[FinalAnswer]' in last_thought:
|
740 |
-
parts = last_thought.split('[FinalAnswer]', 1)
|
741 |
-
final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
|
742 |
-
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
743 |
-
yield history
|
744 |
-
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
745 |
-
yield history
|
746 |
-
else:
|
747 |
-
history.append(ChatMessage(role="assistant", content=last_thought))
|
748 |
-
yield history
|
749 |
-
last_outputs.append(last_outputs_str)
|
750 |
-
if next_round:
|
751 |
-
if self.force_finish:
|
752 |
-
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
753 |
-
conversation, temperature, max_new_tokens, max_token)
|
754 |
-
if '[FinalAnswer]' in last_outputs_str:
|
755 |
-
parts = last_outputs_str.split('[FinalAnswer]', 1)
|
756 |
-
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
|
757 |
-
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
758 |
-
yield history
|
759 |
-
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
760 |
-
yield history
|
761 |
-
else:
|
762 |
-
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
|
763 |
-
yield history
|
764 |
-
else:
|
765 |
-
yield "Reasoning rounds exceeded."
|
766 |
-
except Exception as e:
|
767 |
-
logger.error(f"Exception in run_gradio_chat: {e}")
|
768 |
-
error_msg = f"An error occurred: {e}"
|
769 |
-
history.append(ChatMessage(role="assistant", content=error_msg))
|
770 |
-
yield history
|
771 |
-
if self.force_finish:
|
772 |
-
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
773 |
-
conversation, temperature, max_new_tokens, max_token)
|
774 |
-
if '[FinalAnswer]' in last_outputs_str:
|
775 |
-
parts = last_outputs_str.split('[FinalAnswer]', 1)
|
776 |
-
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
|
777 |
-
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
778 |
-
yield history
|
779 |
-
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
780 |
-
yield history
|
781 |
-
else:
|
782 |
-
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
|
783 |
-
yield history
|
784 |
-
return error_msg
|
|
|
23 |
def __init__(self, model_name,
|
24 |
rag_model_name,
|
25 |
tool_files_dict=None,
|
26 |
+
enable_finish=False, # MODIFIED: Default to False
|
27 |
+
enable_rag=False,
|
28 |
enable_summary=False,
|
29 |
init_rag_num=0,
|
30 |
+
step_rag_num=0,
|
31 |
summary_mode='step',
|
32 |
summary_skip_last_k=0,
|
33 |
summary_context_length=None,
|
|
|
79 |
if model_name == self.model_name:
|
80 |
return f"The model {model_name} is already loaded."
|
81 |
self.model_name = model_name
|
82 |
+
self.model = LLM(model=self.model_name, enforce_eager=True, max_model_len=4096) # MODIFIED: Reduce KV cache
|
83 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
84 |
self.tokenizer = self.model.get_tokenizer()
|
85 |
return f"Model {model_name} loaded successfully."
|
|
|
165 |
def add_special_tools(self, tools, call_agent=False):
|
166 |
if not self.enable_rag and not self.enable_finish:
|
167 |
return tools
|
168 |
+
if self.enable_finish and self.tooluniverse:
|
169 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
170 |
logger.info("Finish tool is added")
|
171 |
+
if call_agent and self.tooluniverse:
|
172 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
|
173 |
logger.info("CallAgent tool is added")
|
174 |
+
elif self.enable_rag and self.tooluniverse:
|
175 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
|
176 |
logger.info("Tool_RAG tool is added")
|
177 |
+
if self.additional_default_tools is not None and self.tooluniverse:
|
178 |
for each_tool_name in self.additional_default_tools:
|
179 |
tool_prompt = self.tooluniverse.get_one_tool_by_one_name(each_tool_name, return_prompt=True)
|
180 |
if tool_prompt is not None:
|
|
|
183 |
return tools
|
184 |
|
185 |
def add_finish_tools(self, tools):
|
186 |
+
if not self.enable_finish or not self.tooluniverse:
|
187 |
return tools
|
188 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
189 |
logger.info("Finish tool is added")
|
|
|
346 |
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
|
347 |
if conversation[-1]['role'] == 'assistant':
|
348 |
conversation.append({'role': 'tool', 'content': 'Errors occurred, provide final answer.'})
|
|
|
349 |
last_outputs_str = self.llm_infer(messages=conversation,
|
350 |
temperature=temperature,
|
351 |
+
tools=[],
|
352 |
output_begin_string='[FinalAnswer]',
|
353 |
skip_special_tokens=True,
|
354 |
max_new_tokens=max_new_tokens, max_token=max_token)
|
|
|
653 |
return "Invalid input."
|
654 |
if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
|
655 |
return ""
|
656 |
+
conversation = self.initialize_conversation(message, conversation, history=[])
|
657 |
+
sampling_params = SamplingParams(
|
658 |
+
temperature=temperature,
|
659 |
+
max_tokens=max_new_tokens,
|
660 |
+
seed=seed if seed is not None else self.seed,
|
661 |
+
)
|
662 |
+
prompt = self.chat_template.render(messages=conversation, tools=[], add_generation_prompt=True)
|
663 |
+
output = self.model.generate([prompt], sampling_params)[0].outputs[0].text # MODIFIED: Direct inference
|
664 |
+
cleaned = clean_response(output) # MODIFIED: Use clean_response
|
665 |
+
if '[FinalAnswer]' in cleaned:
|
666 |
+
parts = cleaned.split('[FinalAnswer]', 1)
|
667 |
+
final_answer = parts[1] if len(parts) > 1 else cleaned
|
668 |
+
history.append(ChatMessage(role="assistant", content=final_answer.strip()))
|
669 |
+
else:
|
670 |
+
history.append(ChatMessage(role="assistant", content=cleaned.strip()))
|
671 |
+
yield history
|
672 |
+
return cleaned
|
673 |
+
|
674 |
+
def clean_response(text: str) -> str: # MODIFIED: Add clean_response for compatibility
|
675 |
+
text = sanitize_utf8(text)
|
676 |
+
text = re.sub(r"\[TOOL_CALLS\].*?\n|\[.*?\].*?\n|(?:get_|tool\s|retrieve\s|use\s|rag\s).*?\n", "", text, flags=re.DOTALL | re.IGNORECASE)
|
677 |
+
text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
|
678 |
+
text = re.sub(
|
679 |
+
r"(?i)(to\s|analyze|will\s|since\s|no\s|none|previous|attempt|involve|check\s|explore|manually|"
|
680 |
+
r"start|look|use|focus|retrieve|tool|based\s|overall|indicate|mention|consider|ensure|need\s|"
|
681 |
+
r"provide|review|assess|identify|potential|records|patient|history|symptoms|medication|"
|
682 |
+
r"conflict|assessment|follow-up|issue|reasoning|step|prompt|address|rag|thought|try|john\sdoe|nkma).*?\n",
|
683 |
+
"", text, flags=re.DOTALL
|
684 |
+
)
|
685 |
+
text = re.sub(r"\n{2,}", "\n", text).strip()
|
686 |
+
lines = []
|
687 |
+
valid_heading = False
|
688 |
+
for line in text.split("\n"):
|
689 |
+
line = line.strip()
|
690 |
+
if line.lower() in ["missed diagnoses:", "medication conflicts:", "incomplete assessments:", "urgent follow-up:"]:
|
691 |
+
valid_heading = True
|
692 |
+
lines.append(f"**{line[:-1]}**:")
|
693 |
+
elif valid_heading and line.startswith("-"):
|
694 |
+
lines.append(line)
|
695 |
+
else:
|
696 |
+
valid_heading = False
|
697 |
+
return "\n".join(lines).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|