Ali2206 commited on
Commit
8228971
·
verified ·
1 Parent(s): 0b3aa6e

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +132 -86
src/txagent/txagent.py CHANGED
@@ -67,42 +67,57 @@ class TxAgent:
67
  self.enable_checker = enable_checker
68
  self.additional_default_tools = additional_default_tools
69
  self.print_self_values()
 
70
 
71
  def init_model(self):
72
- self.load_models()
73
- self.load_tooluniverse()
74
- self.load_tool_desc_embedding()
 
 
 
 
 
 
75
 
76
  def print_self_values(self):
77
  for attr, value in self.__dict__.items():
78
- print(f"{attr}: {value}")
79
 
80
  def load_models(self, model_name=None):
81
  if model_name is not None:
82
  if model_name == self.model_name:
 
83
  return f"The model {model_name} is already loaded."
84
  self.model_name = model_name
85
 
 
86
  self.model = LLM(model=self.model_name)
87
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
88
  self.tokenizer = self.model.get_tokenizer()
89
-
90
  return f"Model {model_name} loaded successfully."
91
 
92
  def load_tooluniverse(self):
 
93
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
94
  self.tooluniverse.load_tools()
95
  special_tools = self.tooluniverse.prepare_tool_prompts(
96
  self.tooluniverse.tool_category_dicts["special_tools"])
97
  self.special_tools_name = [tool['name'] for tool in special_tools]
 
98
 
99
  def load_tool_desc_embedding(self):
 
100
  self.rag_model.load_tool_desc_embedding(self.tooluniverse)
 
101
 
102
  def rag_infer(self, query, top_k=5):
 
103
  return self.rag_model.rag_infer(query, top_k)
104
 
105
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
 
106
  picked_tools_prompt = []
107
  picked_tools_prompt = self.add_special_tools(
108
  picked_tools_prompt, call_agent=call_agent)
@@ -114,9 +129,11 @@ class TxAgent:
114
  if not call_agent:
115
  picked_tools_prompt += self.tool_RAG(
116
  message=message, rag_num=self.init_rag_num)
 
117
  return picked_tools_prompt, call_agent_level
118
 
119
  def initialize_conversation(self, message, conversation=None, history=None):
 
120
  if conversation is None:
121
  conversation = []
122
 
@@ -125,7 +142,7 @@ class TxAgent:
125
  if history is not None:
126
  if len(history) == 0:
127
  conversation = []
128
- print("clear conversation successfully")
129
  else:
130
  for i in range(len(history)):
131
  if history[i]['role'] == 'user':
@@ -139,7 +156,7 @@ class TxAgent:
139
  {"role": "assistant", "content": history[i]['content']})
140
 
141
  conversation.append({"role": "user", "content": message})
142
-
143
  return conversation
144
 
145
  def tool_RAG(self, message=None,
@@ -147,6 +164,7 @@ class TxAgent:
147
  existing_tools_prompt=[],
148
  rag_num=5,
149
  return_call_result=False):
 
150
  extra_factor = 30
151
  if picked_tool_names is None:
152
  assert picked_tool_names is not None or message is not None
@@ -165,39 +183,43 @@ class TxAgent:
165
  picked_tools)
166
  if return_call_result:
167
  return picked_tools_prompt, picked_tool_names
 
168
  return picked_tools_prompt
169
 
170
  def add_special_tools(self, tools, call_agent=False):
 
171
  if self.enable_finish:
172
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
173
  'Finish', return_prompt=True))
174
- print("Finish tool is added")
175
  if call_agent:
176
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
177
  'CallAgent', return_prompt=True))
178
- print("CallAgent tool is added")
179
  else:
180
  if self.enable_rag:
181
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
182
  'Tool_RAG', return_prompt=True))
183
- print("Tool_RAG tool is added")
184
 
185
  if self.additional_default_tools is not None:
186
  for each_tool_name in self.additional_default_tools:
187
  tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
188
  each_tool_name, return_prompt=True)
189
  if tool_prompt is not None:
190
- print(f"{each_tool_name} tool is added")
191
  tools.append(tool_prompt)
192
  return tools
193
 
194
  def add_finish_tools(self, tools):
 
195
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
196
  'Finish', return_prompt=True))
197
- print("Finish tool is added")
198
  return tools
199
 
200
  def set_system_prompt(self, conversation, sys_prompt):
 
201
  if len(conversation) == 0:
202
  conversation.append(
203
  {"role": "system", "content": sys_prompt})
@@ -213,6 +235,7 @@ class TxAgent:
213
  call_agent_level=None,
214
  temperature=None):
215
 
 
216
  function_call_json, message = self.tooluniverse.extract_function_call_json(
217
  fcall_str, return_message=return_message, verbose=False)
218
  call_results = []
@@ -220,7 +243,7 @@ class TxAgent:
220
  if function_call_json is not None:
221
  if isinstance(function_call_json, list):
222
  for i in range(len(function_call_json)):
223
- print("\033[94mTool Call:\033[0m", function_call_json[i])
224
  if function_call_json[i]["name"] == 'Finish':
225
  special_tool_call = 'Finish'
226
  break
@@ -255,7 +278,7 @@ class TxAgent:
255
 
256
  call_id = self.tooluniverse.call_id_gen()
257
  function_call_json[i]["call_id"] = call_id
258
- print("\033[94mTool Call Result:\033[0m", call_result)
259
  call_results.append({
260
  "role": "tool",
261
  "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
@@ -271,7 +294,7 @@ class TxAgent:
271
  "content": message.strip(),
272
  "tool_calls": json.dumps(function_call_json)
273
  }] + call_results
274
-
275
  return revised_messages, existing_tools_prompt, special_tool_call
276
 
277
  def run_function_call_stream(self, fcall_str,
@@ -283,7 +306,7 @@ class TxAgent:
283
  temperature=None,
284
  return_gradio_history=True):
285
 
286
- logger.debug(f"Running function call stream with input: {fcall_str[:100]}...")
287
  function_call_json, message = self.tooluniverse.extract_function_call_json(
288
  fcall_str, return_message=return_message, verbose=False)
289
  call_results = []
@@ -303,7 +326,7 @@ class TxAgent:
303
 
304
  if isinstance(function_call_json, list):
305
  for i in range(len(function_call_json)):
306
- logger.debug(f"Processing tool call: {function_call_json[i]}")
307
  if function_call_json[i]["name"] == 'Finish':
308
  special_tool_call = 'Finish'
309
  break
@@ -361,12 +384,14 @@ class TxAgent:
361
  }] + call_results
362
 
363
  if return_gradio_history:
364
- logger.debug(f"Yielding gradio history with {len(gradio_history)} entries")
365
  yield revised_messages, existing_tools_prompt or [], special_tool_call, gradio_history
366
  else:
367
  yield revised_messages, existing_tools_prompt or [], special_tool_call
 
368
 
369
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
 
370
  if conversation[-1]['role'] == 'assistant':
371
  conversation.append(
372
  {'role': 'tool', 'content': 'Errors occurred; provide a detailed final answer based on current information.'})
@@ -378,7 +403,7 @@ class TxAgent:
378
  output_begin_string='[FinalAnswer]',
379
  skip_special_tokens=True,
380
  max_new_tokens=max_new_tokens, max_token=max_token)
381
- logger.debug(f"Forced finish output: {last_outputs_str[:100]}...")
382
  return last_outputs_str
383
 
384
  def run_multistep_agent(self, message: str,
@@ -388,7 +413,7 @@ class TxAgent:
388
  max_round: int = 20,
389
  call_agent=False,
390
  call_agent_level=0) -> str:
391
- print("\033[1;32;40mstart\033[0m")
392
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
393
  call_agent, call_agent_level, message)
394
  conversation = self.initialize_conversation(message)
@@ -407,6 +432,7 @@ class TxAgent:
407
  try:
408
  while next_round and current_round < max_round:
409
  current_round += 1
 
410
  if len(outputs) > 0:
411
  function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
412
  last_outputs, return_message=True,
@@ -424,12 +450,12 @@ class TxAgent:
424
  function_call_messages[0]['content'])
425
  content = function_call_messages[0]['content']
426
  if content is None:
 
427
  return "❌ No content returned after Finish tool call."
 
428
  return content.split('[FinalAnswer]')[-1]
429
 
430
  if (self.enable_summary or token_overflow) and not call_agent:
431
- if token_overflow:
432
- print("token_overflow, using summary")
433
  enable_summary = True
434
  last_status = self.function_result_summary(
435
  conversation, status=last_status, enable_summary=enable_summary)
@@ -440,14 +466,14 @@ class TxAgent:
440
  function_call_messages))
441
  else:
442
  next_round = False
443
- conversation.extend(
444
- [{"role": "assistant", "content": ''.join(last_outputs)}])
445
- return ''.join(last_outputs).replace("</s>", "")
 
446
  if self.enable_checker:
447
  good_status, wrong_info = checker.check_conversation()
448
  if not good_status:
449
- print(
450
- "Internal error in reasoning: " + wrong_info)
451
  break
452
  last_outputs = []
453
  outputs.append("### TxAgent:\n")
@@ -458,7 +484,7 @@ class TxAgent:
458
  max_new_tokens=max_new_tokens, max_token=max_token,
459
  check_token_status=True)
460
  if last_outputs_str is None:
461
- print("The number of tokens exceeds the maximum limit.")
462
  if self.force_finish:
463
  return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
464
  else:
@@ -466,20 +492,22 @@ class TxAgent:
466
  else:
467
  last_outputs.append(last_outputs_str)
468
  if max_round == current_round:
469
- print("The number of rounds exceeds the maximum limit!")
470
  if self.force_finish:
471
  return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
472
  else:
 
473
  return None
474
 
475
  except Exception as e:
476
- print(f"Error: {e}")
477
  if self.force_finish:
478
  return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
479
  else:
480
  return None
481
 
482
  def build_logits_processor(self, messages, llm):
 
483
  tokenizer = llm.get_tokenizer()
484
  if self.avoid_repeat and len(messages) > 2:
485
  assistant_messages = []
@@ -491,14 +519,14 @@ class TxAgent:
491
  forbidden_ids = [tokenizer.encode(
492
  msg, add_special_tokens=False) for msg in assistant_messages]
493
  return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
494
- else:
495
- return None
496
 
497
  def llm_infer(self, messages, temperature=0.1, tools=None,
498
  output_begin_string=None, max_new_tokens=2048,
499
  max_token=None, skip_special_tokens=True,
500
  model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
501
 
 
502
  if model is None:
503
  model = self.model
504
 
@@ -519,16 +547,15 @@ class TxAgent:
519
  input_tokens = self.tokenizer.encode(prompt, return_tensors="pt")[0]
520
  num_input_tokens = len(input_tokens)
521
  if num_input_tokens > max_token:
522
- logger.info(f"Number of input tokens before inference: {num_input_tokens}")
523
- logger.info("The number of tokens exceeds the maximum limit!!!!")
524
  max_prompt_tokens = max_token - max_new_tokens - 100
525
  if max_prompt_tokens > 0:
526
  truncated_input = self.tokenizer.decode(input_tokens[:max_prompt_tokens])
527
  prompt = truncated_input
528
- logger.info(f"Prompt truncated to {len(self.tokenizer.encode(prompt, return_tensors='pt')[0])} tokens")
529
  token_overflow = True
530
  else:
531
- logger.warning("Max prompt tokens too small, cannot truncate effectively")
532
  torch.cuda.empty_cache()
533
  gc.collect()
534
  return None, token_overflow
@@ -538,7 +565,7 @@ class TxAgent:
538
  sampling_params=sampling_params,
539
  )
540
  output = output[0].outputs[0].text
541
- print("\033[92m" + output + "\033[0m")
542
  if check_token_status and max_token is not None:
543
  return output, token_overflow
544
 
@@ -549,7 +576,7 @@ class TxAgent:
549
  max_new_tokens: int,
550
  max_token: int) -> str:
551
 
552
- print("\033[1;32;40mstart self agent\033[0m")
553
  conversation = []
554
  conversation = self.set_system_prompt(conversation, self.self_prompt)
555
  conversation.append({"role": "user", "content": message})
@@ -563,7 +590,7 @@ class TxAgent:
563
  max_new_tokens: int,
564
  max_token: int) -> str:
565
 
566
- print("\033[1;32;40mstart chat agent\033[0m")
567
  conversation = []
568
  conversation = self.set_system_prompt(conversation, self.chat_prompt)
569
  conversation.append({"role": "user", "content": message})
@@ -578,7 +605,7 @@ class TxAgent:
578
  max_new_tokens: int,
579
  max_token: int) -> str:
580
 
581
- print("\033[1;32;40mstart format agent\033[0m")
582
  if '[FinalAnswer]' in answer:
583
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
584
  elif "\n\n" in answer:
@@ -588,12 +615,13 @@ class TxAgent:
588
  if len(possible_final_answer) == 1:
589
  choice = possible_final_answer[0]
590
  if choice in ['A', 'B', 'C', 'D', 'E']:
 
591
  return choice
592
  elif len(possible_final_answer) > 1:
593
  if possible_final_answer[1] == ':':
594
  choice = possible_final_answer[0]
595
  if choice in ['A', 'B', 'C', 'D', 'E']:
596
- print("choice", choice)
597
  return choice
598
 
599
  conversation = []
@@ -611,7 +639,7 @@ class TxAgent:
611
  temperature: float,
612
  max_new_tokens: int,
613
  max_token: int) -> str:
614
- print("\033[1;32;40mSummarized Tool Result:\033[0m")
615
  generate_tool_result_summary_training_prompt = """Thought and function calls:
616
  {thought_calls}
617
  Function calls' responses:
@@ -632,9 +660,11 @@ Generate **one summarized sentence** about "function calls' responses" with nece
632
 
633
  if '[' in output:
634
  output = output.split('[')[0]
 
635
  return output
636
 
637
  def function_result_summary(self, input_list, status, enable_summary):
 
638
  if 'tool_call_step' not in status:
639
  status['tool_call_step'] = 0
640
 
@@ -682,13 +712,14 @@ Generate **one summarized sentence** about "function calls' responses" with nece
682
  this_thought_calls = None
683
  else:
684
  if len(function_response) != 0:
685
- print("internal summary")
686
  status['summarized_step'] += 1
687
  result_summary = self.run_summary_agent(
688
  thought_calls=this_thought_calls,
689
  function_response=function_response,
690
  temperature=0.1,
691
  max_new_tokens=1024,
 
692
  max_token=99999
693
  )
694
 
@@ -729,15 +760,18 @@ Generate **one summarized sentence** about "function calls' responses" with nece
729
  last_call_idx+1, {'role': 'tool', 'content': result_summary})
730
  status['summarized_index'] = last_call_idx + 2
731
 
 
732
  return status
733
 
734
  def update_parameters(self, **kwargs):
 
735
  for key, value in kwargs.items():
736
  if hasattr(self, key):
737
  setattr(self, key, value)
738
 
739
  updated_attributes = {key: value for key,
740
  value in kwargs.items() if hasattr(self, key)}
 
741
  return updated_attributes
742
 
743
  def run_gradio_chat(self, message: str,
@@ -763,45 +797,49 @@ Generate **one summarized sentence** about "function calls' responses" with nece
763
  Returns:
764
  str: Final assistant message.
765
  """
766
- logger.debug(f"[TxAgent] Chat started, message: {message[:100]}...")
767
- print("\033[1;32;40m[TxAgent] Chat started\033[0m")
768
-
769
- if not message or len(message.strip()) < 5:
770
- yield "Please provide a valid message or upload files to analyze."
771
- return "Invalid input."
772
-
773
- if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
774
- return ""
775
-
776
- outputs = []
777
- last_outputs = []
778
 
779
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
780
- call_agent,
781
- call_agent_level,
782
- message)
783
 
784
- conversation = self.initialize_conversation(
785
- message,
786
- conversation=conversation,
787
- history=history)
788
- history = [] # Reset history to avoid duplication
 
789
 
790
- next_round = True
791
- function_call_messages = []
792
- current_round = 0
793
- enable_summary = False
794
- last_status = {}
795
- token_overflow = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
 
797
- if self.enable_checker:
798
- checker = ReasoningTraceChecker(
799
- message, conversation, init_index=len(conversation))
800
 
801
- try:
802
  while next_round and current_round < max_round:
803
  current_round += 1
804
- logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
805
 
806
  if last_outputs:
807
  function_call_result = yield from self.run_function_call_stream(
@@ -812,13 +850,11 @@ Generate **one summarized sentence** about "function calls' responses" with nece
812
  call_agent_level=call_agent_level,
813
  temperature=temperature)
814
 
815
- # Ensure function_call_result is valid
816
  if not function_call_result:
817
  logger.warning("Empty result from run_function_call_stream")
818
- error_msg = "Error: Tool call processing failed."
819
- history.append(ChatMessage(role="assistant", content=error_msg))
820
  yield history
821
- return error_msg
822
 
823
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = function_call_result
824
 
@@ -830,9 +866,11 @@ Generate **one summarized sentence** about "function calls' responses" with nece
830
  unique_history.append(msg)
831
  seen_contents.add(msg.content)
832
  history.extend(unique_history)
 
833
 
834
  if special_tool_call == 'Finish' and function_call_messages:
835
  history.append(ChatMessage(role="assistant", content=function_call_messages[0]['content']))
 
836
  yield history
837
  next_round = False
838
  conversation.extend(function_call_messages)
@@ -841,6 +879,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
841
  elif special_tool_call in ['RequireClarification', 'DirectResponse']:
842
  last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
843
  history.append(ChatMessage(role="assistant", content=last_msg.content))
 
844
  yield history
845
  next_round = False
846
  return last_msg.content
@@ -849,8 +888,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
849
  enable_summary = True
850
 
851
  last_status = self.function_result_summary(
852
- conversation, status=last_status,
853
- enable_summary=enable_summary)
854
 
855
  if function_call_messages:
856
  conversation.extend(function_call_messages)
@@ -859,13 +897,14 @@ Generate **one summarized sentence** about "function calls' responses" with nece
859
  content = ''.join(last_outputs).replace("</s>", "")
860
  history.append(ChatMessage(role="assistant", content=content))
861
  conversation.append({"role": "assistant", "content": content})
 
862
  yield history
863
  return content
864
 
865
  if self.enable_checker:
866
  good_status, wrong_info = checker.check_conversation()
867
  if not good_status:
868
- logger.warning(f"Checker flagged reasoning error: {wrong_info}")
869
  break
870
 
871
  last_outputs = []
@@ -879,10 +918,11 @@ Generate **one summarized sentence** about "function calls' responses" with nece
879
  seed=seed,
880
  check_token_status=True)
881
 
882
- logger.debug(f"llm_infer output: {last_outputs_str[:100] if last_outputs_str else None}, token_overflow: {token_overflow}")
 
883
 
884
  if last_outputs_str is None:
885
- logger.warning("llm_infer returned None, likely due to token overflow")
886
  error_msg = "Error: Unable to generate response due to token limit. Please reduce input size."
887
  history.append(ChatMessage(role="assistant", content=error_msg))
888
  yield history
@@ -899,16 +939,18 @@ Generate **one summarized sentence** about "function calls' responses" with nece
899
  final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
900
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
901
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
 
902
  yield history
903
  next_round = False
904
  else:
905
  history.append(ChatMessage(role="assistant", content=last_thought))
 
906
  yield history
907
 
908
  last_outputs.append(last_outputs_str)
909
 
910
  if next_round:
911
- logger.info("Max rounds reached, forcing finish")
912
  if self.force_finish:
913
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
914
  conversation, temperature, max_new_tokens, max_token)
@@ -919,17 +961,20 @@ Generate **one summarized sentence** about "function calls' responses" with nece
919
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
920
  else:
921
  history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
 
922
  yield history
923
  else:
924
  error_msg = "The number of reasoning rounds exceeded the limit."
925
  history.append(ChatMessage(role="assistant", content=error_msg))
 
926
  yield history
927
  return error_msg
928
 
929
  except Exception as e:
930
- logger.error(f"Exception in run_gradio_chat: {e}", exc_info=True)
931
  error_msg = f"An error occurred: {e}"
932
  history.append(ChatMessage(role="assistant", content=error_msg))
 
933
  yield history
934
  if self.force_finish:
935
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
@@ -941,5 +986,6 @@ Generate **one summarized sentence** about "function calls' responses" with nece
941
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
942
  else:
943
  history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
 
944
  yield history
945
  return error_msg
 
67
  self.enable_checker = enable_checker
68
  self.additional_default_tools = additional_default_tools
69
  self.print_self_values()
70
+ logger.info("TxAgent initialized with model_name=%s, rag_model_name=%s", model_name, rag_model_name)
71
 
72
  def init_model(self):
73
+ logger.info("Initializing model: %s", self.model_name)
74
+ try:
75
+ self.load_models()
76
+ self.load_tooluniverse()
77
+ self.load_tool_desc_embedding()
78
+ logger.info("Model initialization complete")
79
+ except Exception as e:
80
+ logger.error("Failed to initialize model: %s", e, exc_info=True)
81
+ raise
82
 
83
  def print_self_values(self):
84
  for attr, value in self.__dict__.items():
85
+ logger.debug("%s: %s", attr, value)
86
 
87
  def load_models(self, model_name=None):
88
  if model_name is not None:
89
  if model_name == self.model_name:
90
+ logger.debug("Model %s already loaded", model_name)
91
  return f"The model {model_name} is already loaded."
92
  self.model_name = model_name
93
 
94
+ logger.debug("Loading model %s", self.model_name)
95
  self.model = LLM(model=self.model_name)
96
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
97
  self.tokenizer = self.model.get_tokenizer()
98
+ logger.info("Model %s loaded successfully", self.model_name)
99
  return f"Model {model_name} loaded successfully."
100
 
101
  def load_tooluniverse(self):
102
+ logger.debug("Loading tool universe")
103
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
104
  self.tooluniverse.load_tools()
105
  special_tools = self.tooluniverse.prepare_tool_prompts(
106
  self.tooluniverse.tool_category_dicts["special_tools"])
107
  self.special_tools_name = [tool['name'] for tool in special_tools]
108
+ logger.debug("Tool universe loaded with %d special tools", len(self.special_tools_name))
109
 
110
  def load_tool_desc_embedding(self):
111
+ logger.debug("Loading tool description embeddings")
112
  self.rag_model.load_tool_desc_embedding(self.tooluniverse)
113
+ logger.debug("Tool description embeddings loaded")
114
 
115
  def rag_infer(self, query, top_k=5):
116
+ logger.debug("Running RAG inference with query: %s", query[:50])
117
  return self.rag_model.rag_infer(query, top_k)
118
 
119
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
120
+ logger.debug("Initializing tools prompt, call_agent=%s, level=%d", call_agent, call_agent_level)
121
  picked_tools_prompt = []
122
  picked_tools_prompt = self.add_special_tools(
123
  picked_tools_prompt, call_agent=call_agent)
 
129
  if not call_agent:
130
  picked_tools_prompt += self.tool_RAG(
131
  message=message, rag_num=self.init_rag_num)
132
+ logger.debug("Tools prompt initialized with %d tools", len(picked_tools_prompt))
133
  return picked_tools_prompt, call_agent_level
134
 
135
  def initialize_conversation(self, message, conversation=None, history=None):
136
+ logger.debug("Initializing conversation with message: %s", message[:50])
137
  if conversation is None:
138
  conversation = []
139
 
 
142
  if history is not None:
143
  if len(history) == 0:
144
  conversation = []
145
+ logger.debug("Cleared conversation")
146
  else:
147
  for i in range(len(history)):
148
  if history[i]['role'] == 'user':
 
156
  {"role": "assistant", "content": history[i]['content']})
157
 
158
  conversation.append({"role": "user", "content": message})
159
+ logger.debug("Conversation initialized with %d messages", len(conversation))
160
  return conversation
161
 
162
  def tool_RAG(self, message=None,
 
164
  existing_tools_prompt=[],
165
  rag_num=5,
166
  return_call_result=False):
167
+ logger.debug("Running tool RAG, message=%s, rag_num=%d", message[:50] if message else None, rag_num)
168
  extra_factor = 30
169
  if picked_tool_names is None:
170
  assert picked_tool_names is not None or message is not None
 
183
  picked_tools)
184
  if return_call_result:
185
  return picked_tools_prompt, picked_tool_names
186
+ logger.debug("Tool RAG returned %d tools", len(picked_tools_prompt))
187
  return picked_tools_prompt
188
 
189
  def add_special_tools(self, tools, call_agent=False):
190
+ logger.debug("Adding special tools, call_agent=%s", call_agent)
191
  if self.enable_finish:
192
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
193
  'Finish', return_prompt=True))
194
+ logger.debug("Finish tool added")
195
  if call_agent:
196
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
197
  'CallAgent', return_prompt=True))
198
+ logger.debug("CallAgent tool added")
199
  else:
200
  if self.enable_rag:
201
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
202
  'Tool_RAG', return_prompt=True))
203
+ logger.debug("Tool_RAG tool added")
204
 
205
  if self.additional_default_tools is not None:
206
  for each_tool_name in self.additional_default_tools:
207
  tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
208
  each_tool_name, return_prompt=True)
209
  if tool_prompt is not None:
210
+ logger.debug("%s tool added", each_tool_name)
211
  tools.append(tool_prompt)
212
  return tools
213
 
214
  def add_finish_tools(self, tools):
215
+ logger.debug("Adding finish tools")
216
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
217
  'Finish', return_prompt=True))
218
+ logger.debug("Finish tool added")
219
  return tools
220
 
221
  def set_system_prompt(self, conversation, sys_prompt):
222
+ logger.debug("Setting system prompt")
223
  if len(conversation) == 0:
224
  conversation.append(
225
  {"role": "system", "content": sys_prompt})
 
235
  call_agent_level=None,
236
  temperature=None):
237
 
238
+ logger.debug("Running function call with input: %s", fcall_str[:50])
239
  function_call_json, message = self.tooluniverse.extract_function_call_json(
240
  fcall_str, return_message=return_message, verbose=False)
241
  call_results = []
 
243
  if function_call_json is not None:
244
  if isinstance(function_call_json, list):
245
  for i in range(len(function_call_json)):
246
+ logger.debug("Tool Call: %s", function_call_json[i])
247
  if function_call_json[i]["name"] == 'Finish':
248
  special_tool_call = 'Finish'
249
  break
 
278
 
279
  call_id = self.tooluniverse.call_id_gen()
280
  function_call_json[i]["call_id"] = call_id
281
+ logger.debug("Tool Call Result: %s", call_result)
282
  call_results.append({
283
  "role": "tool",
284
  "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
 
294
  "content": message.strip(),
295
  "tool_calls": json.dumps(function_call_json)
296
  }] + call_results
297
+ logger.debug("Function call completed, returning %d messages", len(revised_messages))
298
  return revised_messages, existing_tools_prompt, special_tool_call
299
 
300
  def run_function_call_stream(self, fcall_str,
 
306
  temperature=None,
307
  return_gradio_history=True):
308
 
309
+ logger.debug("Running function call stream with input: %s", fcall_str[:50])
310
  function_call_json, message = self.tooluniverse.extract_function_call_json(
311
  fcall_str, return_message=return_message, verbose=False)
312
  call_results = []
 
326
 
327
  if isinstance(function_call_json, list):
328
  for i in range(len(function_call_json)):
329
+ logger.debug("Processing tool call: %s", function_call_json[i])
330
  if function_call_json[i]["name"] == 'Finish':
331
  special_tool_call = 'Finish'
332
  break
 
384
  }] + call_results
385
 
386
  if return_gradio_history:
387
+ logger.debug("Yielding gradio history with %d entries", len(gradio_history))
388
  yield revised_messages, existing_tools_prompt or [], special_tool_call, gradio_history
389
  else:
390
  yield revised_messages, existing_tools_prompt or [], special_tool_call
391
+ logger.debug("Function call stream completed")
392
 
393
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
394
+ logger.debug("Forcing answer due to unfinished reasoning")
395
  if conversation[-1]['role'] == 'assistant':
396
  conversation.append(
397
  {'role': 'tool', 'content': 'Errors occurred; provide a detailed final answer based on current information.'})
 
403
  output_begin_string='[FinalAnswer]',
404
  skip_special_tokens=True,
405
  max_new_tokens=max_new_tokens, max_token=max_token)
406
+ logger.debug("Forced finish output: %s", last_outputs_str[:100])
407
  return last_outputs_str
408
 
409
  def run_multistep_agent(self, message: str,
 
413
  max_round: int = 20,
414
  call_agent=False,
415
  call_agent_level=0) -> str:
416
+ logger.info("Starting multistep agent with message: %s", message[:50])
417
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
418
  call_agent, call_agent_level, message)
419
  conversation = self.initialize_conversation(message)
 
432
  try:
433
  while next_round and current_round < max_round:
434
  current_round += 1
435
+ logger.debug("Round %d", current_round)
436
  if len(outputs) > 0:
437
  function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
438
  last_outputs, return_message=True,
 
450
  function_call_messages[0]['content'])
451
  content = function_call_messages[0]['content']
452
  if content is None:
453
+ logger.warning("No content after Finish tool call")
454
  return "❌ No content returned after Finish tool call."
455
+ logger.debug("Returning final content: %s", content[:50])
456
  return content.split('[FinalAnswer]')[-1]
457
 
458
  if (self.enable_summary or token_overflow) and not call_agent:
 
 
459
  enable_summary = True
460
  last_status = self.function_result_summary(
461
  conversation, status=last_status, enable_summary=enable_summary)
 
466
  function_call_messages))
467
  else:
468
  next_round = False
469
+ content = ''.join(last_outputs).replace("</s>", "")
470
+ logger.debug("Returning content: %s", content[:50])
471
+ return content
472
+
473
  if self.enable_checker:
474
  good_status, wrong_info = checker.check_conversation()
475
  if not good_status:
476
+ logger.warning("Internal error in reasoning: %s", wrong_info)
 
477
  break
478
  last_outputs = []
479
  outputs.append("### TxAgent:\n")
 
484
  max_new_tokens=max_new_tokens, max_token=max_token,
485
  check_token_status=True)
486
  if last_outputs_str is None:
487
+ logger.warning("Token limit exceeded")
488
  if self.force_finish:
489
  return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
490
  else:
 
492
  else:
493
  last_outputs.append(last_outputs_str)
494
  if max_round == current_round:
495
+ logger.warning("Max rounds exceeded")
496
  if self.force_finish:
497
  return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
498
  else:
499
+ logger.debug("No output due to max rounds")
500
  return None
501
 
502
  except Exception as e:
503
+ logger.error("Error in multistep agent: %s", e, exc_info=True)
504
  if self.force_finish:
505
  return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
506
  else:
507
  return None
508
 
509
  def build_logits_processor(self, messages, llm):
510
+ logger.debug("Building logits processor")
511
  tokenizer = llm.get_tokenizer()
512
  if self.avoid_repeat and len(messages) > 2:
513
  assistant_messages = []
 
519
  forbidden_ids = [tokenizer.encode(
520
  msg, add_special_tokens=False) for msg in assistant_messages]
521
  return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
522
+ return None
 
523
 
524
  def llm_infer(self, messages, temperature=0.1, tools=None,
525
  output_begin_string=None, max_new_tokens=2048,
526
  max_token=None, skip_special_tokens=True,
527
  model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
528
 
529
+ logger.debug("Running LLM inference with %d messages", len(messages))
530
  if model is None:
531
  model = self.model
532
 
 
547
  input_tokens = self.tokenizer.encode(prompt, return_tensors="pt")[0]
548
  num_input_tokens = len(input_tokens)
549
  if num_input_tokens > max_token:
550
+ logger.info("Input tokens: %d, max_token: %d", num_input_tokens, max_token)
 
551
  max_prompt_tokens = max_token - max_new_tokens - 100
552
  if max_prompt_tokens > 0:
553
  truncated_input = self.tokenizer.decode(input_tokens[:max_prompt_tokens])
554
  prompt = truncated_input
555
+ logger.info("Truncated to %d tokens", len(self.tokenizer.encode(prompt, return_tensors='pt')[0]))
556
  token_overflow = True
557
  else:
558
+ logger.warning("Cannot truncate effectively")
559
  torch.cuda.empty_cache()
560
  gc.collect()
561
  return None, token_overflow
 
565
  sampling_params=sampling_params,
566
  )
567
  output = output[0].outputs[0].text
568
+ logger.debug("LLM output: %s", output[:50])
569
  if check_token_status and max_token is not None:
570
  return output, token_overflow
571
 
 
576
  max_new_tokens: int,
577
  max_token: int) -> str:
578
 
579
+ logger.info("Starting self agent with message: %s", message[:50])
580
  conversation = []
581
  conversation = self.set_system_prompt(conversation, self.self_prompt)
582
  conversation.append({"role": "user", "content": message})
 
590
  max_new_tokens: int,
591
  max_token: int) -> str:
592
 
593
+ logger.info("Starting chat agent with message: %s", message[:50])
594
  conversation = []
595
  conversation = self.set_system_prompt(conversation, self.chat_prompt)
596
  conversation.append({"role": "user", "content": message})
 
605
  max_new_tokens: int,
606
  max_token: int) -> str:
607
 
608
+ logger.info("Starting format agent")
609
  if '[FinalAnswer]' in answer:
610
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
611
  elif "\n\n" in answer:
 
615
  if len(possible_final_answer) == 1:
616
  choice = possible_final_answer[0]
617
  if choice in ['A', 'B', 'C', 'D', 'E']:
618
+ logger.debug("Returning choice: %s", choice)
619
  return choice
620
  elif len(possible_final_answer) > 1:
621
  if possible_final_answer[1] == ':':
622
  choice = possible_final_answer[0]
623
  if choice in ['A', 'B', 'C', 'D', 'E']:
624
+ logger.debug("Returning choice: %s", choice)
625
  return choice
626
 
627
  conversation = []
 
639
  temperature: float,
640
  max_new_tokens: int,
641
  max_token: int) -> str:
642
+ logger.info("Running summary agent")
643
  generate_tool_result_summary_training_prompt = """Thought and function calls:
644
  {thought_calls}
645
  Function calls' responses:
 
660
 
661
  if '[' in output:
662
  output = output.split('[')[0]
663
+ logger.debug("Summary output: %s", output)
664
  return output
665
 
666
  def function_result_summary(self, input_list, status, enable_summary):
667
+ logger.debug("Running function result summary, enable_summary=%s", enable_summary)
668
  if 'tool_call_step' not in status:
669
  status['tool_call_step'] = 0
670
 
 
712
  this_thought_calls = None
713
  else:
714
  if len(function_response) != 0:
715
+ logger.debug("Generating internal summary")
716
  status['summarized_step'] += 1
717
  result_summary = self.run_summary_agent(
718
  thought_calls=this_thought_calls,
719
  function_response=function_response,
720
  temperature=0.1,
721
  max_new_tokens=1024,
722
+ run_gradio_chat
723
  max_token=99999
724
  )
725
 
 
760
  last_call_idx+1, {'role': 'tool', 'content': result_summary})
761
  status['summarized_index'] = last_call_idx + 2
762
 
763
+ logger.debug("Function result summary completed")
764
  return status
765
 
766
  def update_parameters(self, **kwargs):
767
+ logger.debug("Updating parameters: %s", kwargs)
768
  for key, value in kwargs.items():
769
  if hasattr(self, key):
770
  setattr(self, key, value)
771
 
772
  updated_attributes = {key: value for key,
773
  value in kwargs.items() if hasattr(self, key)}
774
+ logger.debug("Updated attributes: %s", updated_attributes)
775
  return updated_attributes
776
 
777
  def run_gradio_chat(self, message: str,
 
797
  Returns:
798
  str: Final assistant message.
799
  """
800
+ logger.info("[TxAgent] Chat started with message: %s", message[:100])
801
+ logger.debug("Initial history: %s", [msg.content[:50] for msg in history] if history else [])
 
 
 
 
 
 
 
 
 
 
802
 
803
+ # Yield initial message to ensure UI updates
804
+ history.append(ChatMessage(role="assistant", content="Starting analysis..."))
805
+ yield history
806
+ logger.debug("Yielded initial history")
807
 
808
+ try:
809
+ if not message or len(message.strip()) < 5:
810
+ logger.warning("Invalid message detected")
811
+ history.append(ChatMessage(role="assistant", content="Please provide a valid message or upload files to analyze."))
812
+ yield history
813
+ return "Invalid input."
814
 
815
+ if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
816
+ logger.debug("Skipping tool-related message")
817
+ yield history
818
+ return ""
819
+
820
+ outputs = []
821
+ last_outputs = []
822
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
823
+ call_agent, call_agent_level, message)
824
+ conversation = self.initialize_conversation(
825
+ message, conversation=conversation, history=history)
826
+ history = [] # Reset history to avoid duplication
827
+ logger.debug("Conversation initialized with %d messages", len(conversation))
828
+
829
+ next_round = True
830
+ function_call_messages = []
831
+ current_round = 0
832
+ enable_summary = False
833
+ last_status = {}
834
+ token_overflow = False
835
 
836
+ if self.enable_checker:
837
+ checker = ReasoningTraceChecker(
838
+ message, conversation, init_index=len(conversation))
839
 
 
840
  while next_round and current_round < max_round:
841
  current_round += 1
842
+ logger.debug("Round %d, conversation length: %d", current_round, len(conversation))
843
 
844
  if last_outputs:
845
  function_call_result = yield from self.run_function_call_stream(
 
850
  call_agent_level=call_agent_level,
851
  temperature=temperature)
852
 
 
853
  if not function_call_result:
854
  logger.warning("Empty result from run_function_call_stream")
855
+ history.append(ChatMessage(role="assistant", content="Error: Tool call processing failed."))
 
856
  yield history
857
+ return "Error: Tool call processing failed."
858
 
859
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = function_call_result
860
 
 
866
  unique_history.append(msg)
867
  seen_contents.add(msg.content)
868
  history.extend(unique_history)
869
+ logger.debug("Extended history with %d unique messages", len(unique_history))
870
 
871
  if special_tool_call == 'Finish' and function_call_messages:
872
  history.append(ChatMessage(role="assistant", content=function_call_messages[0]['content']))
873
+ logger.debug("Yielding final history after Finish: %s", function_call_messages[0]['content'][:50])
874
  yield history
875
  next_round = False
876
  conversation.extend(function_call_messages)
 
879
  elif special_tool_call in ['RequireClarification', 'DirectResponse']:
880
  last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
881
  history.append(ChatMessage(role="assistant", content=last_msg.content))
882
+ logger.debug("Yielding history for special tool: %s", last_msg.content[:50])
883
  yield history
884
  next_round = False
885
  return last_msg.content
 
888
  enable_summary = True
889
 
890
  last_status = self.function_result_summary(
891
+ conversation, status=last_status, enable_summary=enable_summary)
 
892
 
893
  if function_call_messages:
894
  conversation.extend(function_call_messages)
 
897
  content = ''.join(last_outputs).replace("</s>", "")
898
  history.append(ChatMessage(role="assistant", content=content))
899
  conversation.append({"role": "assistant", "content": content})
900
+ logger.debug("Yielding history with content: %s", content[:50])
901
  yield history
902
  return content
903
 
904
  if self.enable_checker:
905
  good_status, wrong_info = checker.check_conversation()
906
  if not good_status:
907
+ logger.warning("Checker flagged error: %s", wrong_info)
908
  break
909
 
910
  last_outputs = []
 
918
  seed=seed,
919
  check_token_status=True)
920
 
921
+ logger.debug("llm_infer output: %s, token_overflow: %s",
922
+ last_outputs_str[:50] if last_outputs_str else None, token_overflow)
923
 
924
  if last_outputs_str is None:
925
+ logger.warning("llm_infer returned None")
926
  error_msg = "Error: Unable to generate response due to token limit. Please reduce input size."
927
  history.append(ChatMessage(role="assistant", content=error_msg))
928
  yield history
 
939
  final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
940
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
941
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
942
+ logger.debug("Yielding final analysis: %s", final_answer[:50])
943
  yield history
944
  next_round = False
945
  else:
946
  history.append(ChatMessage(role="assistant", content=last_thought))
947
+ logger.debug("Yielding intermediate history: %s", last_thought[:50])
948
  yield history
949
 
950
  last_outputs.append(last_outputs_str)
951
 
952
  if next_round:
953
+ logger.info("Max rounds reached")
954
  if self.force_finish:
955
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
956
  conversation, temperature, max_new_tokens, max_token)
 
961
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
962
  else:
963
  history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
964
+ logger.debug("Yielding forced final history")
965
  yield history
966
  else:
967
  error_msg = "The number of reasoning rounds exceeded the limit."
968
  history.append(ChatMessage(role="assistant", content=error_msg))
969
+ logger.debug("Yielding max rounds error")
970
  yield history
971
  return error_msg
972
 
973
  except Exception as e:
974
+ logger.error("Exception in run_gradio_chat: %s", e, exc_info=True)
975
  error_msg = f"An error occurred: {e}"
976
  history.append(ChatMessage(role="assistant", content=error_msg))
977
+ logger.debug("Yielding error history: %s", error_msg)
978
  yield history
979
  if self.force_finish:
980
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
 
986
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
987
  else:
988
  history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
989
+ logger.debug("Yielding forced final history after error")
990
  yield history
991
  return error_msg