Ali2206 commited on
Commit
9b6dc72
·
verified ·
1 Parent(s): 4712249

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +39 -75
src/txagent/txagent.py CHANGED
@@ -12,18 +12,17 @@ from tooluniverse import ToolUniverse
12
  from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
14
  import torch
15
- # near the top of txagent.py
16
  import logging
 
17
  logger = logging.getLogger(__name__)
18
- logging.basicConfig(level=logging.INFO)
19
 
20
  from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
 
22
-
23
  class TxAgent:
24
  def __init__(self, model_name,
25
  rag_model_name,
26
- tool_files_dict=None, # None leads to the default tool files in ToolUniverse
27
  enable_finish=True,
28
  enable_rag=True,
29
  enable_summary=False,
@@ -47,10 +46,11 @@ class TxAgent:
47
  self.model = None
48
  self.rag_model = ToolRAGModel(rag_model_name)
49
  self.tooluniverse = None
50
- # self.tool_desc = None
51
- self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
 
52
  self.self_prompt = "Strictly follow the instruction."
53
- self.chat_prompt = "You are helpful assistant to chat with the user."
54
  self.enable_finish = enable_finish
55
  self.enable_rag = enable_rag
56
  self.enable_summary = enable_summary
@@ -145,7 +145,7 @@ class TxAgent:
145
  existing_tools_prompt=[],
146
  rag_num=5,
147
  return_call_result=False):
148
- extra_factor = 30 # Factor to retrieve more than rag_num
149
  if picked_tool_names is None:
150
  assert picked_tool_names is not None or message is not None
151
  picked_tool_names = self.rag_infer(
@@ -270,7 +270,6 @@ class TxAgent:
270
  "tool_calls": json.dumps(function_call_json)
271
  }] + call_results
272
 
273
- # Yield the final result.
274
  return revised_messages, existing_tools_prompt, special_tool_call
275
 
276
  def run_function_call_stream(self, fcall_str,
@@ -364,11 +363,10 @@ class TxAgent:
364
  else:
365
  return revised_messages, existing_tools_prompt, special_tool_call
366
 
367
-
368
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
369
- if conversation[-1]['role'] == 'assisant':
370
  conversation.append(
371
- {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
372
  finish_tools_prompt = self.add_finish_tools([])
373
 
374
  last_outputs_str = self.llm_infer(messages=conversation,
@@ -387,15 +385,6 @@ class TxAgent:
387
  max_round: int = 20,
388
  call_agent=False,
389
  call_agent_level=0) -> str:
390
- """
391
- Generate a streaming response using the llama3-8b model.
392
- Args:
393
- message (str): The input message.
394
- temperature (float): The temperature for generating the response.
395
- max_new_tokens (int): The maximum number of new tokens to generate.
396
- Returns:
397
- str: The generated response.
398
- """
399
  print("\033[1;32;40mstart\033[0m")
400
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
401
  call_agent, call_agent_level, message)
@@ -454,7 +443,6 @@ class TxAgent:
454
  if self.enable_checker:
455
  good_status, wrong_info = checker.check_conversation()
456
  if not good_status:
457
- next_round = False
458
  print(
459
  "Internal error in reasoning: " + wrong_info)
460
  break
@@ -489,7 +477,6 @@ class TxAgent:
489
  return None
490
 
491
  def build_logits_processor(self, messages, llm):
492
- # Use the tokenizer from the LLM instance.
493
  tokenizer = llm.get_tokenizer()
494
  if self.avoid_repeat and len(messages) > 2:
495
  assistant_messages = []
@@ -516,7 +503,6 @@ class TxAgent:
516
  sampling_params = SamplingParams(
517
  temperature=temperature,
518
  max_tokens=max_new_tokens,
519
-
520
  seed=seed if seed is not None else self.seed,
521
  )
522
 
@@ -527,18 +513,23 @@ class TxAgent:
527
 
528
  if check_token_status and max_token is not None:
529
  token_overflow = False
530
- num_input_tokens = len(self.tokenizer.encode(
531
- prompt, return_tensors="pt")[0])
532
- if max_token is not None:
533
- if num_input_tokens > max_token:
 
 
 
 
 
 
 
 
 
534
  torch.cuda.empty_cache()
535
  gc.collect()
536
- print("Number of input tokens before inference:",
537
- num_input_tokens)
538
- logger.info(
539
- "The number of tokens exceeds the maximum limit!!!!")
540
- token_overflow = True
541
  return None, token_overflow
 
542
  output = model.generate(
543
  prompt,
544
  sampling_params=sampling_params,
@@ -641,17 +632,6 @@ Generate **one summarized sentence** about "function calls' responses" with nece
641
  return output
642
 
643
  def function_result_summary(self, input_list, status, enable_summary):
644
- """
645
- Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
646
- Supports 'length' and 'step' modes, and skips the last 'k' groups.
647
- Parameters:
648
- input_list (list): A list of dictionaries containing role and other information.
649
- summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
650
- summary_context_length (int): The context length threshold for the 'length' mode.
651
- last_processed_index (tuple or int): The last processed index.
652
- Returns:
653
- list: A list of extracted information from valid sequences.
654
- """
655
  if 'tool_call_step' not in status:
656
  status['tool_call_step'] = 0
657
 
@@ -748,15 +728,11 @@ Generate **one summarized sentence** about "function calls' responses" with nece
748
 
749
  return status
750
 
751
- # Following are Gradio related functions
752
-
753
- # General update method that accepts any new arguments through kwargs
754
  def update_parameters(self, **kwargs):
755
  for key, value in kwargs.items():
756
  if hasattr(self, key):
757
  setattr(self, key, value)
758
 
759
- # Return the updated attributes
760
  updated_attributes = {key: value for key,
761
  value in kwargs.items() if hasattr(self, key)}
762
  return updated_attributes
@@ -795,7 +771,6 @@ Generate **one summarized sentence** about "function calls' responses" with nece
795
  return ""
796
 
797
  outputs = []
798
- outputs_str = ''
799
  last_outputs = []
800
 
801
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
@@ -867,7 +842,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
867
  if self.enable_checker:
868
  good_status, wrong_info = checker.check_conversation()
869
  if not good_status:
870
- print("Checker flagged reasoning error: ", wrong_info)
871
  break
872
 
873
  last_outputs = []
@@ -884,18 +859,11 @@ Generate **one summarized sentence** about "function calls' responses" with nece
884
  logger.debug(f"llm_infer output: {last_outputs_str[:100] if last_outputs_str else None}, token_overflow: {token_overflow}")
885
 
886
  if last_outputs_str is None:
887
- logger.warning("llm_infer returned None due to token overflow")
888
- if self.force_finish:
889
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
890
- conversation, temperature, max_new_tokens, max_token)
891
- history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
892
- yield history
893
- return last_outputs_str
894
- else:
895
- error_msg = "Token limit exceeded. Please reduce input size or increase max_token."
896
- history.append(ChatMessage(role="assistant", content=error_msg))
897
- yield history
898
- return error_msg
899
 
900
  last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
901
 
@@ -905,14 +873,12 @@ Generate **one summarized sentence** about "function calls' responses" with nece
905
 
906
  if '[FinalAnswer]' in last_thought:
907
  parts = last_thought.split('[FinalAnswer]', 1)
908
- if len(parts) == 2:
909
- final_thought, final_answer = parts
910
- else:
911
- final_thought, final_answer = last_thought, ""
912
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
913
  yield history
914
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
915
  yield history
 
916
  else:
917
  history.append(ChatMessage(role="assistant", content=last_thought))
918
  yield history
@@ -920,15 +886,13 @@ Generate **one summarized sentence** about "function calls' responses" with nece
920
  last_outputs.append(last_outputs_str)
921
 
922
  if next_round:
 
923
  if self.force_finish:
924
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
925
  conversation, temperature, max_new_tokens, max_token)
926
  if '[FinalAnswer]' in last_outputs_str:
927
  parts = last_outputs_str.split('[FinalAnswer]', 1)
928
- if len(parts) == 2:
929
- final_thought, final_answer = parts
930
- else:
931
- final_thought, final_answer = last_outputs_str, ""
932
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
933
  yield history
934
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
@@ -937,7 +901,10 @@ Generate **one summarized sentence** about "function calls' responses" with nece
937
  history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
938
  yield history
939
  else:
940
- yield "The number of reasoning rounds exceeded the limit."
 
 
 
941
 
942
  except Exception as e:
943
  logger.error(f"Exception in run_gradio_chat: {e}", exc_info=True)
@@ -949,10 +916,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
949
  conversation, temperature, max_new_tokens, max_token)
950
  if '[FinalAnswer]' in last_outputs_str:
951
  parts = last_outputs_str.split('[FinalAnswer]', 1)
952
- if len(parts) == 2:
953
- final_thought, final_answer = parts
954
- else:
955
- final_thought, final_answer = last_outputs_str, ""
956
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
957
  yield history
958
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
 
12
  from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
14
  import torch
 
15
  import logging
16
+
17
  logger = logging.getLogger(__name__)
18
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
19
 
20
  from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
 
 
22
  class TxAgent:
23
  def __init__(self, model_name,
24
  rag_model_name,
25
+ tool_files_dict=None,
26
  enable_finish=True,
27
  enable_rag=True,
28
  enable_summary=False,
 
46
  self.model = None
47
  self.rag_model = ToolRAGModel(rag_model_name)
48
  self.tooluniverse = None
49
+ self.prompt_multi_step = ("You are a helpful assistant that solves problems through detailed, step-by-step reasoning "
50
+ "and actions based on your reasoning. Provide comprehensive and clinically precise responses, "
51
+ "including specific diagnoses, tools, and actionable recommendations when analyzing medical data.")
52
  self.self_prompt = "Strictly follow the instruction."
53
+ self.chat_prompt = "You are a helpful assistant to chat with the user."
54
  self.enable_finish = enable_finish
55
  self.enable_rag = enable_rag
56
  self.enable_summary = enable_summary
 
145
  existing_tools_prompt=[],
146
  rag_num=5,
147
  return_call_result=False):
148
+ extra_factor = 30
149
  if picked_tool_names is None:
150
  assert picked_tool_names is not None or message is not None
151
  picked_tool_names = self.rag_infer(
 
270
  "tool_calls": json.dumps(function_call_json)
271
  }] + call_results
272
 
 
273
  return revised_messages, existing_tools_prompt, special_tool_call
274
 
275
  def run_function_call_stream(self, fcall_str,
 
363
  else:
364
  return revised_messages, existing_tools_prompt, special_tool_call
365
 
 
366
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
367
+ if conversation[-1]['role'] == 'assistant':
368
  conversation.append(
369
+ {'role': 'tool', 'content': 'Errors happened during the function call, please come up with the final answer with the current information.'})
370
  finish_tools_prompt = self.add_finish_tools([])
371
 
372
  last_outputs_str = self.llm_infer(messages=conversation,
 
385
  max_round: int = 20,
386
  call_agent=False,
387
  call_agent_level=0) -> str:
 
 
 
 
 
 
 
 
 
388
  print("\033[1;32;40mstart\033[0m")
389
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
390
  call_agent, call_agent_level, message)
 
443
  if self.enable_checker:
444
  good_status, wrong_info = checker.check_conversation()
445
  if not good_status:
 
446
  print(
447
  "Internal error in reasoning: " + wrong_info)
448
  break
 
477
  return None
478
 
479
  def build_logits_processor(self, messages, llm):
 
480
  tokenizer = llm.get_tokenizer()
481
  if self.avoid_repeat and len(messages) > 2:
482
  assistant_messages = []
 
503
  sampling_params = SamplingParams(
504
  temperature=temperature,
505
  max_tokens=max_new_tokens,
 
506
  seed=seed if seed is not None else self.seed,
507
  )
508
 
 
513
 
514
  if check_token_status and max_token is not None:
515
  token_overflow = False
516
+ input_tokens = self.tokenizer.encode(prompt, return_tensors="pt")[0]
517
+ num_input_tokens = len(input_tokens)
518
+ if num_input_tokens > max_token:
519
+ logger.info(f"Number of input tokens before inference: {num_input_tokens}")
520
+ logger.info("The number of tokens exceeds the maximum limit!!!!")
521
+ max_prompt_tokens = max_token - max_new_tokens - 100
522
+ if max_prompt_tokens > 0:
523
+ truncated_input = self.tokenizer.decode(input_tokens[:max_prompt_tokens])
524
+ prompt = truncated_input
525
+ logger.info(f"Prompt truncated to {len(self.tokenizer.encode(prompt, return_tensors='pt')[0])} tokens")
526
+ token_overflow = True
527
+ else:
528
+ logger.warning("Max prompt tokens too small, cannot truncate effectively")
529
  torch.cuda.empty_cache()
530
  gc.collect()
 
 
 
 
 
531
  return None, token_overflow
532
+
533
  output = model.generate(
534
  prompt,
535
  sampling_params=sampling_params,
 
632
  return output
633
 
634
  def function_result_summary(self, input_list, status, enable_summary):
 
 
 
 
 
 
 
 
 
 
 
635
  if 'tool_call_step' not in status:
636
  status['tool_call_step'] = 0
637
 
 
728
 
729
  return status
730
 
 
 
 
731
  def update_parameters(self, **kwargs):
732
  for key, value in kwargs.items():
733
  if hasattr(self, key):
734
  setattr(self, key, value)
735
 
 
736
  updated_attributes = {key: value for key,
737
  value in kwargs.items() if hasattr(self, key)}
738
  return updated_attributes
 
771
  return ""
772
 
773
  outputs = []
 
774
  last_outputs = []
775
 
776
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
 
842
  if self.enable_checker:
843
  good_status, wrong_info = checker.check_conversation()
844
  if not good_status:
845
+ logger.warning(f"Checker flagged reasoning error: {wrong_info}")
846
  break
847
 
848
  last_outputs = []
 
859
  logger.debug(f"llm_infer output: {last_outputs_str[:100] if last_outputs_str else None}, token_overflow: {token_overflow}")
860
 
861
  if last_outputs_str is None:
862
+ logger.warning("llm_infer returned None, likely due to token overflow")
863
+ error_msg = "Error: Unable to generate response due to token limit. Please reduce input size."
864
+ history.append(ChatMessage(role="assistant", content=error_msg))
865
+ yield history
866
+ return error_msg
 
 
 
 
 
 
 
867
 
868
  last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
869
 
 
873
 
874
  if '[FinalAnswer]' in last_thought:
875
  parts = last_thought.split('[FinalAnswer]', 1)
876
+ final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
 
 
 
877
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
878
  yield history
879
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
880
  yield history
881
+ next_round = False
882
  else:
883
  history.append(ChatMessage(role="assistant", content=last_thought))
884
  yield history
 
886
  last_outputs.append(last_outputs_str)
887
 
888
  if next_round:
889
+ logger.info("Max rounds reached, forcing finish")
890
  if self.force_finish:
891
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
892
  conversation, temperature, max_new_tokens, max_token)
893
  if '[FinalAnswer]' in last_outputs_str:
894
  parts = last_outputs_str.split('[FinalAnswer]', 1)
895
+ final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
 
 
 
896
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
897
  yield history
898
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
 
901
  history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
902
  yield history
903
  else:
904
+ error_msg = "The number of reasoning rounds exceeded the limit."
905
+ history.append(ChatMessage(role="assistant", content=error_msg))
906
+ yield history
907
+ return error_msg
908
 
909
  except Exception as e:
910
  logger.error(f"Exception in run_gradio_chat: {e}", exc_info=True)
 
916
  conversation, temperature, max_new_tokens, max_token)
917
  if '[FinalAnswer]' in last_outputs_str:
918
  parts = last_outputs_str.split('[FinalAnswer]', 1)
919
+ final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
 
 
 
920
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
921
  yield history
922
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))