Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +52 -139
src/txagent/txagent.py
CHANGED
|
@@ -23,11 +23,11 @@ class TxAgent:
|
|
| 23 |
def __init__(self, model_name,
|
| 24 |
rag_model_name,
|
| 25 |
tool_files_dict=None,
|
| 26 |
-
enable_finish=
|
| 27 |
-
enable_rag=
|
| 28 |
enable_summary=False,
|
| 29 |
init_rag_num=0,
|
| 30 |
-
step_rag_num=
|
| 31 |
summary_mode='step',
|
| 32 |
summary_skip_last_k=0,
|
| 33 |
summary_context_length=None,
|
|
@@ -79,7 +79,7 @@ class TxAgent:
|
|
| 79 |
if model_name == self.model_name:
|
| 80 |
return f"The model {model_name} is already loaded."
|
| 81 |
self.model_name = model_name
|
| 82 |
-
self.model = LLM(model=self.model_name, enforce_eager=True)
|
| 83 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
| 84 |
self.tokenizer = self.model.get_tokenizer()
|
| 85 |
return f"Model {model_name} loaded successfully."
|
|
@@ -165,16 +165,16 @@ class TxAgent:
|
|
| 165 |
def add_special_tools(self, tools, call_agent=False):
|
| 166 |
if not self.enable_rag and not self.enable_finish:
|
| 167 |
return tools
|
| 168 |
-
if self.enable_finish and self.tooluniverse:
|
| 169 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
| 170 |
logger.info("Finish tool is added")
|
| 171 |
-
if call_agent and self.tooluniverse:
|
| 172 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
|
| 173 |
logger.info("CallAgent tool is added")
|
| 174 |
-
elif self.enable_rag and self.tooluniverse:
|
| 175 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
|
| 176 |
logger.info("Tool_RAG tool is added")
|
| 177 |
-
if self.additional_default_tools is not None and self.tooluniverse:
|
| 178 |
for each_tool_name in self.additional_default_tools:
|
| 179 |
tool_prompt = self.tooluniverse.get_one_tool_by_one_name(each_tool_name, return_prompt=True)
|
| 180 |
if tool_prompt is not None:
|
|
@@ -183,7 +183,7 @@ class TxAgent:
|
|
| 183 |
return tools
|
| 184 |
|
| 185 |
def add_finish_tools(self, tools):
|
| 186 |
-
if not self.enable_finish or not self.tooluniverse:
|
| 187 |
return tools
|
| 188 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
| 189 |
logger.info("Finish tool is added")
|
|
@@ -346,10 +346,9 @@ class TxAgent:
|
|
| 346 |
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
|
| 347 |
if conversation[-1]['role'] == 'assistant':
|
| 348 |
conversation.append({'role': 'tool', 'content': 'Errors occurred, provide final answer.'})
|
| 349 |
-
finish_tools_prompt = self.add_finish_tools([]) if self.enable_finish else []
|
| 350 |
last_outputs_str = self.llm_infer(messages=conversation,
|
| 351 |
temperature=temperature,
|
| 352 |
-
tools=
|
| 353 |
output_begin_string='[FinalAnswer]',
|
| 354 |
skip_special_tokens=True,
|
| 355 |
max_new_tokens=max_new_tokens, max_token=max_token)
|
|
@@ -654,131 +653,45 @@ Generate one summarized sentence about "function calls' responses" with necessar
|
|
| 654 |
return "Invalid input."
|
| 655 |
if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
|
| 656 |
return ""
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
conversation, status=last_status, enable_summary=enable_summary)
|
| 700 |
-
if function_call_messages:
|
| 701 |
-
conversation.extend(function_call_messages)
|
| 702 |
-
yield history
|
| 703 |
-
else:
|
| 704 |
-
next_round = False
|
| 705 |
-
conversation.append({"role": "assistant", "content": ''.join(last_outputs)})
|
| 706 |
-
return ''.join(last_outputs).replace("</s>", "")
|
| 707 |
-
if self.enable_checker:
|
| 708 |
-
good_status, wrong_info = checker.check_conversation()
|
| 709 |
-
if not good_status:
|
| 710 |
-
logger.warning(f"Checker flagged reasoning error: {wrong_info}")
|
| 711 |
-
break
|
| 712 |
-
last_outputs = []
|
| 713 |
-
last_outputs_str, token_overflow = self.llm_infer(
|
| 714 |
-
messages=conversation,
|
| 715 |
-
temperature=temperature,
|
| 716 |
-
tools=picked_tools_prompt,
|
| 717 |
-
skip_special_tokens=False,
|
| 718 |
-
max_new_tokens=max_new_tokens,
|
| 719 |
-
max_token=max_token,
|
| 720 |
-
seed=seed,
|
| 721 |
-
check_token_status=True)
|
| 722 |
-
logger.debug(f"llm_infer output: {last_outputs_str[:100] if last_outputs_str else None}")
|
| 723 |
-
if last_outputs_str is None:
|
| 724 |
-
logger.warning("Token overflow")
|
| 725 |
-
if self.force_finish:
|
| 726 |
-
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
| 727 |
-
conversation, temperature, max_new_tokens, max_token)
|
| 728 |
-
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
|
| 729 |
-
yield history
|
| 730 |
-
return last_outputs_str
|
| 731 |
-
error_msg = "Token limit exceeded."
|
| 732 |
-
history.append(ChatMessage(role="assistant", content=error_msg))
|
| 733 |
-
yield history
|
| 734 |
-
return error_msg
|
| 735 |
-
last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
|
| 736 |
-
for msg in history:
|
| 737 |
-
if msg.metadata is not None:
|
| 738 |
-
msg.metadata['status'] = 'done'
|
| 739 |
-
if '[FinalAnswer]' in last_thought:
|
| 740 |
-
parts = last_thought.split('[FinalAnswer]', 1)
|
| 741 |
-
final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
|
| 742 |
-
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
| 743 |
-
yield history
|
| 744 |
-
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
| 745 |
-
yield history
|
| 746 |
-
else:
|
| 747 |
-
history.append(ChatMessage(role="assistant", content=last_thought))
|
| 748 |
-
yield history
|
| 749 |
-
last_outputs.append(last_outputs_str)
|
| 750 |
-
if next_round:
|
| 751 |
-
if self.force_finish:
|
| 752 |
-
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
| 753 |
-
conversation, temperature, max_new_tokens, max_token)
|
| 754 |
-
if '[FinalAnswer]' in last_outputs_str:
|
| 755 |
-
parts = last_outputs_str.split('[FinalAnswer]', 1)
|
| 756 |
-
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
|
| 757 |
-
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
| 758 |
-
yield history
|
| 759 |
-
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
| 760 |
-
yield history
|
| 761 |
-
else:
|
| 762 |
-
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
|
| 763 |
-
yield history
|
| 764 |
-
else:
|
| 765 |
-
yield "Reasoning rounds exceeded."
|
| 766 |
-
except Exception as e:
|
| 767 |
-
logger.error(f"Exception in run_gradio_chat: {e}")
|
| 768 |
-
error_msg = f"An error occurred: {e}"
|
| 769 |
-
history.append(ChatMessage(role="assistant", content=error_msg))
|
| 770 |
-
yield history
|
| 771 |
-
if self.force_finish:
|
| 772 |
-
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
| 773 |
-
conversation, temperature, max_new_tokens, max_token)
|
| 774 |
-
if '[FinalAnswer]' in last_outputs_str:
|
| 775 |
-
parts = last_outputs_str.split('[FinalAnswer]', 1)
|
| 776 |
-
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
|
| 777 |
-
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
| 778 |
-
yield history
|
| 779 |
-
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
| 780 |
-
yield history
|
| 781 |
-
else:
|
| 782 |
-
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
|
| 783 |
-
yield history
|
| 784 |
-
return error_msg
|
|
|
|
| 23 |
def __init__(self, model_name,
|
| 24 |
rag_model_name,
|
| 25 |
tool_files_dict=None,
|
| 26 |
+
enable_finish=False, # MODIFIED: Default to False
|
| 27 |
+
enable_rag=False,
|
| 28 |
enable_summary=False,
|
| 29 |
init_rag_num=0,
|
| 30 |
+
step_rag_num=0,
|
| 31 |
summary_mode='step',
|
| 32 |
summary_skip_last_k=0,
|
| 33 |
summary_context_length=None,
|
|
|
|
| 79 |
if model_name == self.model_name:
|
| 80 |
return f"The model {model_name} is already loaded."
|
| 81 |
self.model_name = model_name
|
| 82 |
+
self.model = LLM(model=self.model_name, enforce_eager=True, max_model_len=4096) # MODIFIED: Reduce KV cache
|
| 83 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
| 84 |
self.tokenizer = self.model.get_tokenizer()
|
| 85 |
return f"Model {model_name} loaded successfully."
|
|
|
|
| 165 |
def add_special_tools(self, tools, call_agent=False):
|
| 166 |
if not self.enable_rag and not self.enable_finish:
|
| 167 |
return tools
|
| 168 |
+
if self.enable_finish and self.tooluniverse:
|
| 169 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
| 170 |
logger.info("Finish tool is added")
|
| 171 |
+
if call_agent and self.tooluniverse:
|
| 172 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
|
| 173 |
logger.info("CallAgent tool is added")
|
| 174 |
+
elif self.enable_rag and self.tooluniverse:
|
| 175 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
|
| 176 |
logger.info("Tool_RAG tool is added")
|
| 177 |
+
if self.additional_default_tools is not None and self.tooluniverse:
|
| 178 |
for each_tool_name in self.additional_default_tools:
|
| 179 |
tool_prompt = self.tooluniverse.get_one_tool_by_one_name(each_tool_name, return_prompt=True)
|
| 180 |
if tool_prompt is not None:
|
|
|
|
| 183 |
return tools
|
| 184 |
|
| 185 |
def add_finish_tools(self, tools):
|
| 186 |
+
if not self.enable_finish or not self.tooluniverse:
|
| 187 |
return tools
|
| 188 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
| 189 |
logger.info("Finish tool is added")
|
|
|
|
| 346 |
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
|
| 347 |
if conversation[-1]['role'] == 'assistant':
|
| 348 |
conversation.append({'role': 'tool', 'content': 'Errors occurred, provide final answer.'})
|
|
|
|
| 349 |
last_outputs_str = self.llm_infer(messages=conversation,
|
| 350 |
temperature=temperature,
|
| 351 |
+
tools=[],
|
| 352 |
output_begin_string='[FinalAnswer]',
|
| 353 |
skip_special_tokens=True,
|
| 354 |
max_new_tokens=max_new_tokens, max_token=max_token)
|
|
|
|
| 653 |
return "Invalid input."
|
| 654 |
if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
|
| 655 |
return ""
|
| 656 |
+
conversation = self.initialize_conversation(message, conversation, history=[])
|
| 657 |
+
sampling_params = SamplingParams(
|
| 658 |
+
temperature=temperature,
|
| 659 |
+
max_tokens=max_new_tokens,
|
| 660 |
+
seed=seed if seed is not None else self.seed,
|
| 661 |
+
)
|
| 662 |
+
prompt = self.chat_template.render(messages=conversation, tools=[], add_generation_prompt=True)
|
| 663 |
+
output = self.model.generate([prompt], sampling_params)[0].outputs[0].text # MODIFIED: Direct inference
|
| 664 |
+
cleaned = clean_response(output) # MODIFIED: Use clean_response
|
| 665 |
+
if '[FinalAnswer]' in cleaned:
|
| 666 |
+
parts = cleaned.split('[FinalAnswer]', 1)
|
| 667 |
+
final_answer = parts[1] if len(parts) > 1 else cleaned
|
| 668 |
+
history.append(ChatMessage(role="assistant", content=final_answer.strip()))
|
| 669 |
+
else:
|
| 670 |
+
history.append(ChatMessage(role="assistant", content=cleaned.strip()))
|
| 671 |
+
yield history
|
| 672 |
+
return cleaned
|
| 673 |
+
|
| 674 |
+
def clean_response(text: str) -> str: # MODIFIED: Add clean_response for compatibility
|
| 675 |
+
text = sanitize_utf8(text)
|
| 676 |
+
text = re.sub(r"\[TOOL_CALLS\].*?\n|\[.*?\].*?\n|(?:get_|tool\s|retrieve\s|use\s|rag\s).*?\n", "", text, flags=re.DOTALL | re.IGNORECASE)
|
| 677 |
+
text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
|
| 678 |
+
text = re.sub(
|
| 679 |
+
r"(?i)(to\s|analyze|will\s|since\s|no\s|none|previous|attempt|involve|check\s|explore|manually|"
|
| 680 |
+
r"start|look|use|focus|retrieve|tool|based\s|overall|indicate|mention|consider|ensure|need\s|"
|
| 681 |
+
r"provide|review|assess|identify|potential|records|patient|history|symptoms|medication|"
|
| 682 |
+
r"conflict|assessment|follow-up|issue|reasoning|step|prompt|address|rag|thought|try|john\sdoe|nkma).*?\n",
|
| 683 |
+
"", text, flags=re.DOTALL
|
| 684 |
+
)
|
| 685 |
+
text = re.sub(r"\n{2,}", "\n", text).strip()
|
| 686 |
+
lines = []
|
| 687 |
+
valid_heading = False
|
| 688 |
+
for line in text.split("\n"):
|
| 689 |
+
line = line.strip()
|
| 690 |
+
if line.lower() in ["missed diagnoses:", "medication conflicts:", "incomplete assessments:", "urgent follow-up:"]:
|
| 691 |
+
valid_heading = True
|
| 692 |
+
lines.append(f"**{line[:-1]}**:")
|
| 693 |
+
elif valid_heading and line.startswith("-"):
|
| 694 |
+
lines.append(line)
|
| 695 |
+
else:
|
| 696 |
+
valid_heading = False
|
| 697 |
+
return "\n".join(lines).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|