Ali2206 commited on
Commit
80bc091
·
verified ·
1 Parent(s): 9aaa3ed

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +321 -611
src/txagent/txagent.py CHANGED
@@ -12,33 +12,31 @@ from tooluniverse import ToolUniverse
12
  from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
14
  import torch
15
- # near the top of txagent.py
16
  import logging
 
17
  logger = logging.getLogger(__name__)
18
  logging.basicConfig(level=logging.INFO)
19
 
20
  from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
 
22
-
23
  class TxAgent:
24
  def __init__(self, model_name,
25
  rag_model_name,
26
- tool_files_dict=None, # None leads to the default tool files in ToolUniverse
27
  enable_finish=True,
28
  enable_rag=True,
29
  enable_summary=False,
30
- init_rag_num=0,
31
- step_rag_num=10,
32
  summary_mode='step',
33
  summary_skip_last_k=0,
34
  summary_context_length=None,
35
  force_finish=True,
36
  avoid_repeat=True,
37
  seed=None,
38
- enable_checker=False,
39
  enable_chat=False,
40
- additional_default_tools=None,
41
- ):
42
  self.model_name = model_name
43
  self.tokenizer = None
44
  self.terminators = None
@@ -47,10 +45,9 @@ class TxAgent:
47
  self.model = None
48
  self.rag_model = ToolRAGModel(rag_model_name)
49
  self.tooluniverse = None
50
- # self.tool_desc = None
51
- self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
52
- self.self_prompt = "Strictly follow the instruction."
53
- self.chat_prompt = "You are helpful assistant to chat with the user."
54
  self.enable_finish = enable_finish
55
  self.enable_rag = enable_rag
56
  self.enable_summary = enable_summary
@@ -64,7 +61,7 @@ class TxAgent:
64
  self.seed = seed
65
  self.enable_checker = enable_checker
66
  self.additional_default_tools = additional_default_tools
67
- self.print_self_values()
68
 
69
  def init_model(self):
70
  self.load_models()
@@ -73,19 +70,19 @@ class TxAgent:
73
 
74
  def print_self_values(self):
75
  for attr, value in self.__dict__.items():
76
- print(f"{attr}: {value}")
77
 
78
  def load_models(self, model_name=None):
79
- if model_name is not None:
80
- if model_name == self.model_name:
81
- return f"The model {model_name} is already loaded."
82
  self.model_name = model_name
83
 
84
- self.model = LLM(model=self.model_name)
85
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
86
  self.tokenizer = self.model.get_tokenizer()
87
-
88
- return f"Model {model_name} loaded successfully."
89
 
90
  def load_tooluniverse(self):
91
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
@@ -93,9 +90,11 @@ class TxAgent:
93
  special_tools = self.tooluniverse.prepare_tool_prompts(
94
  self.tooluniverse.tool_category_dicts["special_tools"])
95
  self.special_tools_name = [tool['name'] for tool in special_tools]
 
96
 
97
  def load_tool_desc_embedding(self):
98
  self.rag_model.load_tool_desc_embedding(self.tooluniverse)
 
99
 
100
  def rag_infer(self, query, top_k=5):
101
  return self.rag_model.rag_infer(query, top_k)
@@ -109,7 +108,7 @@ class TxAgent:
109
  if call_agent_level >= 2:
110
  call_agent = False
111
 
112
- if not call_agent:
113
  picked_tools_prompt += self.tool_RAG(
114
  message=message, rag_num=self.init_rag_num)
115
  return picked_tools_prompt, call_agent_level
@@ -118,293 +117,198 @@ class TxAgent:
118
  if conversation is None:
119
  conversation = []
120
 
121
- conversation = self.set_system_prompt(
122
- conversation, self.prompt_multi_step)
123
- if history is not None:
124
- if len(history) == 0:
125
- conversation = []
126
- print("clear conversation successfully")
127
- else:
128
- for i in range(len(history)):
129
- if history[i]['role'] == 'user':
130
- if i-1 >= 0 and history[i-1]['role'] == 'assistant':
131
- conversation.append(
132
- {"role": "assistant", "content": history[i-1]['content']})
133
- conversation.append(
134
- {"role": "user", "content": history[i]['content']})
135
- if i == len(history)-1 and history[i]['role'] == 'assistant':
136
- conversation.append(
137
- {"role": "assistant", "content": history[i]['content']})
138
-
139
  conversation.append({"role": "user", "content": message})
140
-
141
  return conversation
142
 
143
- def tool_RAG(self, message=None,
144
- picked_tool_names=None,
145
- existing_tools_prompt=[],
146
- rag_num=5,
147
- return_call_result=False):
148
- extra_factor = 30 # Factor to retrieve more than rag_num
149
  if picked_tool_names is None:
150
- assert picked_tool_names is not None or message is not None
151
- picked_tool_names = self.rag_infer(
152
- message, top_k=rag_num*extra_factor)
153
-
154
- picked_tool_names_no_special = []
155
- for tool in picked_tool_names:
156
- if tool not in self.special_tools_name:
157
- picked_tool_names_no_special.append(tool)
158
- picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
159
- picked_tool_names = picked_tool_names_no_special[:rag_num]
160
 
 
 
 
 
161
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
162
- picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
163
- picked_tools)
164
  if return_call_result:
165
  return picked_tools_prompt, picked_tool_names
166
  return picked_tools_prompt
167
 
168
  def add_special_tools(self, tools, call_agent=False):
169
  if self.enable_finish:
170
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
171
- 'Finish', return_prompt=True))
172
- print("Finish tool is added")
173
  if call_agent:
174
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
- 'CallAgent', return_prompt=True))
176
- print("CallAgent tool is added")
177
- else:
178
- if self.enable_rag:
179
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
180
- 'Tool_RAG', return_prompt=True))
181
- print("Tool_RAG tool is added")
182
-
183
- if self.additional_default_tools is not None:
184
- for each_tool_name in self.additional_default_tools:
185
- tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
186
- each_tool_name, return_prompt=True)
187
- if tool_prompt is not None:
188
- print(f"{each_tool_name} tool is added")
189
- tools.append(tool_prompt)
190
  return tools
191
 
192
  def add_finish_tools(self, tools):
193
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
194
- 'Finish', return_prompt=True))
195
- print("Finish tool is added")
196
  return tools
197
 
198
  def set_system_prompt(self, conversation, sys_prompt):
199
- if len(conversation) == 0:
200
- conversation.append(
201
- {"role": "system", "content": sys_prompt})
202
  else:
203
  conversation[0] = {"role": "system", "content": sys_prompt}
204
  return conversation
205
 
206
- def run_function_call(self, fcall_str,
207
- return_message=False,
208
- existing_tools_prompt=None,
209
- message_for_call_agent=None,
210
- call_agent=False,
211
- call_agent_level=None,
212
- temperature=None):
213
-
214
  function_call_json, message = self.tooluniverse.extract_function_call_json(
215
  fcall_str, return_message=return_message, verbose=False)
216
  call_results = []
217
  special_tool_call = ''
218
- if function_call_json is not None:
219
- if isinstance(function_call_json, list):
220
- for i in range(len(function_call_json)):
221
- print("\033[94mTool Call:\033[0m", function_call_json[i])
222
- if function_call_json[i]["name"] == 'Finish':
223
- special_tool_call = 'Finish'
224
- break
225
- elif function_call_json[i]["name"] == 'Tool_RAG':
226
- new_tools_prompt, call_result = self.tool_RAG(
227
- message=message,
228
- existing_tools_prompt=existing_tools_prompt,
229
- rag_num=self.step_rag_num,
230
- return_call_result=True)
231
- existing_tools_prompt += new_tools_prompt
232
- elif function_call_json[i]["name"] == 'CallAgent':
233
- if call_agent_level < 2 and call_agent:
234
- solution_plan = function_call_json[i]['arguments']['solution']
235
- full_message = (
236
- message_for_call_agent +
237
- "\nYou must follow the following plan to answer the question: " +
238
- str(solution_plan)
239
- )
240
- call_result = self.run_multistep_agent(
241
- full_message, temperature=temperature,
242
- max_new_tokens=1024, max_token=99999,
243
- call_agent=False, call_agent_level=call_agent_level)
244
- if call_result is None:
245
- call_result = "⚠️ No content returned from sub-agent."
246
- else:
247
- call_result = call_result.split('[FinalAnswer]')[-1].strip()
248
- else:
249
- call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
250
- else:
251
- call_result = self.tooluniverse.run_one_function(
252
- function_call_json[i])
253
-
254
- call_id = self.tooluniverse.call_id_gen()
255
- function_call_json[i]["call_id"] = call_id
256
- print("\033[94mTool Call Result:\033[0m", call_result)
257
- call_results.append({
258
- "role": "tool",
259
- "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
260
- })
261
  else:
262
  call_results.append({
263
  "role": "tool",
264
- "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
265
  })
266
 
267
  revised_messages = [{
268
  "role": "assistant",
269
- "content": message.strip(),
270
  "tool_calls": json.dumps(function_call_json)
271
  }] + call_results
272
-
273
- # Yield the final result.
274
  return revised_messages, existing_tools_prompt, special_tool_call
275
 
276
- def run_function_call_stream(self, fcall_str,
277
- return_message=False,
278
- existing_tools_prompt=None,
279
- message_for_call_agent=None,
280
- call_agent=False,
281
- call_agent_level=None,
282
- temperature=None,
283
- return_gradio_history=True):
284
-
285
  function_call_json, message = self.tooluniverse.extract_function_call_json(
286
  fcall_str, return_message=return_message, verbose=False)
287
  call_results = []
288
  special_tool_call = ''
289
- if return_gradio_history:
290
- gradio_history = []
291
- if function_call_json is not None:
292
- if isinstance(function_call_json, list):
293
- for i in range(len(function_call_json)):
294
- if function_call_json[i]["name"] == 'Finish':
295
- special_tool_call = 'Finish'
296
- break
297
- elif function_call_json[i]["name"] == 'Tool_RAG':
298
- new_tools_prompt, call_result = self.tool_RAG(
299
- message=message,
300
- existing_tools_prompt=existing_tools_prompt,
301
- rag_num=self.step_rag_num,
302
- return_call_result=True)
303
- existing_tools_prompt += new_tools_prompt
304
- elif function_call_json[i]["name"] == 'DirectResponse':
305
- call_result = function_call_json[i]['arguments']['respose']
306
- special_tool_call = 'DirectResponse'
307
- elif function_call_json[i]["name"] == 'RequireClarification':
308
- call_result = function_call_json[i]['arguments']['unclear_question']
309
- special_tool_call = 'RequireClarification'
310
- elif function_call_json[i]["name"] == 'CallAgent':
311
- if call_agent_level < 2 and call_agent:
312
- solution_plan = function_call_json[i]['arguments']['solution']
313
- full_message = (
314
- message_for_call_agent +
315
- "\nYou must follow the following plan to answer the question: " +
316
- str(solution_plan)
317
- )
318
- sub_agent_task = "Sub TxAgent plan: " + \
319
- str(solution_plan)
320
- call_result = yield from self.run_gradio_chat(
321
- full_message, history=[], temperature=temperature,
322
- max_new_tokens=1024, max_token=99999,
323
- call_agent=False, call_agent_level=call_agent_level,
324
- conversation=None,
325
- sub_agent_task=sub_agent_task)
326
-
327
- if call_result is not None and isinstance(call_result, str):
328
- call_result = call_result.split('[FinalAnswer]')[-1]
329
- else:
330
- call_result = "⚠️ No content returned from sub-agent."
331
- else:
332
- call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
333
- else:
334
- call_result = self.tooluniverse.run_one_function(
335
- function_call_json[i])
336
-
337
- call_id = self.tooluniverse.call_id_gen()
338
- function_call_json[i]["call_id"] = call_id
339
- call_results.append({
340
- "role": "tool",
341
- "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
342
- })
343
- if return_gradio_history and function_call_json[i]["name"] != 'Finish':
344
- if function_call_json[i]["name"] == 'Tool_RAG':
345
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
346
- "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
347
- else:
348
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
349
- "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
350
  else:
351
  call_results.append({
352
  "role": "tool",
353
- "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
354
  })
355
 
356
  revised_messages = [{
357
  "role": "assistant",
358
- "content": message.strip(),
359
  "tool_calls": json.dumps(function_call_json)
360
  }] + call_results
 
361
 
362
- if return_gradio_history:
363
- return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
364
- else:
365
- return revised_messages, existing_tools_prompt, special_tool_call
366
-
367
-
368
- def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
369
- if conversation[-1]['role'] == 'assisant':
370
  conversation.append(
371
- {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
372
  finish_tools_prompt = self.add_finish_tools([])
 
 
 
 
 
373
 
374
- last_outputs_str = self.llm_infer(messages=conversation,
375
- temperature=temperature,
376
- tools=finish_tools_prompt,
377
- output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
378
- skip_special_tokens=True,
379
- max_new_tokens=max_new_tokens, max_token=max_token)
380
- print(last_outputs_str)
381
- return last_outputs_str
382
-
383
- def run_multistep_agent(self, message: str,
384
- temperature: float,
385
- max_new_tokens: int,
386
- max_token: int,
387
- max_round: int = 20,
388
- call_agent=False,
389
- call_agent_level=0) -> str:
390
- """
391
- Generate a streaming response using the llama3-8b model.
392
- Args:
393
- message (str): The input message.
394
- temperature (float): The temperature for generating the response.
395
- max_new_tokens (int): The maximum number of new tokens to generate.
396
- Returns:
397
- str: The generated response.
398
- """
399
- print("\033[1;32;40mstart\033[0m")
400
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
401
  call_agent, call_agent_level, message)
402
  conversation = self.initialize_conversation(message)
403
-
404
  outputs = []
405
  last_outputs = []
406
  next_round = True
407
- function_call_messages = []
408
  current_round = 0
409
  token_overflow = False
410
  enable_summary = False
@@ -412,103 +316,70 @@ class TxAgent:
412
 
413
  if self.enable_checker:
414
  checker = ReasoningTraceChecker(message, conversation)
415
- try:
416
- while next_round and current_round < max_round:
417
- current_round += 1
418
- if len(outputs) > 0:
419
- function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
420
- last_outputs, return_message=True,
421
- existing_tools_prompt=picked_tools_prompt,
422
- message_for_call_agent=message,
423
- call_agent=call_agent,
424
- call_agent_level=call_agent_level,
425
- temperature=temperature)
426
-
427
- if special_tool_call == 'Finish':
428
- next_round = False
429
- conversation.extend(function_call_messages)
430
- if isinstance(function_call_messages[0]['content'], types.GeneratorType):
431
- function_call_messages[0]['content'] = next(
432
- function_call_messages[0]['content'])
433
- content = function_call_messages[0]['content']
434
- if content is None:
435
- return "❌ No content returned after Finish tool call."
436
- return content.split('[FinalAnswer]')[-1]
437
 
438
- if (self.enable_summary or token_overflow) and not call_agent:
439
- if token_overflow:
440
- print("token_overflow, using summary")
441
- enable_summary = True
442
- last_status = self.function_result_summary(
443
- conversation, status=last_status, enable_summary=enable_summary)
444
-
445
- if function_call_messages is not None:
446
- conversation.extend(function_call_messages)
447
- outputs.append(tool_result_format(
448
- function_call_messages))
449
- else:
450
- next_round = False
451
- conversation.extend(
452
- [{"role": "assistant", "content": ''.join(last_outputs)}])
453
- return ''.join(last_outputs).replace("</s>", "")
454
- if self.enable_checker:
455
- good_status, wrong_info = checker.check_conversation()
456
- if not good_status:
457
- next_round = False
458
- print(
459
- "Internal error in reasoning: " + wrong_info)
460
- break
461
- last_outputs = []
462
- outputs.append("### TxAgent:\n")
463
- last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
464
- temperature=temperature,
465
- tools=picked_tools_prompt,
466
- skip_special_tokens=False,
467
- max_new_tokens=max_new_tokens, max_token=max_token,
468
- check_token_status=True)
469
- if last_outputs_str is None:
470
- print("The number of tokens exceeds the maximum limit.")
471
- if self.force_finish:
472
- return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
473
- else:
474
- return "❌ Token limit exceeded — no further steps possible."
475
  else:
476
- last_outputs.append(last_outputs_str)
477
- if max_round == current_round:
478
- print("The number of rounds exceeds the maximum limit!")
479
- if self.force_finish:
480
- return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
481
- else:
482
- return None
 
 
 
 
 
 
 
 
 
 
 
 
483
 
484
- except Exception as e:
485
- print(f"Error: {e}")
486
- if self.force_finish:
487
- return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
488
- else:
489
- return None
490
 
491
  def build_logits_processor(self, messages, llm):
492
- # Use the tokenizer from the LLM instance.
493
  tokenizer = llm.get_tokenizer()
494
  if self.avoid_repeat and len(messages) > 2:
495
- assistant_messages = []
496
- for i in range(1, len(messages) + 1):
497
- if messages[-i]['role'] == 'assistant':
498
- assistant_messages.append(messages[-i]['content'])
499
- if len(assistant_messages) == 2:
500
- break
501
- forbidden_ids = [tokenizer.encode(
502
- msg, add_special_tokens=False) for msg in assistant_messages]
503
  return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
504
- else:
505
- return None
506
 
507
- def llm_infer(self, messages, temperature=0.1, tools=None,
508
- output_begin_string=None, max_new_tokens=2048,
509
- max_token=None, skip_special_tokens=True,
510
  model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
511
-
512
  if model is None:
513
  model = self.model
514
 
@@ -516,333 +387,207 @@ class TxAgent:
516
  sampling_params = SamplingParams(
517
  temperature=temperature,
518
  max_tokens=max_new_tokens,
519
-
520
  seed=seed if seed is not None else self.seed,
 
521
  )
522
 
523
- prompt = self.chat_template.render(
524
- messages=messages, tools=tools, add_generation_prompt=True)
525
- if output_begin_string is not None:
526
  prompt += output_begin_string
527
 
528
- if check_token_status and max_token is not None:
529
- token_overflow = False
530
- num_input_tokens = len(self.tokenizer.encode(
531
- prompt, return_tensors="pt")[0])
532
- if max_token is not None:
533
- if num_input_tokens > max_token:
534
- torch.cuda.empty_cache()
535
- gc.collect()
536
- print("Number of input tokens before inference:",
537
- num_input_tokens)
538
- logger.info(
539
- "The number of tokens exceeds the maximum limit!!!!")
540
- token_overflow = True
541
- return None, token_overflow
542
- output = model.generate(
543
- prompt,
544
- sampling_params=sampling_params,
545
- )
546
- output = output[0].outputs[0].text
547
- print("\033[92m" + output + "\033[0m")
548
- if check_token_status and max_token is not None:
549
- return output, token_overflow
550
 
 
 
 
 
 
 
551
  return output
552
 
553
- def run_self_agent(self, message: str,
554
- temperature: float,
555
- max_new_tokens: int,
556
- max_token: int) -> str:
557
-
558
- print("\033[1;32;40mstart self agent\033[0m")
559
- conversation = []
560
- conversation = self.set_system_prompt(conversation, self.self_prompt)
561
  conversation.append({"role": "user", "content": message})
562
- return self.llm_infer(messages=conversation,
563
- temperature=temperature,
564
- tools=None,
565
  max_new_tokens=max_new_tokens, max_token=max_token)
566
 
567
- def run_chat_agent(self, message: str,
568
- temperature: float,
569
- max_new_tokens: int,
570
- max_token: int) -> str:
571
-
572
- print("\033[1;32;40mstart chat agent\033[0m")
573
- conversation = []
574
- conversation = self.set_system_prompt(conversation, self.chat_prompt)
575
  conversation.append({"role": "user", "content": message})
576
- return self.llm_infer(messages=conversation,
577
- temperature=temperature,
578
- tools=None,
579
  max_new_tokens=max_new_tokens, max_token=max_token)
580
 
581
- def run_format_agent(self, message: str,
582
- answer: str,
583
- temperature: float,
584
- max_new_tokens: int,
585
- max_token: int) -> str:
586
-
587
- print("\033[1;32;40mstart format agent\033[0m")
588
  if '[FinalAnswer]' in answer:
589
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
590
  elif "\n\n" in answer:
591
  possible_final_answer = answer.split("\n\n")[-1]
592
  else:
593
  possible_final_answer = answer.strip()
594
- if len(possible_final_answer) == 1:
595
- choice = possible_final_answer[0]
596
- if choice in ['A', 'B', 'C', 'D', 'E']:
597
- return choice
598
- elif len(possible_final_answer) > 1:
599
- if possible_final_answer[1] == ':':
600
- choice = possible_final_answer[0]
601
- if choice in ['A', 'B', 'C', 'D', 'E']:
602
- print("choice", choice)
603
- return choice
604
-
605
- conversation = []
606
- format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
607
- conversation = self.set_system_prompt(conversation, format_prompt)
608
- conversation.append({"role": "user", "content": message +
609
- "\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
610
- return self.llm_infer(messages=conversation,
611
- temperature=temperature,
612
- tools=None,
613
  max_new_tokens=max_new_tokens, max_token=max_token)
614
 
615
- def run_summary_agent(self, thought_calls: str,
616
- function_response: str,
617
- temperature: float,
618
- max_new_tokens: int,
619
- max_token: int) -> str:
620
- print("\033[1;32;40mSummarized Tool Result:\033[0m")
621
- generate_tool_result_summary_training_prompt = """Thought and function calls:
622
- {thought_calls}
623
- Function calls' responses:
624
- \"\"\"
625
- {function_response}
626
- \"\"\"
627
- Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
628
- Directly respond with the summarized sentence of the function calls' responses only.
629
- Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
630
- """.format(thought_calls=thought_calls, function_response=function_response)
631
- conversation = []
632
- conversation.append(
633
- {"role": "user", "content": generate_tool_result_summary_training_prompt})
634
- output = self.llm_infer(messages=conversation,
635
- temperature=temperature,
636
- tools=None,
637
  max_new_tokens=max_new_tokens, max_token=max_token)
638
-
639
  if '[' in output:
640
  output = output.split('[')[0]
641
  return output
642
 
643
  def function_result_summary(self, input_list, status, enable_summary):
644
- """
645
- Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
646
- Supports 'length' and 'step' modes, and skips the last 'k' groups.
647
- Parameters:
648
- input_list (list): A list of dictionaries containing role and other information.
649
- summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
650
- summary_context_length (int): The context length threshold for the 'length' mode.
651
- last_processed_index (tuple or int): The last processed index.
652
- Returns:
653
- list: A list of extracted information from valid sequences.
654
- """
655
  if 'tool_call_step' not in status:
656
  status['tool_call_step'] = 0
 
 
 
657
 
658
  for idx in range(len(input_list)):
659
- pos_id = len(input_list)-idx-1
660
- if input_list[pos_id]['role'] == 'assistant':
661
- if 'tool_calls' in input_list[pos_id]:
662
- if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
663
- status['tool_call_step'] += 1
664
  break
665
 
666
- if 'step' in status:
667
- status['step'] += 1
668
- else:
669
- status['step'] = 0
670
-
671
  if not enable_summary:
672
  return status
673
 
674
  if 'summarized_index' not in status:
675
  status['summarized_index'] = 0
676
-
677
  if 'summarized_step' not in status:
678
  status['summarized_step'] = 0
679
-
680
  if 'previous_length' not in status:
681
  status['previous_length'] = 0
682
-
683
  if 'history' not in status:
684
  status['history'] = []
685
 
686
- function_response = ''
687
- idx = 0
688
- current_summarized_index = status['summarized_index']
689
-
690
- status['history'].append(self.summary_mode == 'step' and status['summarized_step']
691
- < status['step']-status['tool_call_step']-self.summary_skip_last_k)
692
 
693
- idx = current_summarized_index
 
 
694
  while idx < len(input_list):
695
- if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
696
-
697
  if input_list[idx]['role'] == 'assistant':
698
  if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
699
  this_thought_calls = None
700
  else:
701
- if len(function_response) != 0:
702
- print("internal summary")
703
  status['summarized_step'] += 1
704
  result_summary = self.run_summary_agent(
705
- thought_calls=this_thought_calls,
706
- function_response=function_response,
707
- temperature=0.1,
708
- max_new_tokens=1024,
709
- max_token=99999
710
- )
711
-
712
  input_list.insert(
713
- last_call_idx+1, {'role': 'tool', 'content': result_summary})
714
  status['summarized_index'] = last_call_idx + 2
715
  idx += 1
716
-
717
  last_call_idx = idx
718
- this_thought_calls = input_list[idx]['content'] + \
719
- input_list[idx]['tool_calls']
720
  function_response = ''
721
-
722
- elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
723
  function_response += input_list[idx]['content']
724
  del input_list[idx]
725
  idx -= 1
726
-
727
  else:
728
  break
729
  idx += 1
730
 
731
- if len(function_response) != 0:
732
  status['summarized_step'] += 1
733
  result_summary = self.run_summary_agent(
734
- thought_calls=this_thought_calls,
735
- function_response=function_response,
736
- temperature=0.1,
737
- max_new_tokens=1024,
738
- max_token=99999
739
- )
740
-
741
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
742
  for tool_call in tool_calls:
743
  del tool_call['call_id']
744
  input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
745
  input_list.insert(
746
- last_call_idx+1, {'role': 'tool', 'content': result_summary})
747
  status['summarized_index'] = last_call_idx + 2
748
 
749
  return status
750
 
751
- # Following are Gradio related functions
752
-
753
- # General update method that accepts any new arguments through kwargs
754
  def update_parameters(self, **kwargs):
 
755
  for key, value in kwargs.items():
756
  if hasattr(self, key):
757
  setattr(self, key, value)
758
-
759
- # Return the updated attributes
760
- updated_attributes = {key: value for key,
761
- value in kwargs.items() if hasattr(self, key)}
762
  return updated_attributes
763
 
764
- def run_gradio_chat(self, message: str,
765
- history: list,
766
- temperature: float,
767
- max_new_tokens: int,
768
- max_token: int,
769
- call_agent: bool,
770
- conversation: gr.State,
771
- max_round: int = 20,
772
- seed: int = None,
773
- call_agent_level: int = 0,
774
- sub_agent_task: str = None,
775
- uploaded_files: list = None) -> str:
776
- """
777
- Generate a streaming response using the loaded model.
778
- Args:
779
- message (str): The input message (with file content if uploaded).
780
- history (list): The conversation history used by ChatInterface.
781
- temperature (float): Sampling temperature.
782
- max_new_tokens (int): Max new tokens.
783
- max_token (int): Max total tokens allowed.
784
- Returns:
785
- str: Final assistant message.
786
- """
787
- logger.debug(f"[TxAgent] Chat started, message: {message[:100]}...")
788
- print("\033[1;32;40m[TxAgent] Chat started\033[0m")
789
-
790
  if not message or len(message.strip()) < 5:
791
  yield "Please provide a valid message or upload files to analyze."
792
- return "Invalid input."
793
 
794
  if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
795
- return ""
796
-
797
- outputs = []
798
- outputs_str = ''
799
- last_outputs = []
800
 
801
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
802
- call_agent,
803
- call_agent_level,
804
- message)
805
-
806
  conversation = self.initialize_conversation(
807
- message,
808
- conversation=conversation,
809
- history=history)
810
  history = []
811
 
812
  next_round = True
813
- function_call_messages = []
814
  current_round = 0
815
  enable_summary = False
816
  last_status = {}
817
  token_overflow = False
818
 
819
  if self.enable_checker:
820
- checker = ReasoningTraceChecker(
821
- message, conversation, init_index=len(conversation))
822
 
823
  try:
824
  while next_round and current_round < max_round:
825
  current_round += 1
826
- logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
827
-
828
  if last_outputs:
829
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
830
- last_outputs, return_message=True,
831
- existing_tools_prompt=picked_tools_prompt,
832
- message_for_call_agent=message,
833
- call_agent=call_agent,
834
- call_agent_level=call_agent_level,
835
- temperature=temperature)
836
-
837
  history.extend(current_gradio_history)
838
 
839
- if special_tool_call == 'Finish' and function_call_messages:
840
  yield history
841
  next_round = False
842
  conversation.extend(function_call_messages)
843
  return function_call_messages[0]['content']
844
 
845
- elif special_tool_call in ['RequireClarification', 'DirectResponse']:
846
  last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
847
  history.append(ChatMessage(role="assistant", content=last_msg.content))
848
  yield history
@@ -851,64 +596,46 @@ Generate **one summarized sentence** about "function calls' responses" with nece
851
 
852
  if (self.enable_summary or token_overflow) and not call_agent:
853
  enable_summary = True
854
-
855
  last_status = self.function_result_summary(
856
- conversation, status=last_status,
857
- enable_summary=enable_summary)
858
 
859
  if function_call_messages:
860
  conversation.extend(function_call_messages)
861
  yield history
862
  else:
863
  next_round = False
864
- conversation.append({"role": "assistant", "content": ''.join(last_outputs)})
865
  return ''.join(last_outputs).replace("</s>", "")
866
 
867
  if self.enable_checker:
868
  good_status, wrong_info = checker.check_conversation()
869
  if not good_status:
870
- print("Checker flagged reasoning error: ", wrong_info)
871
  break
872
 
873
- last_outputs = []
874
  last_outputs_str, token_overflow = self.llm_infer(
875
- messages=conversation,
876
- temperature=temperature,
877
- tools=picked_tools_prompt,
878
- skip_special_tokens=False,
879
- max_new_tokens=max_new_tokens,
880
- max_token=max_token,
881
- seed=seed,
882
- check_token_status=True)
883
-
884
- logger.debug(f"llm_infer output: {last_outputs_str[:100] if last_outputs_str else None}, token_overflow: {token_overflow}")
885
 
886
  if last_outputs_str is None:
887
- logger.warning("llm_infer returned None due to token overflow")
888
  if self.force_finish:
889
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
890
  conversation, temperature, max_new_tokens, max_token)
891
  history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
892
  yield history
893
  return last_outputs_str
894
- else:
895
- error_msg = "Token limit exceeded. Please reduce input size or increase max_token."
896
- history.append(ChatMessage(role="assistant", content=error_msg))
897
- yield history
898
- return error_msg
899
 
900
  last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
901
-
902
  for msg in history:
903
- if msg.metadata is not None:
904
  msg.metadata['status'] = 'done'
905
 
906
  if '[FinalAnswer]' in last_thought:
907
  parts = last_thought.split('[FinalAnswer]', 1)
908
- if len(parts) == 2:
909
- final_thought, final_answer = parts
910
- else:
911
- final_thought, final_answer = last_thought, ""
912
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
913
  yield history
914
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
@@ -919,45 +646,28 @@ Generate **one summarized sentence** about "function calls' responses" with nece
919
 
920
  last_outputs.append(last_outputs_str)
921
 
922
- if next_round:
923
- if self.force_finish:
924
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
925
- conversation, temperature, max_new_tokens, max_token)
926
- if '[FinalAnswer]' in last_outputs_str:
927
- parts = last_outputs_str.split('[FinalAnswer]', 1)
928
- if len(parts) == 2:
929
- final_thought, final_answer = parts
930
- else:
931
- final_thought, final_answer = last_outputs_str, ""
932
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
933
- yield history
934
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
935
- yield history
936
- else:
937
- history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
938
- yield history
939
- else:
940
- yield "The number of reasoning rounds exceeded the limit."
941
 
942
  except Exception as e:
943
- logger.error(f"Exception in run_gradio_chat: {e}", exc_info=True)
944
- error_msg = f"An error occurred: {e}"
945
  history.append(ChatMessage(role="assistant", content=error_msg))
946
  yield history
947
  if self.force_finish:
948
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
949
  conversation, temperature, max_new_tokens, max_token)
950
- if '[FinalAnswer]' in last_outputs_str:
951
- parts = last_outputs_str.split('[FinalAnswer]', 1)
952
- if len(parts) == 2:
953
- final_thought, final_answer = parts
954
- else:
955
- final_thought, final_answer = last_outputs_str, ""
956
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
957
- yield history
958
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
959
- yield history
960
- else:
961
- history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
962
- yield history
963
  return error_msg
 
12
  from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
14
  import torch
 
15
  import logging
16
+
17
  logger = logging.getLogger(__name__)
18
  logging.basicConfig(level=logging.INFO)
19
 
20
  from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
 
 
22
  class TxAgent:
23
  def __init__(self, model_name,
24
  rag_model_name,
25
+ tool_files_dict=None,
26
  enable_finish=True,
27
  enable_rag=True,
28
  enable_summary=False,
29
+ init_rag_num=2, # Reduced for faster initial tool selection
30
+ step_rag_num=4, # Reduced for fewer RAG calls
31
  summary_mode='step',
32
  summary_skip_last_k=0,
33
  summary_context_length=None,
34
  force_finish=True,
35
  avoid_repeat=True,
36
  seed=None,
37
+ enable_checker=False, # Disabled by default for speed
38
  enable_chat=False,
39
+ additional_default_tools=None):
 
40
  self.model_name = model_name
41
  self.tokenizer = None
42
  self.terminators = None
 
45
  self.model = None
46
  self.rag_model = ToolRAGModel(rag_model_name)
47
  self.tooluniverse = None
48
+ self.prompt_multi_step = "You are a medical assistant solving clinical oversight issues step-by-step using provided tools."
49
+ self.self_prompt = "Follow instructions precisely."
50
+ self.chat_prompt = "You are a helpful assistant for clinical queries."
 
51
  self.enable_finish = enable_finish
52
  self.enable_rag = enable_rag
53
  self.enable_summary = enable_summary
 
61
  self.seed = seed
62
  self.enable_checker = enable_checker
63
  self.additional_default_tools = additional_default_tools
64
+ logger.debug("TxAgent initialized with parameters: %s", self.__dict__)
65
 
66
  def init_model(self):
67
  self.load_models()
 
70
 
71
  def print_self_values(self):
72
  for attr, value in self.__dict__.items():
73
+ logger.debug("%s: %s", attr, value)
74
 
75
  def load_models(self, model_name=None):
76
+ if model_name is not None and model_name == self.model_name:
77
+ return f"The model {model_name} is already loaded."
78
+ if model_name:
79
  self.model_name = model_name
80
 
81
+ self.model = LLM(model=self.model_name, dtype="float16") # Enable FP16
82
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
83
  self.tokenizer = self.model.get_tokenizer()
84
+ logger.info("Model %s loaded successfully", self.model_name)
85
+ return f"Model {self.model_name} loaded successfully."
86
 
87
  def load_tooluniverse(self):
88
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
 
90
  special_tools = self.tooluniverse.prepare_tool_prompts(
91
  self.tooluniverse.tool_category_dicts["special_tools"])
92
  self.special_tools_name = [tool['name'] for tool in special_tools]
93
+ logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name))
94
 
95
  def load_tool_desc_embedding(self):
96
  self.rag_model.load_tool_desc_embedding(self.tooluniverse)
97
+ logger.debug("Tool description embeddings loaded")
98
 
99
  def rag_infer(self, query, top_k=5):
100
  return self.rag_model.rag_infer(query, top_k)
 
108
  if call_agent_level >= 2:
109
  call_agent = False
110
 
111
+ if not call_agent and self.enable_rag:
112
  picked_tools_prompt += self.tool_RAG(
113
  message=message, rag_num=self.init_rag_num)
114
  return picked_tools_prompt, call_agent_level
 
117
  if conversation is None:
118
  conversation = []
119
 
120
+ conversation = self.set_system_prompt(conversation, self.prompt_multi_step)
121
+ if history:
122
+ conversation.extend(
123
+ {"role": h['role'], "content": h['content']}
124
+ for h in history if h['role'] in ['user', 'assistant']
125
+ )
 
 
 
 
 
 
 
 
 
 
 
 
126
  conversation.append({"role": "user", "content": message})
127
+ logger.debug("Conversation initialized with %d messages", len(conversation))
128
  return conversation
129
 
130
+ def tool_RAG(self, message=None, picked_tool_names=None,
131
+ existing_tools_prompt=None, rag_num=4, return_call_result=False):
132
+ extra_factor = 10 # Reduced from 30 for efficiency
 
 
 
133
  if picked_tool_names is None:
134
+ picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor)
 
 
 
 
 
 
 
 
 
135
 
136
+ picked_tool_names = [
137
+ tool for tool in picked_tool_names
138
+ if tool not in self.special_tools_name
139
+ ][:rag_num]
140
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
141
+ picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
142
+ logger.debug("RAG selected %d tools: %s", len(picked_tool_names), picked_tool_names)
143
  if return_call_result:
144
  return picked_tools_prompt, picked_tool_names
145
  return picked_tools_prompt
146
 
147
  def add_special_tools(self, tools, call_agent=False):
148
  if self.enable_finish:
149
+ tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
150
+ logger.debug("Finish tool added")
 
151
  if call_agent:
152
+ tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
153
+ logger.debug("CallAgent tool added")
154
+ elif self.enable_rag:
155
+ tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
156
+ logger.debug("Tool_RAG tool added")
157
+ if self.additional_default_tools:
158
+ for tool_name in self.additional_default_tools:
159
+ tool_prompt = self.tooluniverse.get_one_tool_by_one_name(tool_name, return_prompt=True)
160
+ if tool_prompt:
161
+ tools.append(tool_prompt)
162
+ logger.debug("%s tool added", tool_name)
 
 
 
 
 
163
  return tools
164
 
165
  def add_finish_tools(self, tools):
166
+ tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
167
+ logger.debug("Finish tool added")
 
168
  return tools
169
 
170
  def set_system_prompt(self, conversation, sys_prompt):
171
+ if not conversation:
172
+ conversation.append({"role": "system", "content": sys_prompt})
 
173
  else:
174
  conversation[0] = {"role": "system", "content": sys_prompt}
175
  return conversation
176
 
177
+ def run_function_call(self, fcall_str, return_message=False,
178
+ existing_tools_prompt=None, message_for_call_agent=None,
179
+ call_agent=False, call_agent_level=None, temperature=None):
 
 
 
 
 
180
  function_call_json, message = self.tooluniverse.extract_function_call_json(
181
  fcall_str, return_message=return_message, verbose=False)
182
  call_results = []
183
  special_tool_call = ''
184
+ if function_call_json:
185
+ for func in function_call_json if isinstance(function_call_json, list) else [function_call_json]:
186
+ logger.debug("Tool Call: %s", func)
187
+ if func["name"] == 'Finish':
188
+ special_tool_call = 'Finish'
189
+ break
190
+ elif func["name"] == 'Tool_RAG':
191
+ new_tools_prompt, call_result = self.tool_RAG(
192
+ message=message, existing_tools_prompt=existing_tools_prompt,
193
+ rag_num=self.step_rag_num, return_call_result=True)
194
+ existing_tools_prompt += new_tools_prompt
195
+ elif func["name"] == 'CallAgent' and call_agent and call_agent_level < 2:
196
+ solution_plan = func['arguments']['solution']
197
+ full_message = (
198
+ message_for_call_agent + "\nFollow this plan: " + str(solution_plan)
199
+ )
200
+ call_result = self.run_multistep_agent(
201
+ full_message, temperature=temperature, max_new_tokens=512,
202
+ max_token=2048, call_agent=False, call_agent_level=call_agent_level)
203
+ call_result = call_result.split('[FinalAnswer]')[-1].strip() if call_result else "⚠️ No content from sub-agent."
204
+ else:
205
+ call_result = self.tooluniverse.run_one_function(func)
206
+
207
+ call_id = self.tooluniverse.call_id_gen()
208
+ func["call_id"] = call_id
209
+ logger.debug("Tool Call Result: %s", call_result)
210
+ call_results.append({
211
+ "role": "tool",
212
+ "content": json.dumps({"tool_name": func["name"], "content": call_result, "call_id": call_id})
213
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  else:
215
  call_results.append({
216
  "role": "tool",
217
+ "content": json.dumps({"content": "Invalid function call format."})
218
  })
219
 
220
  revised_messages = [{
221
  "role": "assistant",
222
+ "content": message.strip() if message else "",
223
  "tool_calls": json.dumps(function_call_json)
224
  }] + call_results
 
 
225
  return revised_messages, existing_tools_prompt, special_tool_call
226
 
227
+ def run_function_call_stream(self, fcall_str, return_message=False,
228
+ existing_tools_prompt=None, message_for_call_agent=None,
229
+ call_agent=False, call_agent_level=None, temperature=None,
230
+ return_gradio_history=True):
 
 
 
 
 
231
  function_call_json, message = self.tooluniverse.extract_function_call_json(
232
  fcall_str, return_message=return_message, verbose=False)
233
  call_results = []
234
  special_tool_call = ''
235
+ gradio_history = [] if return_gradio_history else None
236
+ if function_call_json:
237
+ for func in function_call_json if isinstance(function_call_json, list) else [function_call_json]:
238
+ if func["name"] == 'Finish':
239
+ special_tool_call = 'Finish'
240
+ break
241
+ elif func["name"] == 'Tool_RAG':
242
+ new_tools_prompt, call_result = self.tool_RAG(
243
+ message=message, existing_tools_prompt=existing_tools_prompt,
244
+ rag_num=self.step_rag_num, return_call_result=True)
245
+ existing_tools_prompt += new_tools_prompt
246
+ elif func["name"] == 'DirectResponse':
247
+ call_result = func['arguments']['response']
248
+ special_tool_call = 'DirectResponse'
249
+ elif func["name"] == 'RequireClarification':
250
+ call_result = func['arguments']['unclear_question']
251
+ special_tool_call = 'RequireClarification'
252
+ elif func["name"] == 'CallAgent' and call_agent and call_agent_level < 2:
253
+ solution_plan = func['arguments']['solution']
254
+ full_message = (
255
+ message_for_call_agent + "\nFollow this plan: " + str(solution_plan)
256
+ )
257
+ sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
258
+ call_result = yield from self.run_gradio_chat(
259
+ full_message, history=[], temperature=temperature,
260
+ max_new_tokens=512, max_token=2048, call_agent=False,
261
+ call_agent_level=call_agent_level, conversation=None,
262
+ sub_agent_task=sub_agent_task)
263
+ call_result = call_result.split('[FinalAnswer]')[-1] if call_result else "⚠️ No content from sub-agent."
264
+ else:
265
+ call_result = self.tooluniverse.run_one_function(func)
266
+
267
+ call_id = self.tooluniverse.call_id_gen()
268
+ func["call_id"] = call_id
269
+ call_results.append({
270
+ "role": "tool",
271
+ "content": json.dumps({"tool_name": func["name"], "content": call_result, "call_id": call_id})
272
+ })
273
+ if return_gradio_history and func["name"] != 'Finish':
274
+ title = f"{'🧰' if func['name'] == 'Tool_RAG' else '⚒️'} {func['name']}"
275
+ gradio_history.append(ChatMessage(
276
+ role="assistant", content=str(call_result),
277
+ metadata={"title": title, "log": str(func['arguments'])}
278
+ ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  else:
280
  call_results.append({
281
  "role": "tool",
282
+ "content": json.dumps({"content": "Invalid function call format."})
283
  })
284
 
285
  revised_messages = [{
286
  "role": "assistant",
287
+ "content": message.strip() if message else "",
288
  "tool_calls": json.dumps(function_call_json)
289
  }] + call_results
290
+ return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
291
 
292
+ def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token):
293
+ if conversation[-1]['role'] == 'assistant':
 
 
 
 
 
 
294
  conversation.append(
295
+ {'role': 'tool', 'content': 'Errors occurred; provide final answer with current info.'})
296
  finish_tools_prompt = self.add_finish_tools([])
297
+ output = self.llm_infer(
298
+ messages=conversation, temperature=temperature, tools=finish_tools_prompt,
299
+ output_begin_string='[FinalAnswer]', max_new_tokens=max_new_tokens, max_token=max_token)
300
+ logger.debug("Unfinished reasoning output: %s", output)
301
+ return output
302
 
303
+ def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int,
304
+ max_token: int, max_round: int = 10, call_agent=False, call_agent_level=0):
305
+ logger.debug("Starting multistep agent for message: %s", message[:100])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
307
  call_agent, call_agent_level, message)
308
  conversation = self.initialize_conversation(message)
 
309
  outputs = []
310
  last_outputs = []
311
  next_round = True
 
312
  current_round = 0
313
  token_overflow = False
314
  enable_summary = False
 
316
 
317
  if self.enable_checker:
318
  checker = ReasoningTraceChecker(message, conversation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
+ while next_round and current_round < max_round:
321
+ current_round += 1
322
+ if last_outputs:
323
+ function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
324
+ last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
325
+ message_for_call_agent=message, call_agent=call_agent,
326
+ call_agent_level=call_agent_level, temperature=temperature)
327
+
328
+ if special_tool_call == 'Finish':
329
+ next_round = False
330
+ conversation.extend(function_call_messages)
331
+ content = function_call_messages[0]['content']
332
+ return content.split('[FinalAnswer]')[-1] if content else "❌ No content after Finish."
333
+
334
+ if (self.enable_summary or token_overflow) and not call_agent:
335
+ enable_summary = True
336
+ last_status = self.function_result_summary(
337
+ conversation, status=last_status, enable_summary=enable_summary)
338
+
339
+ if function_call_messages:
340
+ conversation.extend(function_call_messages)
341
+ outputs.append(tool_result_format(function_call_messages))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  else:
343
+ next_round = False
344
+ return ''.join(last_outputs).replace("</s>", "")
345
+
346
+ if self.enable_checker:
347
+ good_status, wrong_info = checker.check_conversation()
348
+ if not good_status:
349
+ logger.warning("Checker error: %s", wrong_info)
350
+ break
351
+
352
+ last_outputs = []
353
+ last_outputs_str, token_overflow = self.llm_infer(
354
+ messages=conversation, temperature=temperature, tools=picked_tools_prompt,
355
+ max_new_tokens=max_new_tokens, max_token=max_token, check_token_status=True)
356
+ if last_outputs_str is None:
357
+ if self.force_finish:
358
+ return self.get_answer_based_on_unfinished_reasoning(
359
+ conversation, temperature, max_new_tokens, max_token)
360
+ return "❌ Token limit exceeded."
361
+ last_outputs.append(last_outputs_str)
362
 
363
+ if current_round >= max_round:
364
+ logger.warning("Max rounds exceeded")
365
+ if self.force_finish:
366
+ return self.get_answer_based_on_unfinished_reasoning(
367
+ conversation, temperature, max_new_tokens, max_token)
368
+ return None
369
 
370
  def build_logits_processor(self, messages, llm):
 
371
  tokenizer = llm.get_tokenizer()
372
  if self.avoid_repeat and len(messages) > 2:
373
+ assistant_messages = [
374
+ m['content'] for m in messages[-3:] if m['role'] == 'assistant'
375
+ ][:2]
376
+ forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
 
 
 
 
377
  return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
378
+ return None
 
379
 
380
+ def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
381
+ max_new_tokens=512, max_token=2048, skip_special_tokens=True,
 
382
  model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
 
383
  if model is None:
384
  model = self.model
385
 
 
387
  sampling_params = SamplingParams(
388
  temperature=temperature,
389
  max_tokens=max_new_tokens,
 
390
  seed=seed if seed is not None else self.seed,
391
+ logits_processors=logits_processor
392
  )
393
 
394
+ prompt = self.chat_template.render(messages=messages, tools=tools, add_generation_prompt=True)
395
+ if output_begin_string:
 
396
  prompt += output_begin_string
397
 
398
+ if check_token_status and max_token:
399
+ num_input_tokens = len(self.tokenizer.encode(prompt, return_tensors="pt")[0])
400
+ if num_input_tokens > max_token:
401
+ torch.cuda.empty_cache()
402
+ gc.collect()
403
+ logger.info("Token overflow: %d > %d", num_input_tokens, max_token)
404
+ return None, True
405
+ logger.debug("Input tokens: %d", num_input_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
+ output = model.generate(prompt, sampling_params=sampling_params)
408
+ output = output[0].outputs[0].text
409
+ logger.debug("Inference output: %s", output[:100])
410
+ torch.cuda.empty_cache() # Clear CUDA cache
411
+ if check_token_status:
412
+ return output, False
413
  return output
414
 
415
+ def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
416
+ logger.debug("Starting self agent")
417
+ conversation = self.set_system_prompt([], self.self_prompt)
 
 
 
 
 
418
  conversation.append({"role": "user", "content": message})
419
+ return self.llm_infer(messages=conversation, temperature=temperature,
 
 
420
  max_new_tokens=max_new_tokens, max_token=max_token)
421
 
422
+ def run_chat_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
423
+ logger.debug("Starting chat agent")
424
+ conversation = self.set_system_prompt([], self.chat_prompt)
 
 
 
 
 
425
  conversation.append({"role": "user", "content": message})
426
+ return self.llm_infer(messages=conversation, temperature=temperature,
 
 
427
  max_new_tokens=max_new_tokens, max_token=max_token)
428
 
429
+ def run_format_agent(self, message: str, answer: str, temperature: float, max_new_tokens: int, max_token: int):
430
+ logger.debug("Starting format agent")
 
 
 
 
 
431
  if '[FinalAnswer]' in answer:
432
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
433
  elif "\n\n" in answer:
434
  possible_final_answer = answer.split("\n\n")[-1]
435
  else:
436
  possible_final_answer = answer.strip()
437
+
438
+ if len(possible_final_answer) >= 1 and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
439
+ return possible_final_answer[0]
440
+ elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
441
+ return possible_final_answer[0]
442
+
443
+ conversation = self.set_system_prompt(
444
+ [], "Transform the answer to a single letter: 'A', 'B', 'C', 'D', or 'E'.")
445
+ conversation.append({"role": "user", "content": f"Original: {message}\nAnswer: {answer}\nFinal answer (letter):"})
446
+ return self.llm_infer(messages=conversation, temperature=temperature,
 
 
 
 
 
 
 
 
 
447
  max_new_tokens=max_new_tokens, max_token=max_token)
448
 
449
+ def run_summary_agent(self, thought_calls: str, function_response: str,
450
+ temperature: float, max_new_tokens: int, max_token: int):
451
+ logger.debug("Starting summary agent")
452
+ prompt = f"""Thought and function calls: {thought_calls}
453
+ Function responses: {function_response}
454
+ Summarize the function responses in one sentence with all necessary information."""
455
+ conversation = [{"role": "user", "content": prompt}]
456
+ output = self.llm_infer(messages=conversation, temperature=temperature,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  max_new_tokens=max_new_tokens, max_token=max_token)
 
458
  if '[' in output:
459
  output = output.split('[')[0]
460
  return output
461
 
462
  def function_result_summary(self, input_list, status, enable_summary):
 
 
 
 
 
 
 
 
 
 
 
463
  if 'tool_call_step' not in status:
464
  status['tool_call_step'] = 0
465
+ if 'step' not in status:
466
+ status['step'] = 0
467
+ status['step'] += 1
468
 
469
  for idx in range(len(input_list)):
470
+ pos_id = len(input_list) - idx - 1
471
+ if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]:
472
+ if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
473
+ status['tool_call_step'] += 1
 
474
  break
475
 
 
 
 
 
 
476
  if not enable_summary:
477
  return status
478
 
479
  if 'summarized_index' not in status:
480
  status['summarized_index'] = 0
 
481
  if 'summarized_step' not in status:
482
  status['summarized_step'] = 0
 
483
  if 'previous_length' not in status:
484
  status['previous_length'] = 0
 
485
  if 'history' not in status:
486
  status['history'] = []
487
 
488
+ status['history'].append(
489
+ self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k)
 
 
 
 
490
 
491
+ idx = status['summarized_index']
492
+ function_response = ''
493
+ this_thought_calls = None
494
  while idx < len(input_list):
495
+ if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \
496
+ (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
497
  if input_list[idx]['role'] == 'assistant':
498
  if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
499
  this_thought_calls = None
500
  else:
501
+ if function_response:
 
502
  status['summarized_step'] += 1
503
  result_summary = self.run_summary_agent(
504
+ thought_calls=this_thought_calls, function_response=function_response,
505
+ temperature=0.1, max_new_tokens=512, max_token=2048)
 
 
 
 
 
506
  input_list.insert(
507
+ last_call_idx + 1, {'role': 'tool', 'content': result_summary})
508
  status['summarized_index'] = last_call_idx + 2
509
  idx += 1
 
510
  last_call_idx = idx
511
+ this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls']
 
512
  function_response = ''
513
+ elif input_list[idx]['role'] == 'tool' and this_thought_calls:
 
514
  function_response += input_list[idx]['content']
515
  del input_list[idx]
516
  idx -= 1
 
517
  else:
518
  break
519
  idx += 1
520
 
521
+ if function_response:
522
  status['summarized_step'] += 1
523
  result_summary = self.run_summary_agent(
524
+ thought_calls=this_thought_calls, function_response=function_response,
525
+ temperature=0.1, max_new_tokens=512, max_token=2048)
 
 
 
 
 
526
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
527
  for tool_call in tool_calls:
528
  del tool_call['call_id']
529
  input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
530
  input_list.insert(
531
+ last_call_idx + 1, {'role': 'tool', 'content': result_summary})
532
  status['summarized_index'] = last_call_idx + 2
533
 
534
  return status
535
 
 
 
 
536
  def update_parameters(self, **kwargs):
537
+ updated_attributes = {}
538
  for key, value in kwargs.items():
539
  if hasattr(self, key):
540
  setattr(self, key, value)
541
+ updated_attributes[key] = value
542
+ logger.debug("Updated parameters: %s", updated_attributes)
 
 
543
  return updated_attributes
544
 
545
+ def run_gradio_chat(self, message: str, history: list, temperature: float,
546
+ max_new_tokens: int, max_token: int, call_agent: bool,
547
+ conversation: gr.State, max_round: int = 10, seed: int = None,
548
+ call_agent_level: int = 0, sub_agent_task: str = None,
549
+ uploaded_files: list = None):
550
+ logger.debug("Chat started, message: %s", message[:100])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  if not message or len(message.strip()) < 5:
552
  yield "Please provide a valid message or upload files to analyze."
553
+ return
554
 
555
  if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
556
+ return
 
 
 
 
557
 
558
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
559
+ call_agent, call_agent_level, message)
 
 
 
560
  conversation = self.initialize_conversation(
561
+ message, conversation, history)
 
 
562
  history = []
563
 
564
  next_round = True
 
565
  current_round = 0
566
  enable_summary = False
567
  last_status = {}
568
  token_overflow = False
569
 
570
  if self.enable_checker:
571
+ checker = ReasoningTraceChecker(message, conversation, init_index=len(conversation))
 
572
 
573
  try:
574
  while next_round and current_round < max_round:
575
  current_round += 1
576
+ last_outputs = []
 
577
  if last_outputs:
578
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
579
+ last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
580
+ message_for_call_agent=message, call_agent=call_agent,
581
+ call_agent_level=call_agent_level, temperature=temperature)
 
 
 
 
582
  history.extend(current_gradio_history)
583
 
584
+ if special_tool_call == 'Finish':
585
  yield history
586
  next_round = False
587
  conversation.extend(function_call_messages)
588
  return function_call_messages[0]['content']
589
 
590
+ if special_tool_call in ['RequireClarification', 'DirectResponse']:
591
  last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
592
  history.append(ChatMessage(role="assistant", content=last_msg.content))
593
  yield history
 
596
 
597
  if (self.enable_summary or token_overflow) and not call_agent:
598
  enable_summary = True
 
599
  last_status = self.function_result_summary(
600
+ conversation, status=last_status, enable_summary=enable_summary)
 
601
 
602
  if function_call_messages:
603
  conversation.extend(function_call_messages)
604
  yield history
605
  else:
606
  next_round = False
 
607
  return ''.join(last_outputs).replace("</s>", "")
608
 
609
  if self.enable_checker:
610
  good_status, wrong_info = checker.check_conversation()
611
  if not good_status:
612
+ logger.warning("Checker error: %s", wrong_info)
613
  break
614
 
 
615
  last_outputs_str, token_overflow = self.llm_infer(
616
+ messages=conversation, temperature=temperature, tools=picked_tools_prompt,
617
+ max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
 
 
 
 
 
 
 
 
618
 
619
  if last_outputs_str is None:
 
620
  if self.force_finish:
621
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
622
  conversation, temperature, max_new_tokens, max_token)
623
  history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
624
  yield history
625
  return last_outputs_str
626
+ error_msg = "Token limit exceeded."
627
+ history.append(ChatMessage(role="assistant", content=error_msg))
628
+ yield history
629
+ return error_msg
 
630
 
631
  last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
 
632
  for msg in history:
633
+ if msg.metadata:
634
  msg.metadata['status'] = 'done'
635
 
636
  if '[FinalAnswer]' in last_thought:
637
  parts = last_thought.split('[FinalAnswer]', 1)
638
+ final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
 
 
 
639
  history.append(ChatMessage(role="assistant", content=final_thought.strip()))
640
  yield history
641
  history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
 
646
 
647
  last_outputs.append(last_outputs_str)
648
 
649
+ if next_round and self.force_finish:
650
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
651
+ conversation, temperature, max_new_tokens, max_token)
652
+ parts = last_outputs_str.split('[FinalAnswer]', 1)
653
+ final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
654
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
655
+ yield history
656
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
657
+ yield history
 
 
 
 
 
 
 
 
 
 
658
 
659
  except Exception as e:
660
+ logger.error("Exception in run_gradio_chat: %s", e, exc_info=True)
661
+ error_msg = f"Error: {e}"
662
  history.append(ChatMessage(role="assistant", content=error_msg))
663
  yield history
664
  if self.force_finish:
665
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
666
  conversation, temperature, max_new_tokens, max_token)
667
+ parts = last_outputs_str.split('[FinalAnswer]', 1)
668
+ final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
669
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
670
+ yield history
671
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
672
+ yield history
 
 
 
 
 
 
 
673
  return error_msg