Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +101 -83
src/txagent/txagent.py
CHANGED
@@ -46,9 +46,11 @@ class TxAgent:
|
|
46 |
self.model = None
|
47 |
self.rag_model = ToolRAGModel(rag_model_name)
|
48 |
self.tooluniverse = None
|
49 |
-
self.prompt_multi_step = ("You are a
|
50 |
-
"
|
51 |
-
"
|
|
|
|
|
52 |
self.self_prompt = "Strictly follow the instruction."
|
53 |
self.chat_prompt = "You are a helpful assistant to chat with the user."
|
54 |
self.enable_finish = enable_finish
|
@@ -281,76 +283,76 @@ class TxAgent:
|
|
281 |
temperature=None,
|
282 |
return_gradio_history=True):
|
283 |
|
|
|
284 |
function_call_json, message = self.tooluniverse.extract_function_call_json(
|
285 |
fcall_str, return_message=return_message, verbose=False)
|
286 |
call_results = []
|
287 |
special_tool_call = ''
|
288 |
-
|
289 |
-
gradio_history = []
|
290 |
-
if function_call_json is not None:
|
291 |
-
if isinstance(function_call_json, list):
|
292 |
-
for i in range(len(function_call_json)):
|
293 |
-
if function_call_json[i]["name"] == 'Finish':
|
294 |
-
special_tool_call = 'Finish'
|
295 |
-
break
|
296 |
-
elif function_call_json[i]["name"] == 'Tool_RAG':
|
297 |
-
new_tools_prompt, call_result = self.tool_RAG(
|
298 |
-
message=message,
|
299 |
-
existing_tools_prompt=existing_tools_prompt,
|
300 |
-
rag_num=self.step_rag_num,
|
301 |
-
return_call_result=True)
|
302 |
-
existing_tools_prompt += new_tools_prompt
|
303 |
-
elif function_call_json[i]["name"] == 'DirectResponse':
|
304 |
-
call_result = function_call_json[i]['arguments']['respose']
|
305 |
-
special_tool_call = 'DirectResponse'
|
306 |
-
elif function_call_json[i]["name"] == 'RequireClarification':
|
307 |
-
call_result = function_call_json[i]['arguments']['unclear_question']
|
308 |
-
special_tool_call = 'RequireClarification'
|
309 |
-
elif function_call_json[i]["name"] == 'CallAgent':
|
310 |
-
if call_agent_level < 2 and call_agent:
|
311 |
-
solution_plan = function_call_json[i]['arguments']['solution']
|
312 |
-
full_message = (
|
313 |
-
message_for_call_agent +
|
314 |
-
"\nYou must follow the following plan to answer the question: " +
|
315 |
-
str(solution_plan)
|
316 |
-
)
|
317 |
-
sub_agent_task = "Sub TxAgent plan: " + \
|
318 |
-
str(solution_plan)
|
319 |
-
call_result = yield from self.run_gradio_chat(
|
320 |
-
full_message, history=[], temperature=temperature,
|
321 |
-
max_new_tokens=1024, max_token=99999,
|
322 |
-
call_agent=False, call_agent_level=call_agent_level,
|
323 |
-
conversation=None,
|
324 |
-
sub_agent_task=sub_agent_task)
|
325 |
|
326 |
-
|
327 |
-
|
328 |
-
else:
|
329 |
-
call_result = "⚠️ No content returned from sub-agent."
|
330 |
-
else:
|
331 |
-
call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
|
332 |
-
else:
|
333 |
-
call_result = self.tooluniverse.run_one_function(
|
334 |
-
function_call_json[i])
|
335 |
-
|
336 |
-
call_id = self.tooluniverse.call_id_gen()
|
337 |
-
function_call_json[i]["call_id"] = call_id
|
338 |
-
call_results.append({
|
339 |
-
"role": "tool",
|
340 |
-
"content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
|
341 |
-
})
|
342 |
-
if return_gradio_history and function_call_json[i]["name"] != 'Finish':
|
343 |
-
if function_call_json[i]["name"] == 'Tool_RAG':
|
344 |
-
gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
|
345 |
-
"title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
|
346 |
-
else:
|
347 |
-
gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
|
348 |
-
"title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
|
349 |
-
else:
|
350 |
call_results.append({
|
351 |
"role": "tool",
|
352 |
-
"content": json.dumps({"content": "
|
353 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
|
355 |
revised_messages = [{
|
356 |
"role": "assistant",
|
@@ -359,23 +361,24 @@ class TxAgent:
|
|
359 |
}] + call_results
|
360 |
|
361 |
if return_gradio_history:
|
362 |
-
|
|
|
363 |
else:
|
364 |
-
|
365 |
|
366 |
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
|
367 |
if conversation[-1]['role'] == 'assistant':
|
368 |
conversation.append(
|
369 |
-
{'role': 'tool', 'content': 'Errors
|
370 |
finish_tools_prompt = self.add_finish_tools([])
|
371 |
|
372 |
last_outputs_str = self.llm_infer(messages=conversation,
|
373 |
temperature=temperature,
|
374 |
tools=finish_tools_prompt,
|
375 |
-
output_begin_string='
|
376 |
skip_special_tokens=True,
|
377 |
max_new_tokens=max_new_tokens, max_token=max_token)
|
378 |
-
|
379 |
return last_outputs_str
|
380 |
|
381 |
def run_multistep_agent(self, message: str,
|
@@ -782,7 +785,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
782 |
message,
|
783 |
conversation=conversation,
|
784 |
history=history)
|
785 |
-
history = []
|
786 |
|
787 |
next_round = True
|
788 |
function_call_messages = []
|
@@ -801,7 +804,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
801 |
logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
|
802 |
|
803 |
if last_outputs:
|
804 |
-
|
805 |
last_outputs, return_message=True,
|
806 |
existing_tools_prompt=picked_tools_prompt,
|
807 |
message_for_call_agent=message,
|
@@ -809,9 +812,27 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
809 |
call_agent_level=call_agent_level,
|
810 |
temperature=temperature)
|
811 |
|
812 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
|
814 |
if special_tool_call == 'Finish' and function_call_messages:
|
|
|
815 |
yield history
|
816 |
next_round = False
|
817 |
conversation.extend(function_call_messages)
|
@@ -833,11 +854,13 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
833 |
|
834 |
if function_call_messages:
|
835 |
conversation.extend(function_call_messages)
|
836 |
-
yield history
|
837 |
else:
|
838 |
next_round = False
|
839 |
-
|
840 |
-
|
|
|
|
|
|
|
841 |
|
842 |
if self.enable_checker:
|
843 |
good_status, wrong_info = checker.check_conversation()
|
@@ -875,7 +898,6 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
875 |
parts = last_thought.split('[FinalAnswer]', 1)
|
876 |
final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
|
877 |
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
878 |
-
yield history
|
879 |
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
880 |
yield history
|
881 |
next_round = False
|
@@ -894,12 +916,10 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
894 |
parts = last_outputs_str.split('[FinalAnswer]', 1)
|
895 |
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
|
896 |
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
897 |
-
yield history
|
898 |
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
899 |
-
yield history
|
900 |
else:
|
901 |
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
|
902 |
-
|
903 |
else:
|
904 |
error_msg = "The number of reasoning rounds exceeded the limit."
|
905 |
history.append(ChatMessage(role="assistant", content=error_msg))
|
@@ -918,10 +938,8 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
918 |
parts = last_outputs_str.split('[FinalAnswer]', 1)
|
919 |
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
|
920 |
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
921 |
-
yield history
|
922 |
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
923 |
-
yield history
|
924 |
else:
|
925 |
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
|
926 |
-
|
927 |
return error_msg
|
|
|
46 |
self.model = None
|
47 |
self.rag_model = ToolRAGModel(rag_model_name)
|
48 |
self.tooluniverse = None
|
49 |
+
self.prompt_multi_step = ("You are a highly skilled medical assistant tasked with analyzing medical records in detail. "
|
50 |
+
"Provide comprehensive, step-by-step reasoning to identify oversights, including specific diagnoses, "
|
51 |
+
"medication conflicts, incomplete assessments, and abnormal results. For each point, include clinical "
|
52 |
+
"rationale, standardized screening tools (e.g., PCL-5, SCID-5-PD), and actionable recommendations for "
|
53 |
+
"follow-up, ensuring a thorough and precise response.")
|
54 |
self.self_prompt = "Strictly follow the instruction."
|
55 |
self.chat_prompt = "You are a helpful assistant to chat with the user."
|
56 |
self.enable_finish = enable_finish
|
|
|
283 |
temperature=None,
|
284 |
return_gradio_history=True):
|
285 |
|
286 |
+
logger.debug(f"Running function call stream with input: {fcall_str[:100]}...")
|
287 |
function_call_json, message = self.tooluniverse.extract_function_call_json(
|
288 |
fcall_str, return_message=return_message, verbose=False)
|
289 |
call_results = []
|
290 |
special_tool_call = ''
|
291 |
+
gradio_history = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
+
if function_call_json is None:
|
294 |
+
logger.warning("No valid function call JSON extracted")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
call_results.append({
|
296 |
"role": "tool",
|
297 |
+
"content": json.dumps({"content": "Invalid function call format."})
|
298 |
})
|
299 |
+
if return_gradio_history:
|
300 |
+
gradio_history.append(ChatMessage(role="assistant", content="Error: Invalid tool call format."))
|
301 |
+
yield call_results, existing_tools_prompt or [], special_tool_call, gradio_history
|
302 |
+
return
|
303 |
+
|
304 |
+
if isinstance(function_call_json, list):
|
305 |
+
for i in range(len(function_call_json)):
|
306 |
+
logger.debug(f"Processing tool call: {function_call_json[i]}")
|
307 |
+
if function_call_json[i]["name"] == 'Finish':
|
308 |
+
special_tool_call = 'Finish'
|
309 |
+
break
|
310 |
+
elif function_call_json[i]["name"] == 'Tool_RAG':
|
311 |
+
new_tools_prompt, call_result = self.tool_RAG(
|
312 |
+
message=message,
|
313 |
+
existing_tools_prompt=existing_tools_prompt,
|
314 |
+
rag_num=self.step_rag_num,
|
315 |
+
return_call_result=True)
|
316 |
+
existing_tools_prompt = (existing_tools_prompt or []) + new_tools_prompt
|
317 |
+
elif function_call_json[i]["name"] == 'DirectResponse':
|
318 |
+
call_result = function_call_json[i]['arguments']['respose']
|
319 |
+
special_tool_call = 'DirectResponse'
|
320 |
+
elif function_call_json[i]["name"] == 'RequireClarification':
|
321 |
+
call_result = function_call_json[i]['arguments']['unclear_question']
|
322 |
+
special_tool_call = 'RequireClarification'
|
323 |
+
elif function_call_json[i]["name"] == 'CallAgent':
|
324 |
+
if call_agent_level < 2 and call_agent:
|
325 |
+
solution_plan = function_call_json[i]['arguments']['solution']
|
326 |
+
full_message = (
|
327 |
+
message_for_call_agent +
|
328 |
+
"\nYou must follow the following plan to answer the question: " +
|
329 |
+
str(solution_plan)
|
330 |
+
)
|
331 |
+
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
332 |
+
sub_result = yield from self.run_gradio_chat(
|
333 |
+
full_message, history=[], temperature=temperature,
|
334 |
+
max_new_tokens=1024, max_token=99999,
|
335 |
+
call_agent=False, call_agent_level=call_agent_level,
|
336 |
+
conversation=None,
|
337 |
+
sub_agent_task=sub_agent_task)
|
338 |
+
call_result = sub_result if isinstance(sub_result, str) else "No content from sub-agent."
|
339 |
+
if '[FinalAnswer]' in call_result:
|
340 |
+
call_result = call_result.split('[FinalAnswer]')[-1].strip()
|
341 |
+
else:
|
342 |
+
call_result = "CallAgent disabled. Proceeding with reasoning."
|
343 |
+
else:
|
344 |
+
call_result = self.tooluniverse.run_one_function(function_call_json[i])
|
345 |
+
|
346 |
+
call_id = self.tooluniverse.call_id_gen()
|
347 |
+
function_call_json[i]["call_id"] = call_id
|
348 |
+
call_results.append({
|
349 |
+
"role": "tool",
|
350 |
+
"content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
|
351 |
+
})
|
352 |
+
|
353 |
+
if return_gradio_history and function_call_json[i]["name"] != 'Finish':
|
354 |
+
metadata = {"title": f"⚒️ {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
|
355 |
+
gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata=metadata))
|
356 |
|
357 |
revised_messages = [{
|
358 |
"role": "assistant",
|
|
|
361 |
}] + call_results
|
362 |
|
363 |
if return_gradio_history:
|
364 |
+
logger.debug(f"Yielding gradio history with {len(gradio_history)} entries")
|
365 |
+
yield revised_messages, existing_tools_prompt or [], special_tool_call, gradio_history
|
366 |
else:
|
367 |
+
yield revised_messages, existing_tools_prompt or [], special_tool_call
|
368 |
|
369 |
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
|
370 |
if conversation[-1]['role'] == 'assistant':
|
371 |
conversation.append(
|
372 |
+
{'role': 'tool', 'content': 'Errors occurred; provide a detailed final answer based on current information.'})
|
373 |
finish_tools_prompt = self.add_finish_tools([])
|
374 |
|
375 |
last_outputs_str = self.llm_infer(messages=conversation,
|
376 |
temperature=temperature,
|
377 |
tools=finish_tools_prompt,
|
378 |
+
output_begin_string='[FinalAnswer]',
|
379 |
skip_special_tokens=True,
|
380 |
max_new_tokens=max_new_tokens, max_token=max_token)
|
381 |
+
logger.debug(f"Forced finish output: {last_outputs_str[:100]}...")
|
382 |
return last_outputs_str
|
383 |
|
384 |
def run_multistep_agent(self, message: str,
|
|
|
785 |
message,
|
786 |
conversation=conversation,
|
787 |
history=history)
|
788 |
+
history = [] # Reset history to avoid duplication
|
789 |
|
790 |
next_round = True
|
791 |
function_call_messages = []
|
|
|
804 |
logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
|
805 |
|
806 |
if last_outputs:
|
807 |
+
function_call_result = yield from self.run_function_call_stream(
|
808 |
last_outputs, return_message=True,
|
809 |
existing_tools_prompt=picked_tools_prompt,
|
810 |
message_for_call_agent=message,
|
|
|
812 |
call_agent_level=call_agent_level,
|
813 |
temperature=temperature)
|
814 |
|
815 |
+
# Ensure function_call_result is valid
|
816 |
+
if not function_call_result:
|
817 |
+
logger.warning("Empty result from run_function_call_stream")
|
818 |
+
error_msg = "Error: Tool call processing failed."
|
819 |
+
history.append(ChatMessage(role="assistant", content=error_msg))
|
820 |
+
yield history
|
821 |
+
return error_msg
|
822 |
+
|
823 |
+
function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = function_call_result
|
824 |
+
|
825 |
+
# Deduplicate history entries
|
826 |
+
unique_history = []
|
827 |
+
seen_contents = set()
|
828 |
+
for msg in current_gradio_history:
|
829 |
+
if msg.content not in seen_contents:
|
830 |
+
unique_history.append(msg)
|
831 |
+
seen_contents.add(msg.content)
|
832 |
+
history.extend(unique_history)
|
833 |
|
834 |
if special_tool_call == 'Finish' and function_call_messages:
|
835 |
+
history.append(ChatMessage(role="assistant", content=function_call_messages[0]['content']))
|
836 |
yield history
|
837 |
next_round = False
|
838 |
conversation.extend(function_call_messages)
|
|
|
854 |
|
855 |
if function_call_messages:
|
856 |
conversation.extend(function_call_messages)
|
|
|
857 |
else:
|
858 |
next_round = False
|
859 |
+
content = ''.join(last_outputs).replace("</s>", "")
|
860 |
+
history.append(ChatMessage(role="assistant", content=content))
|
861 |
+
conversation.append({"role": "assistant", "content": content})
|
862 |
+
yield history
|
863 |
+
return content
|
864 |
|
865 |
if self.enable_checker:
|
866 |
good_status, wrong_info = checker.check_conversation()
|
|
|
898 |
parts = last_thought.split('[FinalAnswer]', 1)
|
899 |
final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
|
900 |
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
|
|
901 |
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
902 |
yield history
|
903 |
next_round = False
|
|
|
916 |
parts = last_outputs_str.split('[FinalAnswer]', 1)
|
917 |
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
|
918 |
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
|
|
919 |
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
|
|
920 |
else:
|
921 |
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
|
922 |
+
yield history
|
923 |
else:
|
924 |
error_msg = "The number of reasoning rounds exceeded the limit."
|
925 |
history.append(ChatMessage(role="assistant", content=error_msg))
|
|
|
938 |
parts = last_outputs_str.split('[FinalAnswer]', 1)
|
939 |
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
|
940 |
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
|
|
941 |
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
|
|
942 |
else:
|
943 |
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
|
944 |
+
yield history
|
945 |
return error_msg
|