Ali2206 commited on
Commit
c4e7e4a
·
verified ·
1 Parent(s): b0114f5

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +397 -579
src/txagent/txagent.py CHANGED
@@ -1,4 +1,3 @@
1
- import gradio as gr
2
  import os
3
  import sys
4
  import json
@@ -6,13 +5,13 @@ import gc
6
  import numpy as np
7
  from vllm import LLM, SamplingParams
8
  from jinja2 import Template
9
- from typing import List
10
  import types
11
  from tooluniverse import ToolUniverse
12
- from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
14
  import torch
15
  import logging
 
16
 
17
  # Configure logging with a more specific logger name
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -21,28 +20,50 @@ logger = logging.getLogger("TxAgent")
21
  from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
22
 
23
  class TxAgent:
24
- def __init__(self, model_name,
25
- rag_model_name,
26
- tool_files_dict=None,
27
- enable_finish=True,
28
- enable_rag=False,
29
- enable_summary=False,
30
- init_rag_num=0,
31
- step_rag_num=0,
32
- summary_mode='step',
33
- summary_skip_last_k=0,
34
- summary_context_length=None,
35
- force_finish=True,
36
- avoid_repeat=True,
37
- seed=None,
38
- enable_checker=False,
39
- enable_chat=False,
40
- additional_default_tools=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  self.model_name = model_name
42
  self.tokenizer = None
43
- self.terminators = None
44
  self.rag_model_name = rag_model_name
45
- self.tool_files_dict = tool_files_dict
46
  self.model = None
47
  self.rag_model = ToolRAGModel(rag_model_name)
48
  self.tooluniverse = None
@@ -61,106 +82,166 @@ class TxAgent:
61
  self.avoid_repeat = avoid_repeat
62
  self.seed = seed
63
  self.enable_checker = enable_checker
64
- self.additional_default_tools = additional_default_tools
65
  logger.info("TxAgent initialized with model: %s, RAG: %s", model_name, rag_model_name)
66
 
67
- def init_model(self):
 
68
  self.load_models()
69
  self.load_tooluniverse()
 
70
 
71
- def load_models(self, model_name=None):
 
 
 
 
 
 
 
 
 
72
  if model_name is not None:
73
  if model_name == self.model_name:
74
  return f"The model {model_name} is already loaded."
75
  self.model_name = model_name
76
 
77
- self.model = LLM(
78
- model=self.model_name,
79
- dtype="float16",
80
- max_model_len=131072,
81
- max_num_batched_tokens=65536, # Increased for A100 80GB
82
- max_num_seqs=512,
83
- gpu_memory_utilization=0.95, # Higher utilization for better performance
84
- trust_remote_code=True,
85
- )
86
- self.chat_template = Template(self.model.get_tokenizer().chat_template)
87
- self.tokenizer = self.model.get_tokenizer()
88
- logger.info(
89
- "Model %s loaded with max_model_len=%d, max_num_batched_tokens=%d, gpu_memory_utilization=%.2f",
90
- self.model_name, 131072, 32768, 0.9
91
- )
92
- return f"Model {model_name} loaded successfully."
 
 
 
 
93
 
94
- def load_tooluniverse(self):
95
- self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
96
- self.tooluniverse.load_tools()
97
- special_tools = self.tooluniverse.prepare_tool_prompts(
98
- self.tooluniverse.tool_category_dicts["special_tools"])
99
- self.special_tools_name = [tool['name'] for tool in special_tools]
100
- logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name))
 
 
 
 
 
101
 
102
- def load_tool_desc_embedding(self):
 
103
  cache_path = os.path.join(os.path.dirname(self.tool_files_dict["new_tool"]), "tool_embeddings.pkl")
104
- if os.path.exists(cache_path):
105
- self.rag_model.load_cached_embeddings(cache_path)
106
- else:
107
- self.rag_model.load_tool_desc_embedding(self.tooluniverse)
108
- self.rag_model.save_embeddings(cache_path)
109
- logger.debug("Tool description embeddings loaded")
 
 
 
 
110
 
111
- def rag_infer(self, query, top_k=5):
 
 
 
 
 
 
 
 
 
 
 
 
112
  return self.rag_model.rag_infer(query, top_k)
113
 
114
- def initialize_tools_prompt(self, call_agent, call_agent_level, message):
115
- picked_tools_prompt = []
116
- picked_tools_prompt = self.add_special_tools(
117
- picked_tools_prompt, call_agent=call_agent)
118
- if call_agent:
119
- call_agent_level += 1
120
- if call_agent_level >= 2:
121
- call_agent = False
122
- return picked_tools_prompt, call_agent_level
123
-
124
- def initialize_conversation(self, message, conversation=None, history=None):
 
 
 
 
125
  if conversation is None:
126
  conversation = []
127
 
128
- conversation = self.set_system_prompt(
129
- conversation, self.prompt_multi_step)
130
  if history:
131
- for i in range(len(history)):
132
- if history[i]['role'] == 'user':
133
- conversation.append({"role": "user", "content": history[i]['content']})
134
- elif history[i]['role'] == 'assistant':
135
- conversation.append({"role": "assistant", "content": history[i]['content']})
136
  conversation.append({"role": "user", "content": message})
137
  logger.debug("Conversation initialized with %d messages", len(conversation))
138
  return conversation
139
 
140
- def tool_RAG(self, message=None,
141
- picked_tool_names=None,
142
- existing_tools_prompt=[],
143
- rag_num=0,
144
- return_call_result=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  if not self.enable_rag:
146
- return []
 
147
  extra_factor = 10
148
  if picked_tool_names is None:
149
- assert picked_tool_names is not None or message is not None
150
- picked_tool_names = self.rag_infer(
151
- message, top_k=rag_num * extra_factor)
152
-
153
- picked_tool_names_no_special = [tool for tool in picked_tool_names if tool not in self.special_tools_name]
 
 
 
154
  picked_tool_names = picked_tool_names_no_special[:rag_num]
155
 
156
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
157
  picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
158
  logger.debug("Retrieved %d tools via RAG", len(picked_tools_prompt))
 
159
  if return_call_result:
160
  return picked_tools_prompt, picked_tool_names
161
  return picked_tools_prompt
162
 
163
- def add_special_tools(self, tools, call_agent=False):
 
164
  if self.enable_finish:
165
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
166
  logger.debug("Finish tool added")
@@ -169,25 +250,37 @@ class TxAgent:
169
  logger.debug("CallAgent tool added")
170
  return tools
171
 
172
- def add_finish_tools(self, tools):
173
- tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
174
- logger.debug("Finish tool added")
175
- return tools
176
-
177
- def set_system_prompt(self, conversation, sys_prompt):
178
  if not conversation:
179
  conversation.append({"role": "system", "content": sys_prompt})
180
  else:
181
  conversation[0] = {"role": "system", "content": sys_prompt}
182
  return conversation
183
 
184
- def run_function_call(self, fcall_str,
185
- return_message=False,
186
- existing_tools_prompt=None,
187
- message_for_call_agent=None,
188
- call_agent=False,
189
- call_agent_level=None,
190
- temperature=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  try:
192
  function_call_json, message = self.tooluniverse.extract_function_call_json(
193
  fcall_str, return_message=return_message, verbose=False)
@@ -206,13 +299,12 @@ class TxAgent:
206
  special_tool_call = 'Finish'
207
  break
208
  elif function_call_json[i]["name"] == 'CallAgent':
209
- if call_agent_level < 2 and call_agent:
210
  solution_plan = function_call_json[i]['arguments']['solution']
211
  full_message = (
212
- message_for_call_agent +
213
  "\nYou must follow the following plan to answer the question: " +
214
  str(solution_plan)
215
- )
216
  call_result = self.run_multistep_agent(
217
  full_message, temperature=temperature,
218
  max_new_tokens=512, max_token=131072,
@@ -230,7 +322,11 @@ class TxAgent:
230
  logger.info("Tool Call Result: %s", call_result)
231
  call_results.append({
232
  "role": "tool",
233
- "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
 
 
 
 
234
  })
235
  else:
236
  call_results.append({
@@ -243,109 +339,99 @@ class TxAgent:
243
  "content": message.strip(),
244
  "tool_calls": json.dumps(function_call_json)
245
  }] + call_results
246
- return revised_messages, existing_tools_prompt, special_tool_call
247
-
248
- def run_function_call_stream(self, fcall_str,
249
- return_message=False,
250
- existing_tools_prompt=None,
251
- message_for_call_agent=None,
252
- call_agent=False,
253
- call_agent_level=None,
254
- temperature=None,
255
- return_gradio_history=True):
256
- try:
257
- function_call_json, message = self.tooluniverse.extract_function_call_json(
258
- fcall_str, return_message=return_message, verbose=False)
259
- except Exception as e:
260
- logger.error("Tool call parsing failed: %s", e)
261
- function_call_json = []
262
- message = fcall_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
- call_results = []
265
- special_tool_call = ''
266
- if return_gradio_history:
267
- gradio_history = []
268
- if function_call_json:
269
- if isinstance(function_call_json, list):
270
- for i in range(len(function_call_json)):
271
- if function_call_json[i]["name"] == 'Finish':
272
- special_tool_call = 'Finish'
273
- break
274
- elif function_call_json[i]["name"] == 'DirectResponse':
275
- call_result = function_call_json[i]['arguments']['respose']
276
- special_tool_call = 'DirectResponse'
277
- elif function_call_json[i]["name"] == 'RequireClarification':
278
- call_result = function_call_json[i]['arguments']['unclear_question']
279
- special_tool_call = 'RequireClarification'
280
- elif function_call_json[i]["name"] == 'CallAgent':
281
- if call_agent_level < 2 and call_agent:
282
- solution_plan = function_call_json[i]['arguments']['solution']
283
- full_message = (
284
- message_for_call_agent +
285
- "\nYou must follow the following plan to answer the question: " +
286
- str(solution_plan)
287
- )
288
- sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
289
- call_result = yield from self.run_gradio_chat(
290
- full_message, history=[], temperature=temperature,
291
- max_new_tokens=512, max_token=131072,
292
- call_agent=False, call_agent_level=call_agent_level,
293
- conversation=None, sub_agent_task=sub_agent_task)
294
- if call_result is not None and isinstance(call_result, str):
295
- call_result = call_result.split('[FinalAnswer]')[-1]
296
- else:
297
- call_result = "⚠️ No content returned from sub-agent."
298
- else:
299
- call_result = "Error: CallAgent disabled."
300
- else:
301
- call_result = self.tooluniverse.run_one_function(function_call_json[i])
302
- call_id = self.tooluniverse.call_id_gen()
303
- function_call_json[i]["call_id"] = call_id
304
- call_results.append({
305
- "role": "tool",
306
- "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
307
- })
308
- if return_gradio_history and function_call_json[i]["name"] != 'Finish':
309
- metadata = {"title": f"🧰 {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
310
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata=metadata))
311
- else:
312
- call_results.append({
313
- "role": "tool",
314
- "content": json.dumps({"content": "Invalid or no function call detected."})
315
- })
316
 
317
- revised_messages = [{
318
- "role": "assistant",
319
- "content": message.strip(),
320
- "tool_calls": json.dumps(function_call_json)
321
- }] + call_results
322
- if return_gradio_history:
323
- return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
324
- return revised_messages, existing_tools_prompt, special_tool_call
325
 
326
- def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
327
- if conversation[-1]['role'] == 'assistant':
328
- conversation.append(
329
- {'role': 'tool', 'content': 'Errors occurred during function call; provide final answer with current information.'})
330
- finish_tools_prompt = self.add_finish_tools([])
331
- last_outputs_str = self.llm_infer(
332
- messages=conversation,
333
- temperature=temperature,
334
- tools=finish_tools_prompt,
335
- output_begin_string='[FinalAnswer]',
336
- skip_special_tokens=True,
337
- max_new_tokens=max_new_tokens,
338
- max_token=max_token)
339
- logger.info("Unfinished reasoning answer: %s", last_outputs_str[:100])
340
- return last_outputs_str
341
 
342
- def run_multistep_agent(self, message: str,
343
- temperature: float,
344
- max_new_tokens: int,
345
- max_token: int,
346
- max_round: int = 5,
347
- call_agent=False,
348
- call_agent_level=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  logger.info("Starting multistep agent for message: %s", message[:100])
350
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
351
  call_agent, call_agent_level, message)
@@ -415,196 +501,109 @@ class TxAgent:
415
  conversation, temperature, max_new_tokens, max_token)
416
  return None
417
 
418
- def build_logits_processor(self, messages, llm):
419
- logger.warning("Logits processor disabled due to vLLM V1 limitation")
420
- return None
421
-
422
- def llm_infer(self, messages, temperature=0.1, tools=None,
423
- output_begin_string=None, max_new_tokens=512,
424
- max_token=131072, skip_special_tokens=True,
425
- model=None, tokenizer=None, terminators=None,
426
- seed=None, check_token_status=False):
427
- if model is None:
428
- model = self.model
429
-
430
- logits_processor = self.build_logits_processor(messages, model)
431
- sampling_params = SamplingParams(
432
- temperature=temperature,
433
- max_tokens=max_new_tokens,
434
- seed=seed if seed is not None else self.seed,
435
- )
436
-
437
- prompt = self.chat_template.render(
438
- messages=messages, tools=tools, add_generation_prompt=True)
439
- if output_begin_string is not None:
440
- prompt += output_begin_string
441
-
442
- if check_token_status and max_token is not None:
443
- token_overflow = False
444
- num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False))
445
- logger.info("Input prompt tokens: %d, max_token: %d", num_input_tokens, max_token)
446
- if num_input_tokens > max_token:
447
- torch.cuda.empty_cache()
448
- gc.collect()
449
- logger.warning("Token overflow: %d > %d", num_input_tokens, max_token)
450
- return None, True
451
 
452
- output = model.generate(prompt, sampling_params=sampling_params)
453
- output_text = output[0].outputs[0].text
454
- output_tokens = len(self.tokenizer.encode(output_text, add_special_tokens=False))
455
- logger.debug("Inference output: %s (output tokens: %d)", output_text[:100], output_tokens)
456
- torch.cuda.empty_cache()
457
- gc.collect()
458
- if check_token_status and max_token is not None:
459
- return output_text, token_overflow
460
- return output_text
461
-
462
- def run_self_agent(self, message: str,
463
- temperature: float,
464
- max_new_tokens: int,
465
- max_token: int):
466
- logger.info("Starting self agent")
467
- conversation = self.set_system_prompt([], self.self_prompt)
468
- conversation.append({"role": "user", "content": message})
469
- return self.llm_infer(
470
- messages=conversation,
471
- temperature=temperature,
472
- tools=None,
473
- max_new_tokens=max_new_tokens,
474
- max_token=max_token)
475
 
476
- def run_chat_agent(self, message: str,
477
- temperature: float,
478
- max_new_tokens: int,
479
- max_token: int):
480
- logger.info("Starting chat agent")
481
- conversation = self.set_system_prompt([], self.chat_prompt)
482
- conversation.append({"role": "user", "content": message})
483
- return self.llm_infer(
484
- messages=conversation,
485
- temperature=temperature,
486
- tools=None,
487
- max_new_tokens=max_new_tokens,
488
- max_token=max_token)
489
 
490
- def run_format_agent(self, message: str,
491
- answer: str,
492
- temperature: float,
493
- max_new_tokens: int,
494
- max_token: int):
495
- logger.info("Starting format agent")
496
- if '[FinalAnswer]' in answer:
497
- possible_final_answer = answer.split("[FinalAnswer]")[-1]
498
- elif "\n\n" in answer:
499
- possible_final_answer = answer.split("\n\n")[-1]
500
- else:
501
- possible_final_answer = answer.strip()
502
- if len(possible_final_answer) == 1 and possible_final_answer in ['A', 'B', 'C', 'D', 'E']:
503
- return possible_final_answer
504
- elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
505
- return possible_final_answer[0]
506
-
507
- conversation = self.set_system_prompt(
508
- [], "Transform the agent's answer to a single letter: 'A', 'B', 'C', 'D'.")
509
- conversation.append({"role": "user", "content": message +
510
- "\nAgent's answer: " + answer + "\nAnswer (must be a letter):"})
511
- return self.llm_infer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  messages=conversation,
513
  temperature=temperature,
514
- tools=None,
 
 
515
  max_new_tokens=max_new_tokens,
516
  max_token=max_token)
 
 
517
 
518
- def run_summary_agent(self, thought_calls: str,
519
- function_response: str,
520
- temperature: float,
521
- max_new_tokens: int,
522
- max_token: int):
523
- logger.info("Summarizing tool result")
524
- prompt = f"""Thought and function calls:
525
- {thought_calls}
526
- Function calls' responses:
527
- \"\"\"
528
- {function_response}
529
- \"\"\"
530
- Summarize the function calls' l responses in one sentence with all necessary information.
531
- """
532
- conversation = [{"role": "user", "content": prompt}]
533
- output = self.llm_infer(
534
- messages=conversation,
535
- temperature=temperature,
536
- tools=None,
537
- max_new_tokens=max_new_tokens,
538
- max_token=max_token)
539
- if '[' in output:
540
- output = output.split('[')[0]
541
- return output
542
-
543
- def function_result_summary(self, input_list, status, enable_summary):
544
- if 'tool_call_step' not in status:
545
- status['tool_call_step'] = 0
546
- for idx in range(len(input_list)):
547
- pos_id = len(input_list) - idx - 1
548
- if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]:
549
- break
550
-
551
- status['step'] = status.get('step', 0) + 1
552
- if not enable_summary:
553
- return status
554
-
555
- status['summarized_index'] = status.get('summarized_index', 0)
556
- status['summarized_step'] = status.get('summarized_step', 0)
557
- status['previous_length'] = status.get('previous_length', 0)
558
- status['history'] = status.get('history', [])
559
-
560
- function_response = ''
561
- idx = status['summarized_index']
562
- this_thought_calls = None
563
-
564
- while idx < len(input_list):
565
- if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \
566
- (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
567
- if input_list[idx]['role'] == 'assistant':
568
- if function_response:
569
- status['summarized_step'] += 1
570
- result_summary = self.run_summary_agent(
571
- thought_calls=this_thought_calls,
572
- function_response=function_response,
573
- temperature=0.1,
574
- max_new_tokens=512,
575
- max_token=131072)
576
- input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
577
- status['summarized_index'] = last_call_idx + 2
578
- idx += 1
579
- last_call_idx = idx
580
- this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls']
581
- function_response = ''
582
- elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
583
- function_response += input_list[idx]['content']
584
- del input_list[idx]
585
- idx -= 1
586
- else:
587
- break
588
- idx += 1
589
-
590
- if function_response:
591
- status['summarized_step'] += 1
592
- result_summary = self.run_summary_agent(
593
- thought_calls=this_thought_calls,
594
- function_response=function_response,
595
- temperature=0.1,
596
- max_new_tokens=512,
597
- max_token=131072)
598
- tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
599
- for tool_call in tool_calls:
600
- del tool_call['call_id']
601
- input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
602
- input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
603
- status['summarized_index'] = last_call_idx + 2
604
-
605
- return status
606
-
607
- def update_parameters(self, **kwargs):
608
  updated_attributes = {}
609
  for key, value in kwargs.items():
610
  if hasattr(self, key):
@@ -613,199 +612,18 @@ Summarize the function calls' l responses in one sentence with all necessary inf
613
  logger.info("Updated parameters: %s", updated_attributes)
614
  return updated_attributes
615
 
616
- def run_gradio_chat(self, message: str,
617
- history: list,
618
- temperature: float,
619
- max_new_tokens: int = 2048,
620
- max_token: int = 131072,
621
- call_agent: bool = False,
622
- conversation: gr.State = None,
623
- max_round: int = 5,
624
- seed: int = None,
625
- call_agent_level: int = 0,
626
- sub_agent_task: str = None,
627
- uploaded_files: list = None):
628
- logger.info("Chat started, message: %s", message[:100])
629
- if not message or len(message.strip()) < 5:
630
- yield "Please provide a valid message or upload files to analyze."
631
- return
632
-
633
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
634
- call_agent, call_agent_level, message)
635
- conversation = self.initialize_conversation(
636
- message, conversation, history)
637
- history = []
638
- last_outputs = []
639
-
640
- next_round = True
641
- current_round = 0
642
- enable_summary = False
643
- last_status = {}
644
- token_overflow = False
645
-
646
- try:
647
- while next_round and current_round < max_round:
648
- current_round += 1
649
- logger.debug("Starting round %d/%d", current_round, max_round)
650
- if last_outputs:
651
- function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
652
- last_outputs, return_message=True,
653
- existing_tools_prompt=picked_tools_prompt,
654
- message_for_call_agent=message,
655
- call_agent=call_agent,
656
- call_agent_level=call_agent_level,
657
- temperature=temperature)
658
- history.extend(current_gradio_history)
659
-
660
- if special_tool_call == 'Finish':
661
- logger.info("Finish tool called, ending chat")
662
- yield history
663
- next_round = False
664
- conversation.extend(function_call_messages)
665
- content = function_call_messages[0]['content']
666
- if content:
667
- return content
668
- return "No content returned after Finish tool call."
669
-
670
- elif special_tool_call in ['RequireClarification', 'DirectResponse']:
671
- last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
672
- history.append(ChatMessage(role="assistant", content=last_msg.content))
673
- logger.info("Special tool %s called, ending chat", special_tool_call)
674
- yield history
675
- next_round = False
676
- return last_msg.content
677
-
678
- if (self.enable_summary or token_overflow) and not call_agent:
679
- enable_summary = True
680
- last_status = self.function_result_summary(
681
- conversation, status=last_status, enable_summary=enable_summary)
682
-
683
- if function_call_messages:
684
- conversation.extend(function_call_messages)
685
- yield history
686
- else:
687
- next_round = False
688
- conversation.append({"role": "assistant", "content": ''.join(last_outputs)})
689
- logger.info("No function call messages, ending chat")
690
- return ''.join(last_outputs).replace("</s>", "")
691
-
692
- last_outputs = []
693
- last_outputs_str, token_overflow = self.llm_infer(
694
- messages=conversation,
695
- temperature=temperature,
696
- tools=picked_tools_prompt,
697
- skip_special_tokens=False,
698
- max_new_tokens=max_new_tokens,
699
- max_token=max_token,
700
- seed=seed,
701
- check_token_status=True)
702
-
703
- if last_outputs_str is None:
704
- logger.warning("Token limit exceeded")
705
- if self.force_finish:
706
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
707
- conversation, temperature, max_new_tokens, max_token)
708
- history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
709
- yield history
710
- return last_outputs_str
711
- error_msg = "Token limit exceeded."
712
- history.append(ChatMessage(role="assistant", content=error_msg))
713
- yield history
714
- return error_msg
715
-
716
- last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
717
- for msg in history:
718
- if msg.metadata is not None:
719
- msg.metadata['status'] = 'done'
720
-
721
- if '[FinalAnswer]' in last_thought:
722
- parts = last_thought.split('[FinalAnswer]', 1)
723
- final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
724
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
725
- yield history
726
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
727
- logger.info("Final answer provided: %s", final_answer[:100])
728
- yield history
729
- next_round = False # Ensure we exit after final answer
730
- return final_answer
731
- else:
732
- history.append(ChatMessage(role="assistant", content=last_thought))
733
- yield history
734
-
735
- last_outputs.append(last_outputs_str)
736
-
737
- if next_round:
738
- if self.force_finish:
739
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
740
- conversation, temperature, max_new_tokens, max_token)
741
- parts = last_outputs_str.split('[FinalAnswer]', 1)
742
- final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
743
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
744
- yield history
745
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
746
- logger.info("Forced final answer: %s", final_answer[:100])
747
- yield history
748
- return final_answer
749
- else:
750
- error_msg = "Reasoning rounds exceeded limit."
751
- history.append(ChatMessage(role="assistant", content=error_msg))
752
- yield history
753
- return error_msg
754
 
755
- except Exception as e:
756
- logger.error("Exception in run_gradio_chat: %s", e, exc_info=True)
757
- error_msg = f"Error: {e}"
758
- history.append(ChatMessage(role="assistant", content=error_msg))
759
- yield history
760
- if self.force_finish:
761
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
762
- conversation, temperature, max_new_tokens, max_token)
763
- parts = last_outputs_str.split('[FinalAnswer]', 1)
764
- final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
765
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
766
- yield history
767
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
768
- logger.info("Forced final answer after error: %s", final_answer[:100])
769
- yield history
770
- return final_answer
771
- return error_msg
772
-
773
- def run_gradio_chat_batch(self, messages: List[str],
774
- temperature: float,
775
- max_new_tokens: int = 2048,
776
- max_token: int = 131072,
777
- call_agent: bool = False,
778
- conversation: List = None,
779
- max_round: int = 5,
780
- seed: int = None,
781
- call_agent_level: int = 0):
782
- """Run batch inference for multiple messages."""
783
- logger.info("Starting batch chat for %d messages", len(messages))
784
- batch_results = []
785
-
786
- for message in messages:
787
- # Initialize conversation for each message
788
- conv = self.initialize_conversation(message, conversation, history=None)
789
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
790
- call_agent, call_agent_level, message)
791
-
792
- # Run single inference for simplicity (extend for multi-round if needed)
793
- output, token_overflow = self.llm_infer(
794
- messages=conv,
795
- temperature=temperature,
796
- tools=picked_tools_prompt,
797
- max_new_tokens=max_new_tokens,
798
- max_token=max_token,
799
- skip_special_tokens=False,
800
- seed=seed,
801
- check_token_status=True
802
- )
803
-
804
- if output is None:
805
- logger.warning("Token limit exceeded for message: %s", message[:100])
806
- batch_results.append("Token limit exceeded.")
807
- else:
808
- batch_results.append(output)
809
-
810
- logger.info("Batch chat completed for %d messages", len(messages))
811
- return batch_results
 
 
1
  import os
2
  import sys
3
  import json
 
5
  import numpy as np
6
  from vllm import LLM, SamplingParams
7
  from jinja2 import Template
8
+ from typing import List, Dict, Optional, Union, Generator
9
  import types
10
  from tooluniverse import ToolUniverse
 
11
  from .toolrag import ToolRAGModel
12
  import torch
13
  import logging
14
+ from datetime import datetime
15
 
16
  # Configure logging with a more specific logger name
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
20
  from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
 
22
  class TxAgent:
23
+ def __init__(self,
24
+ model_name: str,
25
+ rag_model_name: str,
26
+ tool_files_dict: Optional[Dict] = None,
27
+ enable_finish: bool = True,
28
+ enable_rag: bool = False,
29
+ enable_summary: bool = False,
30
+ init_rag_num: int = 0,
31
+ step_rag_num: int = 0,
32
+ summary_mode: str = 'step',
33
+ summary_skip_last_k: int = 0,
34
+ summary_context_length: Optional[int] = None,
35
+ force_finish: bool = True,
36
+ avoid_repeat: bool = True,
37
+ seed: Optional[int] = None,
38
+ enable_checker: bool = False,
39
+ enable_chat: bool = False,
40
+ additional_default_tools: Optional[List] = None):
41
+ """
42
+ Initialize the TxAgent with specified configuration.
43
+
44
+ Args:
45
+ model_name: Name of the main LLM model
46
+ rag_model_name: Name of the RAG model
47
+ tool_files_dict: Dictionary of tool files
48
+ enable_finish: Whether to enable the Finish tool
49
+ enable_rag: Whether to enable RAG functionality
50
+ enable_summary: Whether to enable summarization
51
+ init_rag_num: Initial number of RAG tools to retrieve
52
+ step_rag_num: Number of RAG tools to retrieve per step
53
+ summary_mode: Mode for summarization ('step' or 'length')
54
+ summary_skip_last_k: Number of last steps to skip in summarization
55
+ summary_context_length: Context length threshold for summarization
56
+ force_finish: Whether to force finish when max rounds reached
57
+ avoid_repeat: Whether to avoid repeating similar responses
58
+ seed: Random seed for reproducibility
59
+ enable_checker: Whether to enable reasoning trace checker
60
+ enable_chat: Whether to enable chat mode
61
+ additional_default_tools: Additional tools to include by default
62
+ """
63
  self.model_name = model_name
64
  self.tokenizer = None
 
65
  self.rag_model_name = rag_model_name
66
+ self.tool_files_dict = tool_files_dict or {}
67
  self.model = None
68
  self.rag_model = ToolRAGModel(rag_model_name)
69
  self.tooluniverse = None
 
82
  self.avoid_repeat = avoid_repeat
83
  self.seed = seed
84
  self.enable_checker = enable_checker
85
+ self.additional_default_tools = additional_default_tools or []
86
  logger.info("TxAgent initialized with model: %s, RAG: %s", model_name, rag_model_name)
87
 
88
+ def init_model(self) -> None:
89
+ """Initialize both the main model and tool universe."""
90
  self.load_models()
91
  self.load_tooluniverse()
92
+ logger.info("Model and tools initialized successfully")
93
 
94
+ def load_models(self, model_name: Optional[str] = None) -> str:
95
+ """
96
+ Load the specified model or the default model if none specified.
97
+
98
+ Args:
99
+ model_name: Name of the model to load
100
+
101
+ Returns:
102
+ Status message indicating success or failure
103
+ """
104
  if model_name is not None:
105
  if model_name == self.model_name:
106
  return f"The model {model_name} is already loaded."
107
  self.model_name = model_name
108
 
109
+ try:
110
+ self.model = LLM(
111
+ model=self.model_name,
112
+ dtype="float16",
113
+ max_model_len=131072,
114
+ max_num_batched_tokens=65536,
115
+ max_num_seqs=512,
116
+ gpu_memory_utilization=0.95,
117
+ trust_remote_code=True,
118
+ )
119
+ self.tokenizer = self.model.get_tokenizer()
120
+ self.chat_template = Template(self.tokenizer.chat_template)
121
+ logger.info(
122
+ "Model %s loaded with max_model_len=%d, max_num_batched_tokens=%d",
123
+ self.model_name, 131072, 65536
124
+ )
125
+ return f"Model {model_name} loaded successfully."
126
+ except Exception as e:
127
+ logger.error("Failed to load model: %s", str(e))
128
+ raise RuntimeError(f"Failed to load model: {str(e)}")
129
 
130
+ def load_tooluniverse(self) -> None:
131
+ """Load and initialize the tool universe with specified tools."""
132
+ try:
133
+ self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
134
+ self.tooluniverse.load_tools()
135
+ special_tools = self.tooluniverse.prepare_tool_prompts(
136
+ self.tooluniverse.tool_category_dicts["special_tools"])
137
+ self.special_tools_name = [tool['name'] for tool in special_tools]
138
+ logger.info("ToolUniverse loaded with %d special tools", len(self.special_tools_name))
139
+ except Exception as e:
140
+ logger.error("Failed to load tools: %s", str(e))
141
+ raise RuntimeError(f"Failed to load tools: {str(e)}")
142
 
143
+ def load_tool_desc_embedding(self) -> None:
144
+ """Load tool description embeddings from cache or generate new ones."""
145
  cache_path = os.path.join(os.path.dirname(self.tool_files_dict["new_tool"]), "tool_embeddings.pkl")
146
+ try:
147
+ if os.path.exists(cache_path):
148
+ self.rag_model.load_cached_embeddings(cache_path)
149
+ else:
150
+ self.rag_model.load_tool_desc_embedding(self.tooluniverse)
151
+ self.rag_model.save_embeddings(cache_path)
152
+ logger.info("Tool description embeddings loaded successfully")
153
+ except Exception as e:
154
+ logger.error("Failed to load tool embeddings: %s", str(e))
155
+ raise RuntimeError(f"Failed to load tool embeddings: {str(e)}")
156
 
157
+ def rag_infer(self, query: str, top_k: int = 5) -> List[str]:
158
+ """
159
+ Perform RAG inference to retrieve relevant tools.
160
+
161
+ Args:
162
+ query: The query to search for
163
+ top_k: Number of top results to return
164
+
165
+ Returns:
166
+ List of relevant tool names
167
+ """
168
+ if not self.enable_rag:
169
+ return []
170
  return self.rag_model.rag_infer(query, top_k)
171
 
172
+ def initialize_conversation(self,
173
+ message: str,
174
+ conversation: Optional[List[Dict]] = None,
175
+ history: Optional[List[Dict]] = None) -> List[Dict]:
176
+ """
177
+ Initialize or extend a conversation with the given message and history.
178
+
179
+ Args:
180
+ message: The new message to add
181
+ conversation: Existing conversation to extend
182
+ history: Chat history to incorporate
183
+
184
+ Returns:
185
+ Updated conversation list
186
+ """
187
  if conversation is None:
188
  conversation = []
189
 
190
+ conversation = self.set_system_prompt(conversation, self.prompt_multi_step)
 
191
  if history:
192
+ for msg in history:
193
+ if msg['role'] == 'user':
194
+ conversation.append({"role": "user", "content": msg['content']})
195
+ elif msg['role'] == 'assistant':
196
+ conversation.append({"role": "assistant", "content": msg['content']})
197
  conversation.append({"role": "user", "content": message})
198
  logger.debug("Conversation initialized with %d messages", len(conversation))
199
  return conversation
200
 
201
+ def tool_RAG(self,
202
+ message: Optional[str] = None,
203
+ picked_tool_names: Optional[List[str]] = None,
204
+ existing_tools_prompt: List = [],
205
+ rag_num: int = 0,
206
+ return_call_result: bool = False) -> Union[List, Tuple[List, List]]:
207
+ """
208
+ Retrieve relevant tools using RAG.
209
+
210
+ Args:
211
+ message: The query message for RAG
212
+ picked_tool_names: Pre-selected tool names
213
+ existing_tools_prompt: Existing tools to include
214
+ rag_num: Number of tools to retrieve
215
+ return_call_result: Whether to return tool names
216
+
217
+ Returns:
218
+ List of tool prompts or tuple with tool names if return_call_result is True
219
+ """
220
  if not self.enable_rag:
221
+ return [] if not return_call_result else ([], [])
222
+
223
  extra_factor = 10
224
  if picked_tool_names is None:
225
+ if message is None:
226
+ raise ValueError("Either message or picked_tool_names must be provided")
227
+ picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor)
228
+
229
+ picked_tool_names_no_special = [
230
+ tool for tool in picked_tool_names
231
+ if tool not in self.special_tools_name
232
+ ]
233
  picked_tool_names = picked_tool_names_no_special[:rag_num]
234
 
235
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
236
  picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
237
  logger.debug("Retrieved %d tools via RAG", len(picked_tools_prompt))
238
+
239
  if return_call_result:
240
  return picked_tools_prompt, picked_tool_names
241
  return picked_tools_prompt
242
 
243
+ def add_special_tools(self, tools: List, call_agent: bool = False) -> List:
244
+ """Add special tools (Finish and optionally CallAgent) to the tools list."""
245
  if self.enable_finish:
246
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
247
  logger.debug("Finish tool added")
 
250
  logger.debug("CallAgent tool added")
251
  return tools
252
 
253
+ def set_system_prompt(self, conversation: List[Dict], sys_prompt: str) -> List[Dict]:
254
+ """Set or update the system prompt in the conversation."""
 
 
 
 
255
  if not conversation:
256
  conversation.append({"role": "system", "content": sys_prompt})
257
  else:
258
  conversation[0] = {"role": "system", "content": sys_prompt}
259
  return conversation
260
 
261
+ def run_function_call(self,
262
+ fcall_str: str,
263
+ return_message: bool = False,
264
+ existing_tools_prompt: Optional[List] = None,
265
+ message_for_call_agent: Optional[str] = None,
266
+ call_agent: bool = False,
267
+ call_agent_level: Optional[int] = None,
268
+ temperature: Optional[float] = None) -> Tuple[List[Dict], List, str]:
269
+ """
270
+ Execute function calls from the model's output.
271
+
272
+ Args:
273
+ fcall_str: The function call string from the model
274
+ return_message: Whether to return the message part
275
+ existing_tools_prompt: Existing tools to consider
276
+ message_for_call_agent: Original message for CallAgent
277
+ call_agent: Whether CallAgent is enabled
278
+ call_agent_level: Current CallAgent level
279
+ temperature: Temperature for sub-agent calls
280
+
281
+ Returns:
282
+ Tuple of (revised_messages, tools_prompt, special_tool_call)
283
+ """
284
  try:
285
  function_call_json, message = self.tooluniverse.extract_function_call_json(
286
  fcall_str, return_message=return_message, verbose=False)
 
299
  special_tool_call = 'Finish'
300
  break
301
  elif function_call_json[i]["name"] == 'CallAgent':
302
+ if call_agent_level is not None and call_agent_level < 2 and call_agent:
303
  solution_plan = function_call_json[i]['arguments']['solution']
304
  full_message = (
305
+ (message_for_call_agent or "") +
306
  "\nYou must follow the following plan to answer the question: " +
307
  str(solution_plan)
 
308
  call_result = self.run_multistep_agent(
309
  full_message, temperature=temperature,
310
  max_new_tokens=512, max_token=131072,
 
322
  logger.info("Tool Call Result: %s", call_result)
323
  call_results.append({
324
  "role": "tool",
325
+ "content": json.dumps({
326
+ "tool_name": function_call_json[i]["name"],
327
+ "content": call_result,
328
+ "call_id": call_id
329
+ })
330
  })
331
  else:
332
  call_results.append({
 
339
  "content": message.strip(),
340
  "tool_calls": json.dumps(function_call_json)
341
  }] + call_results
342
+ return revised_messages, existing_tools_prompt or [], special_tool_call
343
+
344
+ def llm_infer(self,
345
+ messages: List[Dict],
346
+ temperature: float = 0.1,
347
+ tools: Optional[List] = None,
348
+ output_begin_string: Optional[str] = None,
349
+ max_new_tokens: int = 512,
350
+ max_token: int = 131072,
351
+ skip_special_tokens: bool = True,
352
+ model: Optional[LLM] = None,
353
+ check_token_status: bool = False) -> Union[str, Tuple[str, bool]]:
354
+ """
355
+ Perform inference using the LLM.
356
+
357
+ Args:
358
+ messages: Conversation history
359
+ temperature: Sampling temperature
360
+ tools: List of tools to include
361
+ output_begin_string: Prefix for output
362
+ max_new_tokens: Maximum new tokens to generate
363
+ max_token: Maximum total tokens allowed
364
+ skip_special_tokens: Whether to skip special tokens
365
+ model: Optional custom model to use
366
+ check_token_status: Whether to check token limits
367
+
368
+ Returns:
369
+ Generated text or tuple with text and overflow flag if check_token_status
370
+ """
371
+ model = model or self.model
372
+ tokenizer = self.tokenizer
373
 
374
+ sampling_params = SamplingParams(
375
+ temperature=temperature,
376
+ max_tokens=max_new_tokens,
377
+ seed=self.seed,
378
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
+ prompt = self.chat_template.render(
381
+ messages=messages, tools=tools, add_generation_prompt=True)
382
+ if output_begin_string is not None:
383
+ prompt += output_begin_string
 
 
 
 
384
 
385
+ token_overflow = False
386
+ if check_token_status and max_token is not None:
387
+ num_input_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
388
+ logger.info("Input prompt tokens: %d, max_token: %d", num_input_tokens, max_token)
389
+ if num_input_tokens > max_token:
390
+ torch.cuda.empty_cache()
391
+ gc.collect()
392
+ logger.warning("Token overflow: %d > %d", num_input_tokens, max_token)
393
+ return (None, True) if check_token_status else None
 
 
 
 
 
 
394
 
395
+ try:
396
+ output = model.generate(prompt, sampling_params=sampling_params)
397
+ output_text = output[0].outputs[0].text
398
+ output_tokens = len(tokenizer.encode(output_text, add_special_tokens=False))
399
+ logger.debug("Inference output: %s (output tokens: %d)", output_text[:100], output_tokens)
400
+
401
+ if skip_special_tokens:
402
+ output_text = output_text.replace("</s>", "").strip()
403
+
404
+ torch.cuda.empty_cache()
405
+ gc.collect()
406
+
407
+ return (output_text, token_overflow) if check_token_status else output_text
408
+ except Exception as e:
409
+ logger.error("Inference failed: %s", str(e))
410
+ raise RuntimeError(f"Inference failed: {str(e)}")
411
+
412
+ def run_multistep_agent(self,
413
+ message: str,
414
+ temperature: float,
415
+ max_new_tokens: int,
416
+ max_token: int,
417
+ max_round: int = 5,
418
+ call_agent: bool = False,
419
+ call_agent_level: int = 0) -> Optional[str]:
420
+ """
421
+ Run multi-step reasoning with the agent.
422
+
423
+ Args:
424
+ message: Input message
425
+ temperature: Sampling temperature
426
+ max_new_tokens: Max new tokens per step
427
+ max_token: Max total tokens
428
+ max_round: Maximum reasoning rounds
429
+ call_agent: Whether to enable CallAgent
430
+ call_agent_level: Current CallAgent level
431
+
432
+ Returns:
433
+ Final answer or None if failed
434
+ """
435
  logger.info("Starting multistep agent for message: %s", message[:100])
436
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
437
  call_agent, call_agent_level, message)
 
501
  conversation, temperature, max_new_tokens, max_token)
502
  return None
503
 
504
+ def analyze_document(self,
505
+ file_path: str,
506
+ temperature: float = 0.1,
507
+ max_new_tokens: int = 2048,
508
+ max_token: int = 131072) -> Dict[str, Union[str, List]]:
509
+ """
510
+ Analyze a document and return structured results.
511
+
512
+ Args:
513
+ file_path: Path to the document
514
+ temperature: Sampling temperature
515
+ max_new_tokens: Max new tokens per step
516
+ max_token: Max total tokens
517
+
518
+ Returns:
519
+ Dictionary with analysis results
520
+ """
521
+ logger.info("Starting document analysis for: %s", file_path)
522
+ start_time = time.time()
523
+
524
+ try:
525
+ extracted_text = self.extract_text(file_path)
526
+ if not extracted_text:
527
+ raise ValueError("Could not extract text from document")
 
 
 
 
 
 
 
 
 
528
 
529
+ chunks = self.split_text(extracted_text)
530
+ batches = self.batch_chunks(chunks, batch_size=1)
531
+ batch_results = []
532
+
533
+ for batch in batches:
534
+ prompt = "\n\n".join(self.build_prompt(chunk) for chunk in batch)
535
+ response = self.run_multistep_agent(
536
+ prompt,
537
+ temperature=temperature,
538
+ max_new_tokens=max_new_tokens,
539
+ max_token=max_token,
540
+ call_agent=False
541
+ )
542
+ batch_results.append(self.clean_response(response or "No response"))
 
 
 
 
 
 
 
 
 
543
 
544
+ combined = "\n\n".join([res for res in batch_results if not res.startswith("❌")])
545
+ if not combined:
546
+ raise ValueError("No valid batch responses generated")
 
 
 
 
 
 
 
 
 
 
547
 
548
+ final_summary = self.generate_final_summary(self, combined)
549
+
550
+ return {
551
+ "status": "success",
552
+ "summary": final_summary,
553
+ "batch_results": batch_results,
554
+ "processing_time": time.time() - start_time
555
+ }
556
+
557
+ except Exception as e:
558
+ logger.error("Document analysis failed: %s", str(e))
559
+ return {
560
+ "status": "error",
561
+ "message": str(e),
562
+ "processing_time": time.time() - start_time
563
+ }
564
+
565
+ def get_answer_based_on_unfinished_reasoning(self,
566
+ conversation: List[Dict],
567
+ temperature: float,
568
+ max_new_tokens: int,
569
+ max_token: int) -> str:
570
+ """
571
+ Generate a final answer when reasoning is incomplete.
572
+
573
+ Args:
574
+ conversation: Current conversation history
575
+ temperature: Sampling temperature
576
+ max_new_tokens: Max new tokens
577
+ max_token: Max total tokens
578
+
579
+ Returns:
580
+ Final answer string
581
+ """
582
+ if conversation[-1]['role'] == 'assistant':
583
+ conversation.append(
584
+ {'role': 'tool', 'content': 'Errors occurred during function call; provide final answer with current information.'})
585
+ finish_tools_prompt = self.add_finish_tools([])
586
+ last_outputs_str = self.llm_infer(
587
  messages=conversation,
588
  temperature=temperature,
589
+ tools=finish_tools_prompt,
590
+ output_begin_string='[FinalAnswer]',
591
+ skip_special_tokens=True,
592
  max_new_tokens=max_new_tokens,
593
  max_token=max_token)
594
+ logger.info("Unfinished reasoning answer: %s", last_outputs_str[:100])
595
+ return last_outputs_str
596
 
597
+ def update_parameters(self, **kwargs) -> Dict:
598
+ """
599
+ Update agent parameters dynamically.
600
+
601
+ Args:
602
+ kwargs: Parameter names and values to update
603
+
604
+ Returns:
605
+ Dictionary of updated parameters
606
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  updated_attributes = {}
608
  for key, value in kwargs.items():
609
  if hasattr(self, key):
 
612
  logger.info("Updated parameters: %s", updated_attributes)
613
  return updated_attributes
614
 
615
+ def cleanup(self) -> None:
616
+ """Clean up resources and clear memory."""
617
+ if hasattr(self, 'model'):
618
+ del self.model
619
+ if hasattr(self, 'rag_model'):
620
+ del self.rag_model
621
+ if hasattr(self, 'tooluniverse'):
622
+ del self.tooluniverse
623
+ torch.cuda.empty_cache()
624
+ gc.collect()
625
+ logger.info("TxAgent resources cleaned up")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
626
 
627
+ def __del__(self):
628
+ """Destructor to ensure proper cleanup."""
629
+ self.cleanup()