Ali2206 commited on
Commit
5707e8d
·
verified ·
1 Parent(s): d313543

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +52 -139
src/txagent/txagent.py CHANGED
@@ -23,11 +23,11 @@ class TxAgent:
23
  def __init__(self, model_name,
24
  rag_model_name,
25
  tool_files_dict=None,
26
- enable_finish=True,
27
- enable_rag=True,
28
  enable_summary=False,
29
  init_rag_num=0,
30
- step_rag_num=10,
31
  summary_mode='step',
32
  summary_skip_last_k=0,
33
  summary_context_length=None,
@@ -79,7 +79,7 @@ class TxAgent:
79
  if model_name == self.model_name:
80
  return f"The model {model_name} is already loaded."
81
  self.model_name = model_name
82
- self.model = LLM(model=self.model_name, enforce_eager=True)
83
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
84
  self.tokenizer = self.model.get_tokenizer()
85
  return f"Model {model_name} loaded successfully."
@@ -165,16 +165,16 @@ class TxAgent:
165
  def add_special_tools(self, tools, call_agent=False):
166
  if not self.enable_rag and not self.enable_finish:
167
  return tools
168
- if self.enable_finish and self.tooluniverse: # MODIFIED: Check tooluniverse
169
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
170
  logger.info("Finish tool is added")
171
- if call_agent and self.tooluniverse: # MODIFIED: Check tooluniverse
172
  tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
173
  logger.info("CallAgent tool is added")
174
- elif self.enable_rag and self.tooluniverse: # MODIFIED: Check tooluniverse
175
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
176
  logger.info("Tool_RAG tool is added")
177
- if self.additional_default_tools is not None and self.tooluniverse: # MODIFIED: Check tooluniverse
178
  for each_tool_name in self.additional_default_tools:
179
  tool_prompt = self.tooluniverse.get_one_tool_by_one_name(each_tool_name, return_prompt=True)
180
  if tool_prompt is not None:
@@ -183,7 +183,7 @@ class TxAgent:
183
  return tools
184
 
185
  def add_finish_tools(self, tools):
186
- if not self.enable_finish or not self.tooluniverse: # MODIFIED: Check tooluniverse
187
  return tools
188
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
189
  logger.info("Finish tool is added")
@@ -346,10 +346,9 @@ class TxAgent:
346
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
347
  if conversation[-1]['role'] == 'assistant':
348
  conversation.append({'role': 'tool', 'content': 'Errors occurred, provide final answer.'})
349
- finish_tools_prompt = self.add_finish_tools([]) if self.enable_finish else []
350
  last_outputs_str = self.llm_infer(messages=conversation,
351
  temperature=temperature,
352
- tools=finish_tools_prompt,
353
  output_begin_string='[FinalAnswer]',
354
  skip_special_tokens=True,
355
  max_new_tokens=max_new_tokens, max_token=max_token)
@@ -654,131 +653,45 @@ Generate one summarized sentence about "function calls' responses" with necessar
654
  return "Invalid input."
655
  if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
656
  return ""
657
- outputs = []
658
- last_outputs = []
659
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
660
- call_agent, call_agent_level, message)
661
- conversation = self.initialize_conversation(
662
- message, conversation=conversation, history=history)
663
- history = [] if not history else history # MODIFIED: Simplify history
664
- next_round = True
665
- function_call_messages = []
666
- current_round = 0
667
- enable_summary = False
668
- last_status = {}
669
- token_overflow = False
670
- if self.enable_checker:
671
- checker = ReasoningTraceChecker(message, conversation, init_index=len(conversation))
672
- try:
673
- while next_round and current_round < max_round:
674
- current_round += 1
675
- logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
676
- if last_outputs and self.enable_rag:
677
- function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
678
- last_outputs, return_message=True,
679
- existing_tools_prompt=picked_tools_prompt,
680
- message_for_call_agent=message,
681
- call_agent=call_agent,
682
- call_agent_level=call_agent_level,
683
- temperature=temperature)
684
- history.extend(current_gradio_history)
685
- if special_tool_call == 'Finish' and function_call_messages:
686
- yield history
687
- next_round = False
688
- conversation.extend(function_call_messages)
689
- return function_call_messages[0]['content']
690
- elif special_tool_call in ['RequireClarification', 'DirectResponse']:
691
- last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
692
- history.append(ChatMessage(role="assistant", content=last_msg.content))
693
- yield history
694
- next_round = False
695
- return last_msg.content
696
- if (self.enable_summary or token_overflow) and not call_agent:
697
- enable_summary = True
698
- last_status = self.function_result_summary(
699
- conversation, status=last_status, enable_summary=enable_summary)
700
- if function_call_messages:
701
- conversation.extend(function_call_messages)
702
- yield history
703
- else:
704
- next_round = False
705
- conversation.append({"role": "assistant", "content": ''.join(last_outputs)})
706
- return ''.join(last_outputs).replace("</s>", "")
707
- if self.enable_checker:
708
- good_status, wrong_info = checker.check_conversation()
709
- if not good_status:
710
- logger.warning(f"Checker flagged reasoning error: {wrong_info}")
711
- break
712
- last_outputs = []
713
- last_outputs_str, token_overflow = self.llm_infer(
714
- messages=conversation,
715
- temperature=temperature,
716
- tools=picked_tools_prompt,
717
- skip_special_tokens=False,
718
- max_new_tokens=max_new_tokens,
719
- max_token=max_token,
720
- seed=seed,
721
- check_token_status=True)
722
- logger.debug(f"llm_infer output: {last_outputs_str[:100] if last_outputs_str else None}")
723
- if last_outputs_str is None:
724
- logger.warning("Token overflow")
725
- if self.force_finish:
726
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
727
- conversation, temperature, max_new_tokens, max_token)
728
- history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
729
- yield history
730
- return last_outputs_str
731
- error_msg = "Token limit exceeded."
732
- history.append(ChatMessage(role="assistant", content=error_msg))
733
- yield history
734
- return error_msg
735
- last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
736
- for msg in history:
737
- if msg.metadata is not None:
738
- msg.metadata['status'] = 'done'
739
- if '[FinalAnswer]' in last_thought:
740
- parts = last_thought.split('[FinalAnswer]', 1)
741
- final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
742
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
743
- yield history
744
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
745
- yield history
746
- else:
747
- history.append(ChatMessage(role="assistant", content=last_thought))
748
- yield history
749
- last_outputs.append(last_outputs_str)
750
- if next_round:
751
- if self.force_finish:
752
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
753
- conversation, temperature, max_new_tokens, max_token)
754
- if '[FinalAnswer]' in last_outputs_str:
755
- parts = last_outputs_str.split('[FinalAnswer]', 1)
756
- final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
757
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
758
- yield history
759
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
760
- yield history
761
- else:
762
- history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
763
- yield history
764
- else:
765
- yield "Reasoning rounds exceeded."
766
- except Exception as e:
767
- logger.error(f"Exception in run_gradio_chat: {e}")
768
- error_msg = f"An error occurred: {e}"
769
- history.append(ChatMessage(role="assistant", content=error_msg))
770
- yield history
771
- if self.force_finish:
772
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
773
- conversation, temperature, max_new_tokens, max_token)
774
- if '[FinalAnswer]' in last_outputs_str:
775
- parts = last_outputs_str.split('[FinalAnswer]', 1)
776
- final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
777
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
778
- yield history
779
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
780
- yield history
781
- else:
782
- history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
783
- yield history
784
- return error_msg
 
23
  def __init__(self, model_name,
24
  rag_model_name,
25
  tool_files_dict=None,
26
+ enable_finish=False, # MODIFIED: Default to False
27
+ enable_rag=False,
28
  enable_summary=False,
29
  init_rag_num=0,
30
+ step_rag_num=0,
31
  summary_mode='step',
32
  summary_skip_last_k=0,
33
  summary_context_length=None,
 
79
  if model_name == self.model_name:
80
  return f"The model {model_name} is already loaded."
81
  self.model_name = model_name
82
+ self.model = LLM(model=self.model_name, enforce_eager=True, max_model_len=4096) # MODIFIED: Reduce KV cache
83
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
84
  self.tokenizer = self.model.get_tokenizer()
85
  return f"Model {model_name} loaded successfully."
 
165
  def add_special_tools(self, tools, call_agent=False):
166
  if not self.enable_rag and not self.enable_finish:
167
  return tools
168
+ if self.enable_finish and self.tooluniverse:
169
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
170
  logger.info("Finish tool is added")
171
+ if call_agent and self.tooluniverse:
172
  tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
173
  logger.info("CallAgent tool is added")
174
+ elif self.enable_rag and self.tooluniverse:
175
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
176
  logger.info("Tool_RAG tool is added")
177
+ if self.additional_default_tools is not None and self.tooluniverse:
178
  for each_tool_name in self.additional_default_tools:
179
  tool_prompt = self.tooluniverse.get_one_tool_by_one_name(each_tool_name, return_prompt=True)
180
  if tool_prompt is not None:
 
183
  return tools
184
 
185
  def add_finish_tools(self, tools):
186
+ if not self.enable_finish or not self.tooluniverse:
187
  return tools
188
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
189
  logger.info("Finish tool is added")
 
346
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
347
  if conversation[-1]['role'] == 'assistant':
348
  conversation.append({'role': 'tool', 'content': 'Errors occurred, provide final answer.'})
 
349
  last_outputs_str = self.llm_infer(messages=conversation,
350
  temperature=temperature,
351
+ tools=[],
352
  output_begin_string='[FinalAnswer]',
353
  skip_special_tokens=True,
354
  max_new_tokens=max_new_tokens, max_token=max_token)
 
653
  return "Invalid input."
654
  if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
655
  return ""
656
+ conversation = self.initialize_conversation(message, conversation, history=[])
657
+ sampling_params = SamplingParams(
658
+ temperature=temperature,
659
+ max_tokens=max_new_tokens,
660
+ seed=seed if seed is not None else self.seed,
661
+ )
662
+ prompt = self.chat_template.render(messages=conversation, tools=[], add_generation_prompt=True)
663
+ output = self.model.generate([prompt], sampling_params)[0].outputs[0].text # MODIFIED: Direct inference
664
+ cleaned = clean_response(output) # MODIFIED: Use clean_response
665
+ if '[FinalAnswer]' in cleaned:
666
+ parts = cleaned.split('[FinalAnswer]', 1)
667
+ final_answer = parts[1] if len(parts) > 1 else cleaned
668
+ history.append(ChatMessage(role="assistant", content=final_answer.strip()))
669
+ else:
670
+ history.append(ChatMessage(role="assistant", content=cleaned.strip()))
671
+ yield history
672
+ return cleaned
673
+
674
+ def clean_response(text: str) -> str: # MODIFIED: Add clean_response for compatibility
675
+ text = sanitize_utf8(text)
676
+ text = re.sub(r"\[TOOL_CALLS\].*?\n|\[.*?\].*?\n|(?:get_|tool\s|retrieve\s|use\s|rag\s).*?\n", "", text, flags=re.DOTALL | re.IGNORECASE)
677
+ text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
678
+ text = re.sub(
679
+ r"(?i)(to\s|analyze|will\s|since\s|no\s|none|previous|attempt|involve|check\s|explore|manually|"
680
+ r"start|look|use|focus|retrieve|tool|based\s|overall|indicate|mention|consider|ensure|need\s|"
681
+ r"provide|review|assess|identify|potential|records|patient|history|symptoms|medication|"
682
+ r"conflict|assessment|follow-up|issue|reasoning|step|prompt|address|rag|thought|try|john\sdoe|nkma).*?\n",
683
+ "", text, flags=re.DOTALL
684
+ )
685
+ text = re.sub(r"\n{2,}", "\n", text).strip()
686
+ lines = []
687
+ valid_heading = False
688
+ for line in text.split("\n"):
689
+ line = line.strip()
690
+ if line.lower() in ["missed diagnoses:", "medication conflicts:", "incomplete assessments:", "urgent follow-up:"]:
691
+ valid_heading = True
692
+ lines.append(f"**{line[:-1]}**:")
693
+ elif valid_heading and line.startswith("-"):
694
+ lines.append(line)
695
+ else:
696
+ valid_heading = False
697
+ return "\n".join(lines).strip()