Ali2206 commited on
Commit
3ada3e8
·
verified ·
1 Parent(s): 9b6dc72

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +101 -83
src/txagent/txagent.py CHANGED
@@ -46,9 +46,11 @@ class TxAgent:
46
  self.model = None
47
  self.rag_model = ToolRAGModel(rag_model_name)
48
  self.tooluniverse = None
49
- self.prompt_multi_step = ("You are a helpful assistant that solves problems through detailed, step-by-step reasoning "
50
- "and actions based on your reasoning. Provide comprehensive and clinically precise responses, "
51
- "including specific diagnoses, tools, and actionable recommendations when analyzing medical data.")
 
 
52
  self.self_prompt = "Strictly follow the instruction."
53
  self.chat_prompt = "You are a helpful assistant to chat with the user."
54
  self.enable_finish = enable_finish
@@ -281,76 +283,76 @@ class TxAgent:
281
  temperature=None,
282
  return_gradio_history=True):
283
 
 
284
  function_call_json, message = self.tooluniverse.extract_function_call_json(
285
  fcall_str, return_message=return_message, verbose=False)
286
  call_results = []
287
  special_tool_call = ''
288
- if return_gradio_history:
289
- gradio_history = []
290
- if function_call_json is not None:
291
- if isinstance(function_call_json, list):
292
- for i in range(len(function_call_json)):
293
- if function_call_json[i]["name"] == 'Finish':
294
- special_tool_call = 'Finish'
295
- break
296
- elif function_call_json[i]["name"] == 'Tool_RAG':
297
- new_tools_prompt, call_result = self.tool_RAG(
298
- message=message,
299
- existing_tools_prompt=existing_tools_prompt,
300
- rag_num=self.step_rag_num,
301
- return_call_result=True)
302
- existing_tools_prompt += new_tools_prompt
303
- elif function_call_json[i]["name"] == 'DirectResponse':
304
- call_result = function_call_json[i]['arguments']['respose']
305
- special_tool_call = 'DirectResponse'
306
- elif function_call_json[i]["name"] == 'RequireClarification':
307
- call_result = function_call_json[i]['arguments']['unclear_question']
308
- special_tool_call = 'RequireClarification'
309
- elif function_call_json[i]["name"] == 'CallAgent':
310
- if call_agent_level < 2 and call_agent:
311
- solution_plan = function_call_json[i]['arguments']['solution']
312
- full_message = (
313
- message_for_call_agent +
314
- "\nYou must follow the following plan to answer the question: " +
315
- str(solution_plan)
316
- )
317
- sub_agent_task = "Sub TxAgent plan: " + \
318
- str(solution_plan)
319
- call_result = yield from self.run_gradio_chat(
320
- full_message, history=[], temperature=temperature,
321
- max_new_tokens=1024, max_token=99999,
322
- call_agent=False, call_agent_level=call_agent_level,
323
- conversation=None,
324
- sub_agent_task=sub_agent_task)
325
 
326
- if call_result is not None and isinstance(call_result, str):
327
- call_result = call_result.split('[FinalAnswer]')[-1]
328
- else:
329
- call_result = "⚠️ No content returned from sub-agent."
330
- else:
331
- call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
332
- else:
333
- call_result = self.tooluniverse.run_one_function(
334
- function_call_json[i])
335
-
336
- call_id = self.tooluniverse.call_id_gen()
337
- function_call_json[i]["call_id"] = call_id
338
- call_results.append({
339
- "role": "tool",
340
- "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
341
- })
342
- if return_gradio_history and function_call_json[i]["name"] != 'Finish':
343
- if function_call_json[i]["name"] == 'Tool_RAG':
344
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
345
- "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
346
- else:
347
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
348
- "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
349
- else:
350
  call_results.append({
351
  "role": "tool",
352
- "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
353
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
  revised_messages = [{
356
  "role": "assistant",
@@ -359,23 +361,24 @@ class TxAgent:
359
  }] + call_results
360
 
361
  if return_gradio_history:
362
- return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
 
363
  else:
364
- return revised_messages, existing_tools_prompt, special_tool_call
365
 
366
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
367
  if conversation[-1]['role'] == 'assistant':
368
  conversation.append(
369
- {'role': 'tool', 'content': 'Errors happened during the function call, please come up with the final answer with the current information.'})
370
  finish_tools_prompt = self.add_finish_tools([])
371
 
372
  last_outputs_str = self.llm_infer(messages=conversation,
373
  temperature=temperature,
374
  tools=finish_tools_prompt,
375
- output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
376
  skip_special_tokens=True,
377
  max_new_tokens=max_new_tokens, max_token=max_token)
378
- print(last_outputs_str)
379
  return last_outputs_str
380
 
381
  def run_multistep_agent(self, message: str,
@@ -782,7 +785,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
782
  message,
783
  conversation=conversation,
784
  history=history)
785
- history = []
786
 
787
  next_round = True
788
  function_call_messages = []
@@ -801,7 +804,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
801
  logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
802
 
803
  if last_outputs:
804
- function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
805
  last_outputs, return_message=True,
806
  existing_tools_prompt=picked_tools_prompt,
807
  message_for_call_agent=message,
@@ -809,9 +812,27 @@ Generate **one summarized sentence** about "function calls' responses" with nece
809
  call_agent_level=call_agent_level,
810
  temperature=temperature)
811
 
812
- history.extend(current_gradio_history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
 
814
  if special_tool_call == 'Finish' and function_call_messages:
 
815
  yield history
816
  next_round = False
817
  conversation.extend(function_call_messages)
@@ -833,11 +854,13 @@ Generate **one summarized sentence** about "function calls' responses" with nece
833
 
834
  if function_call_messages:
835
  conversation.extend(function_call_messages)
836
- yield history
837
  else:
838
  next_round = False
839
- conversation.append({"role": "assistant", "content": ''.join(last_outputs)})
840
- return ''.join(last_outputs).replace("</s>", "")
 
 
 
841
 
842
  if self.enable_checker:
843
  good_status, wrong_info = checker.check_conversation()
@@ -875,7 +898,6 @@ Generate **one summarized sentence** about "function calls' responses" with nece
875
  parts = last_thought.split('[FinalAnswer]', 1)
876
  final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
877
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
878
- yield history
879
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
880
  yield history
881
  next_round = False
@@ -894,12 +916,10 @@ Generate **one summarized sentence** about "function calls' responses" with nece
894
  parts = last_outputs_str.split('[FinalAnswer]', 1)
895
  final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
896
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
897
- yield history
898
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
899
- yield history
900
  else:
901
  history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
902
- yield history
903
  else:
904
  error_msg = "The number of reasoning rounds exceeded the limit."
905
  history.append(ChatMessage(role="assistant", content=error_msg))
@@ -918,10 +938,8 @@ Generate **one summarized sentence** about "function calls' responses" with nece
918
  parts = last_outputs_str.split('[FinalAnswer]', 1)
919
  final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
920
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
921
- yield history
922
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
923
- yield history
924
  else:
925
  history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
926
- yield history
927
  return error_msg
 
46
  self.model = None
47
  self.rag_model = ToolRAGModel(rag_model_name)
48
  self.tooluniverse = None
49
+ self.prompt_multi_step = ("You are a highly skilled medical assistant tasked with analyzing medical records in detail. "
50
+ "Provide comprehensive, step-by-step reasoning to identify oversights, including specific diagnoses, "
51
+ "medication conflicts, incomplete assessments, and abnormal results. For each point, include clinical "
52
+ "rationale, standardized screening tools (e.g., PCL-5, SCID-5-PD), and actionable recommendations for "
53
+ "follow-up, ensuring a thorough and precise response.")
54
  self.self_prompt = "Strictly follow the instruction."
55
  self.chat_prompt = "You are a helpful assistant to chat with the user."
56
  self.enable_finish = enable_finish
 
283
  temperature=None,
284
  return_gradio_history=True):
285
 
286
+ logger.debug(f"Running function call stream with input: {fcall_str[:100]}...")
287
  function_call_json, message = self.tooluniverse.extract_function_call_json(
288
  fcall_str, return_message=return_message, verbose=False)
289
  call_results = []
290
  special_tool_call = ''
291
+ gradio_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ if function_call_json is None:
294
+ logger.warning("No valid function call JSON extracted")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  call_results.append({
296
  "role": "tool",
297
+ "content": json.dumps({"content": "Invalid function call format."})
298
  })
299
+ if return_gradio_history:
300
+ gradio_history.append(ChatMessage(role="assistant", content="Error: Invalid tool call format."))
301
+ yield call_results, existing_tools_prompt or [], special_tool_call, gradio_history
302
+ return
303
+
304
+ if isinstance(function_call_json, list):
305
+ for i in range(len(function_call_json)):
306
+ logger.debug(f"Processing tool call: {function_call_json[i]}")
307
+ if function_call_json[i]["name"] == 'Finish':
308
+ special_tool_call = 'Finish'
309
+ break
310
+ elif function_call_json[i]["name"] == 'Tool_RAG':
311
+ new_tools_prompt, call_result = self.tool_RAG(
312
+ message=message,
313
+ existing_tools_prompt=existing_tools_prompt,
314
+ rag_num=self.step_rag_num,
315
+ return_call_result=True)
316
+ existing_tools_prompt = (existing_tools_prompt or []) + new_tools_prompt
317
+ elif function_call_json[i]["name"] == 'DirectResponse':
318
+ call_result = function_call_json[i]['arguments']['respose']
319
+ special_tool_call = 'DirectResponse'
320
+ elif function_call_json[i]["name"] == 'RequireClarification':
321
+ call_result = function_call_json[i]['arguments']['unclear_question']
322
+ special_tool_call = 'RequireClarification'
323
+ elif function_call_json[i]["name"] == 'CallAgent':
324
+ if call_agent_level < 2 and call_agent:
325
+ solution_plan = function_call_json[i]['arguments']['solution']
326
+ full_message = (
327
+ message_for_call_agent +
328
+ "\nYou must follow the following plan to answer the question: " +
329
+ str(solution_plan)
330
+ )
331
+ sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
332
+ sub_result = yield from self.run_gradio_chat(
333
+ full_message, history=[], temperature=temperature,
334
+ max_new_tokens=1024, max_token=99999,
335
+ call_agent=False, call_agent_level=call_agent_level,
336
+ conversation=None,
337
+ sub_agent_task=sub_agent_task)
338
+ call_result = sub_result if isinstance(sub_result, str) else "No content from sub-agent."
339
+ if '[FinalAnswer]' in call_result:
340
+ call_result = call_result.split('[FinalAnswer]')[-1].strip()
341
+ else:
342
+ call_result = "CallAgent disabled. Proceeding with reasoning."
343
+ else:
344
+ call_result = self.tooluniverse.run_one_function(function_call_json[i])
345
+
346
+ call_id = self.tooluniverse.call_id_gen()
347
+ function_call_json[i]["call_id"] = call_id
348
+ call_results.append({
349
+ "role": "tool",
350
+ "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
351
+ })
352
+
353
+ if return_gradio_history and function_call_json[i]["name"] != 'Finish':
354
+ metadata = {"title": f"⚒️ {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
355
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata=metadata))
356
 
357
  revised_messages = [{
358
  "role": "assistant",
 
361
  }] + call_results
362
 
363
  if return_gradio_history:
364
+ logger.debug(f"Yielding gradio history with {len(gradio_history)} entries")
365
+ yield revised_messages, existing_tools_prompt or [], special_tool_call, gradio_history
366
  else:
367
+ yield revised_messages, existing_tools_prompt or [], special_tool_call
368
 
369
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
370
  if conversation[-1]['role'] == 'assistant':
371
  conversation.append(
372
+ {'role': 'tool', 'content': 'Errors occurred; provide a detailed final answer based on current information.'})
373
  finish_tools_prompt = self.add_finish_tools([])
374
 
375
  last_outputs_str = self.llm_infer(messages=conversation,
376
  temperature=temperature,
377
  tools=finish_tools_prompt,
378
+ output_begin_string='[FinalAnswer]',
379
  skip_special_tokens=True,
380
  max_new_tokens=max_new_tokens, max_token=max_token)
381
+ logger.debug(f"Forced finish output: {last_outputs_str[:100]}...")
382
  return last_outputs_str
383
 
384
  def run_multistep_agent(self, message: str,
 
785
  message,
786
  conversation=conversation,
787
  history=history)
788
+ history = [] # Reset history to avoid duplication
789
 
790
  next_round = True
791
  function_call_messages = []
 
804
  logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
805
 
806
  if last_outputs:
807
+ function_call_result = yield from self.run_function_call_stream(
808
  last_outputs, return_message=True,
809
  existing_tools_prompt=picked_tools_prompt,
810
  message_for_call_agent=message,
 
812
  call_agent_level=call_agent_level,
813
  temperature=temperature)
814
 
815
+ # Ensure function_call_result is valid
816
+ if not function_call_result:
817
+ logger.warning("Empty result from run_function_call_stream")
818
+ error_msg = "Error: Tool call processing failed."
819
+ history.append(ChatMessage(role="assistant", content=error_msg))
820
+ yield history
821
+ return error_msg
822
+
823
+ function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = function_call_result
824
+
825
+ # Deduplicate history entries
826
+ unique_history = []
827
+ seen_contents = set()
828
+ for msg in current_gradio_history:
829
+ if msg.content not in seen_contents:
830
+ unique_history.append(msg)
831
+ seen_contents.add(msg.content)
832
+ history.extend(unique_history)
833
 
834
  if special_tool_call == 'Finish' and function_call_messages:
835
+ history.append(ChatMessage(role="assistant", content=function_call_messages[0]['content']))
836
  yield history
837
  next_round = False
838
  conversation.extend(function_call_messages)
 
854
 
855
  if function_call_messages:
856
  conversation.extend(function_call_messages)
 
857
  else:
858
  next_round = False
859
+ content = ''.join(last_outputs).replace("</s>", "")
860
+ history.append(ChatMessage(role="assistant", content=content))
861
+ conversation.append({"role": "assistant", "content": content})
862
+ yield history
863
+ return content
864
 
865
  if self.enable_checker:
866
  good_status, wrong_info = checker.check_conversation()
 
898
  parts = last_thought.split('[FinalAnswer]', 1)
899
  final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
900
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
 
901
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
902
  yield history
903
  next_round = False
 
916
  parts = last_outputs_str.split('[FinalAnswer]', 1)
917
  final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
918
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
 
919
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
 
920
  else:
921
  history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
922
+ yield history
923
  else:
924
  error_msg = "The number of reasoning rounds exceeded the limit."
925
  history.append(ChatMessage(role="assistant", content=error_msg))
 
938
  parts = last_outputs_str.split('[FinalAnswer]', 1)
939
  final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
940
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
 
941
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
 
942
  else:
943
  history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
944
+ yield history
945
  return error_msg