Ali2206 commited on
Commit
921328d
·
verified ·
1 Parent(s): 6741b3e

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +330 -272
src/txagent/txagent.py CHANGED
@@ -24,17 +24,17 @@ class TxAgent:
24
  rag_model_name,
25
  tool_files_dict=None,
26
  enable_finish=True,
27
- enable_rag=True,
28
  enable_summary=False,
29
- init_rag_num=2, # Reduced for faster initial tool selection
30
- step_rag_num=4, # Reduced for fewer RAG calls
31
  summary_mode='step',
32
  summary_skip_last_k=0,
33
  summary_context_length=None,
34
  force_finish=True,
35
  avoid_repeat=True,
36
  seed=None,
37
- enable_checker=False, # Disabled by default for speed
38
  enable_chat=False,
39
  additional_default_tools=None):
40
  self.model_name = model_name
@@ -45,9 +45,9 @@ class TxAgent:
45
  self.model = None
46
  self.rag_model = ToolRAGModel(rag_model_name)
47
  self.tooluniverse = None
48
- self.prompt_multi_step = "You are a medical assistant solving clinical oversight issues step-by-step using provided tools."
49
- self.self_prompt = "Follow instructions precisely."
50
- self.chat_prompt = "You are a helpful assistant for clinical queries."
51
  self.enable_finish = enable_finish
52
  self.enable_rag = enable_rag
53
  self.enable_summary = enable_summary
@@ -61,28 +61,23 @@ class TxAgent:
61
  self.seed = seed
62
  self.enable_checker = enable_checker
63
  self.additional_default_tools = additional_default_tools
64
- logger.debug("TxAgent initialized with parameters: %s", self.__dict__)
65
 
66
  def init_model(self):
67
  self.load_models()
68
  self.load_tooluniverse()
69
- self.load_tool_desc_embedding()
70
-
71
- def print_self_values(self):
72
- for attr, value in self.__dict__.items():
73
- logger.debug("%s: %s", attr, value)
74
 
75
  def load_models(self, model_name=None):
76
- if model_name is not None and model_name == self.model_name:
77
- return f"The model {model_name} is already loaded."
78
- if model_name:
79
  self.model_name = model_name
80
 
81
- self.model = LLM(model=self.model_name, dtype="float16") # Enable FP16
82
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
83
  self.tokenizer = self.model.get_tokenizer()
84
  logger.info("Model %s loaded successfully", self.model_name)
85
- return f"Model {self.model_name} loaded successfully."
86
 
87
  def load_tooluniverse(self):
88
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
@@ -93,7 +88,12 @@ class TxAgent:
93
  logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name))
94
 
95
  def load_tool_desc_embedding(self):
96
- self.rag_model.load_tool_desc_embedding(self.tooluniverse)
 
 
 
 
 
97
  logger.debug("Tool description embeddings loaded")
98
 
99
  def rag_infer(self, query, top_k=5):
@@ -107,39 +107,43 @@ class TxAgent:
107
  call_agent_level += 1
108
  if call_agent_level >= 2:
109
  call_agent = False
110
-
111
- if not call_agent and self.enable_rag:
112
- picked_tools_prompt += self.tool_RAG(
113
- message=message, rag_num=self.init_rag_num)
114
  return picked_tools_prompt, call_agent_level
115
 
116
  def initialize_conversation(self, message, conversation=None, history=None):
117
  if conversation is None:
118
  conversation = []
119
 
120
- conversation = self.set_system_prompt(conversation, self.prompt_multi_step)
 
121
  if history:
122
- conversation.extend(
123
- {"role": h['role'], "content": h['content']}
124
- for h in history if h['role'] in ['user', 'assistant']
125
- )
 
126
  conversation.append({"role": "user", "content": message})
127
  logger.debug("Conversation initialized with %d messages", len(conversation))
128
  return conversation
129
 
130
- def tool_RAG(self, message=None, picked_tool_names=None,
131
- existing_tools_prompt=None, rag_num=4, return_call_result=False):
132
- extra_factor = 10 # Reduced from 30 for efficiency
 
 
 
 
 
133
  if picked_tool_names is None:
134
- picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor)
 
 
 
 
 
135
 
136
- picked_tool_names = [
137
- tool for tool in picked_tool_names
138
- if tool not in self.special_tools_name
139
- ][:rag_num]
140
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
141
  picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
142
- logger.debug("RAG selected %d tools: %s", len(picked_tool_names), picked_tool_names)
143
  if return_call_result:
144
  return picked_tools_prompt, picked_tool_names
145
  return picked_tools_prompt
@@ -151,15 +155,6 @@ class TxAgent:
151
  if call_agent:
152
  tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
153
  logger.debug("CallAgent tool added")
154
- elif self.enable_rag:
155
- tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
156
- logger.debug("Tool_RAG tool added")
157
- if self.additional_default_tools:
158
- for tool_name in self.additional_default_tools:
159
- tool_prompt = self.tooluniverse.get_one_tool_by_one_name(tool_name, return_prompt=True)
160
- if tool_prompt:
161
- tools.append(tool_prompt)
162
- logger.debug("%s tool added", tool_name)
163
  return tools
164
 
165
  def add_finish_tools(self, tools):
@@ -174,43 +169,51 @@ class TxAgent:
174
  conversation[0] = {"role": "system", "content": sys_prompt}
175
  return conversation
176
 
177
- def run_function_call(self, fcall_str, return_message=False,
178
- existing_tools_prompt=None, message_for_call_agent=None,
179
- call_agent=False, call_agent_level=None, temperature=None):
 
 
 
 
180
  function_call_json, message = self.tooluniverse.extract_function_call_json(
181
  fcall_str, return_message=return_message, verbose=False)
182
  call_results = []
183
  special_tool_call = ''
184
  if function_call_json:
185
- for func in function_call_json if isinstance(function_call_json, list) else [function_call_json]:
186
- logger.debug("Tool Call: %s", func)
187
- if func["name"] == 'Finish':
188
- special_tool_call = 'Finish'
189
- break
190
- elif func["name"] == 'Tool_RAG':
191
- new_tools_prompt, call_result = self.tool_RAG(
192
- message=message, existing_tools_prompt=existing_tools_prompt,
193
- rag_num=self.step_rag_num, return_call_result=True)
194
- existing_tools_prompt += new_tools_prompt
195
- elif func["name"] == 'CallAgent' and call_agent and call_agent_level < 2:
196
- solution_plan = func['arguments']['solution']
197
- full_message = (
198
- message_for_call_agent + "\nFollow this plan: " + str(solution_plan)
199
- )
200
- call_result = self.run_multistep_agent(
201
- full_message, temperature=temperature, max_new_tokens=512,
202
- max_token=2048, call_agent=False, call_agent_level=call_agent_level)
203
- call_result = call_result.split('[FinalAnswer]')[-1].strip() if call_result else "⚠️ No content from sub-agent."
204
- else:
205
- call_result = self.tooluniverse.run_one_function(func)
206
-
207
- call_id = self.tooluniverse.call_id_gen()
208
- func["call_id"] = call_id
209
- logger.debug("Tool Call Result: %s", call_result)
210
- call_results.append({
211
- "role": "tool",
212
- "content": json.dumps({"tool_name": func["name"], "content": call_result, "call_id": call_id})
213
- })
 
 
 
 
214
  else:
215
  call_results.append({
216
  "role": "tool",
@@ -219,63 +222,68 @@ class TxAgent:
219
 
220
  revised_messages = [{
221
  "role": "assistant",
222
- "content": message.strip() if message else "",
223
  "tool_calls": json.dumps(function_call_json)
224
  }] + call_results
225
  return revised_messages, existing_tools_prompt, special_tool_call
226
 
227
- def run_function_call_stream(self, fcall_str, return_message=False,
228
- existing_tools_prompt=None, message_for_call_agent=None,
229
- call_agent=False, call_agent_level=None, temperature=None,
230
- return_gradio_history=True):
 
 
 
 
231
  function_call_json, message = self.tooluniverse.extract_function_call_json(
232
  fcall_str, return_message=return_message, verbose=False)
233
  call_results = []
234
  special_tool_call = ''
235
- gradio_history = [] if return_gradio_history else None
 
236
  if function_call_json:
237
- for func in function_call_json if isinstance(function_call_json, list) else [function_call_json]:
238
- if func["name"] == 'Finish':
239
- special_tool_call = 'Finish'
240
- break
241
- elif func["name"] == 'Tool_RAG':
242
- new_tools_prompt, call_result = self.tool_RAG(
243
- message=message, existing_tools_prompt=existing_tools_prompt,
244
- rag_num=self.step_rag_num, return_call_result=True)
245
- existing_tools_prompt += new_tools_prompt
246
- elif func["name"] == 'DirectResponse':
247
- call_result = func['arguments']['response']
248
- special_tool_call = 'DirectResponse'
249
- elif func["name"] == 'RequireClarification':
250
- call_result = func['arguments']['unclear_question']
251
- special_tool_call = 'RequireClarification'
252
- elif func["name"] == 'CallAgent' and call_agent and call_agent_level < 2:
253
- solution_plan = func['arguments']['solution']
254
- full_message = (
255
- message_for_call_agent + "\nFollow this plan: " + str(solution_plan)
256
- )
257
- sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
258
- call_result = yield from self.run_gradio_chat(
259
- full_message, history=[], temperature=temperature,
260
- max_new_tokens=512, max_token=2048, call_agent=False,
261
- call_agent_level=call_agent_level, conversation=None,
262
- sub_agent_task=sub_agent_task)
263
- call_result = call_result.split('[FinalAnswer]')[-1] if call_result else "⚠️ No content from sub-agent."
264
- else:
265
- call_result = self.tooluniverse.run_one_function(func)
266
-
267
- call_id = self.tooluniverse.call_id_gen()
268
- func["call_id"] = call_id
269
- call_results.append({
270
- "role": "tool",
271
- "content": json.dumps({"tool_name": func["name"], "content": call_result, "call_id": call_id})
272
- })
273
- if return_gradio_history and func["name"] != 'Finish':
274
- title = f"{'🧰' if func['name'] == 'Tool_RAG' else '⚒️'} {func['name']}"
275
- gradio_history.append(ChatMessage(
276
- role="assistant", content=str(call_result),
277
- metadata={"title": title, "log": str(func['arguments'])}
278
- ))
279
  else:
280
  call_results.append({
281
  "role": "tool",
@@ -284,25 +292,37 @@ class TxAgent:
284
 
285
  revised_messages = [{
286
  "role": "assistant",
287
- "content": message.strip() if message else "",
288
  "tool_calls": json.dumps(function_call_json)
289
  }] + call_results
290
- return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
 
 
291
 
292
- def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token):
293
  if conversation[-1]['role'] == 'assistant':
294
  conversation.append(
295
- {'role': 'tool', 'content': 'Errors occurred; provide final answer with current info.'})
296
  finish_tools_prompt = self.add_finish_tools([])
297
- output = self.llm_infer(
298
- messages=conversation, temperature=temperature, tools=finish_tools_prompt,
299
- output_begin_string='[FinalAnswer]', max_new_tokens=max_new_tokens, max_token=max_token)
300
- logger.debug("Unfinished reasoning output: %s", output)
301
- return output
302
-
303
- def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int,
304
- max_token: int, max_round: int = 10, call_agent=False, call_agent_level=0):
305
- logger.debug("Starting multistep agent for message: %s", message[:100])
 
 
 
 
 
 
 
 
 
 
306
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
307
  call_agent, call_agent_level, message)
308
  conversation = self.initialize_conversation(message)
@@ -314,53 +334,57 @@ class TxAgent:
314
  enable_summary = False
315
  last_status = {}
316
 
317
- if self.enable_checker:
318
- checker = ReasoningTraceChecker(message, conversation)
319
-
320
  while next_round and current_round < max_round:
321
  current_round += 1
322
- if last_outputs:
323
  function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
324
- last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
325
- message_for_call_agent=message, call_agent=call_agent,
326
- call_agent_level=call_agent_level, temperature=temperature)
 
 
 
327
 
328
  if special_tool_call == 'Finish':
329
  next_round = False
330
  conversation.extend(function_call_messages)
331
  content = function_call_messages[0]['content']
332
- return content.split('[FinalAnswer]')[-1] if content else "❌ No content after Finish."
 
 
333
 
334
  if (self.enable_summary or token_overflow) and not call_agent:
335
  enable_summary = True
336
  last_status = self.function_result_summary(
337
- conversation, status=last_status, enable_summary=enable_summary)
338
 
339
  if function_call_messages:
340
  conversation.extend(function_call_messages)
341
  outputs.append(tool_result_format(function_call_messages))
342
  else:
343
  next_round = False
 
344
  return ''.join(last_outputs).replace("</s>", "")
345
 
346
- if self.enable_checker:
347
- good_status, wrong_info = checker.check_conversation()
348
- if not good_status:
349
- logger.warning("Checker error: %s", wrong_info)
350
- break
351
-
352
  last_outputs = []
 
353
  last_outputs_str, token_overflow = self.llm_infer(
354
- messages=conversation, temperature=temperature, tools=picked_tools_prompt,
355
- max_new_tokens=max_new_tokens, max_token=max_token, check_token_status=True)
 
 
 
 
 
356
  if last_outputs_str is None:
 
357
  if self.force_finish:
358
  return self.get_answer_based_on_unfinished_reasoning(
359
  conversation, temperature, max_new_tokens, max_token)
360
  return "❌ Token limit exceeded."
361
  last_outputs.append(last_outputs_str)
362
 
363
- if current_round >= max_round:
364
  logger.warning("Max rounds exceeded")
365
  if self.force_finish:
366
  return self.get_answer_based_on_unfinished_reasoning(
@@ -370,16 +394,16 @@ class TxAgent:
370
  def build_logits_processor(self, messages, llm):
371
  tokenizer = llm.get_tokenizer()
372
  if self.avoid_repeat and len(messages) > 2:
373
- assistant_messages = [
374
- m['content'] for m in messages[-3:] if m['role'] == 'assistant'
375
- ][:2]
376
  forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
377
  return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
378
  return None
379
 
380
- def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
381
- max_new_tokens=512, max_token=2048, skip_special_tokens=True,
382
- model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
 
 
383
  if model is None:
384
  model = self.model
385
 
@@ -388,73 +412,108 @@ class TxAgent:
388
  temperature=temperature,
389
  max_tokens=max_new_tokens,
390
  seed=seed if seed is not None else self.seed,
391
- logits_processors=logits_processor
392
  )
393
 
394
- prompt = self.chat_template.render(messages=messages, tools=tools, add_generation_prompt=True)
395
- if output_begin_string:
 
396
  prompt += output_begin_string
397
 
398
- if check_token_status and max_token:
 
399
  num_input_tokens = len(self.tokenizer.encode(prompt, return_tensors="pt")[0])
400
  if num_input_tokens > max_token:
401
  torch.cuda.empty_cache()
402
  gc.collect()
403
  logger.info("Token overflow: %d > %d", num_input_tokens, max_token)
404
  return None, True
405
- logger.debug("Input tokens: %d", num_input_tokens)
406
 
407
  output = model.generate(prompt, sampling_params=sampling_params)
408
  output = output[0].outputs[0].text
409
  logger.debug("Inference output: %s", output[:100])
410
- torch.cuda.empty_cache() # Clear CUDA cache
411
- if check_token_status:
412
- return output, False
 
413
  return output
414
 
415
- def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
416
- logger.debug("Starting self agent")
 
 
 
417
  conversation = self.set_system_prompt([], self.self_prompt)
418
  conversation.append({"role": "user", "content": message})
419
- return self.llm_infer(messages=conversation, temperature=temperature,
420
- max_new_tokens=max_new_tokens, max_token=max_token)
421
-
422
- def run_chat_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
423
- logger.debug("Starting chat agent")
 
 
 
 
 
 
 
424
  conversation = self.set_system_prompt([], self.chat_prompt)
425
  conversation.append({"role": "user", "content": message})
426
- return self.llm_infer(messages=conversation, temperature=temperature,
427
- max_new_tokens=max_new_tokens, max_token=max_token)
428
-
429
- def run_format_agent(self, message: str, answer: str, temperature: float, max_new_tokens: int, max_token: int):
430
- logger.debug("Starting format agent")
 
 
 
 
 
 
 
 
431
  if '[FinalAnswer]' in answer:
432
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
433
  elif "\n\n" in answer:
434
  possible_final_answer = answer.split("\n\n")[-1]
435
  else:
436
  possible_final_answer = answer.strip()
437
-
438
- if len(possible_final_answer) >= 1 and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
439
- return possible_final_answer[0]
440
  elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
441
  return possible_final_answer[0]
442
 
443
  conversation = self.set_system_prompt(
444
- [], "Transform the answer to a single letter: 'A', 'B', 'C', 'D', or 'E'.")
445
- conversation.append({"role": "user", "content": f"Original: {message}\nAnswer: {answer}\nFinal answer (letter):"})
446
- return self.llm_infer(messages=conversation, temperature=temperature,
447
- max_new_tokens=max_new_tokens, max_token=max_token)
448
-
449
- def run_summary_agent(self, thought_calls: str, function_response: str,
450
- temperature: float, max_new_tokens: int, max_token: int):
451
- logger.debug("Starting summary agent")
452
- prompt = f"""Thought and function calls: {thought_calls}
453
- Function responses: {function_response}
454
- Summarize the function responses in one sentence with all necessary information."""
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  conversation = [{"role": "user", "content": prompt}]
456
- output = self.llm_infer(messages=conversation, temperature=temperature,
457
- max_new_tokens=max_new_tokens, max_token=max_token)
 
 
 
 
458
  if '[' in output:
459
  output = output.split('[')[0]
460
  return output
@@ -462,55 +521,43 @@ Summarize the function responses in one sentence with all necessary information.
462
  def function_result_summary(self, input_list, status, enable_summary):
463
  if 'tool_call_step' not in status:
464
  status['tool_call_step'] = 0
465
- if 'step' not in status:
466
- status['step'] = 0
467
- status['step'] += 1
468
-
469
  for idx in range(len(input_list)):
470
  pos_id = len(input_list) - idx - 1
471
  if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]:
472
- if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
473
- status['tool_call_step'] += 1
474
  break
475
 
 
476
  if not enable_summary:
477
  return status
478
 
479
- if 'summarized_index' not in status:
480
- status['summarized_index'] = 0
481
- if 'summarized_step' not in status:
482
- status['summarized_step'] = 0
483
- if 'previous_length' not in status:
484
- status['previous_length'] = 0
485
- if 'history' not in status:
486
- status['history'] = []
487
 
488
- status['history'].append(
489
- self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k)
490
-
491
- idx = status['summarized_index']
492
  function_response = ''
 
493
  this_thought_calls = None
 
494
  while idx < len(input_list):
495
  if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \
496
  (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
497
  if input_list[idx]['role'] == 'assistant':
498
- if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
499
- this_thought_calls = None
500
- else:
501
- if function_response:
502
- status['summarized_step'] += 1
503
- result_summary = self.run_summary_agent(
504
- thought_calls=this_thought_calls, function_response=function_response,
505
- temperature=0.1, max_new_tokens=512, max_token=2048)
506
- input_list.insert(
507
- last_call_idx + 1, {'role': 'tool', 'content': result_summary})
508
- status['summarized_index'] = last_call_idx + 2
509
- idx += 1
510
- last_call_idx = idx
511
- this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls']
512
- function_response = ''
513
- elif input_list[idx]['role'] == 'tool' and this_thought_calls:
514
  function_response += input_list[idx]['content']
515
  del input_list[idx]
516
  idx -= 1
@@ -521,14 +568,16 @@ Summarize the function responses in one sentence with all necessary information.
521
  if function_response:
522
  status['summarized_step'] += 1
523
  result_summary = self.run_summary_agent(
524
- thought_calls=this_thought_calls, function_response=function_response,
525
- temperature=0.1, max_new_tokens=512, max_token=2048)
 
 
 
526
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
527
  for tool_call in tool_calls:
528
  del tool_call['call_id']
529
  input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
530
- input_list.insert(
531
- last_call_idx + 1, {'role': 'tool', 'content': result_summary})
532
  status['summarized_index'] = last_call_idx + 2
533
 
534
  return status
@@ -539,22 +588,26 @@ Summarize the function responses in one sentence with all necessary information.
539
  if hasattr(self, key):
540
  setattr(self, key, value)
541
  updated_attributes[key] = value
542
- logger.debug("Updated parameters: %s", updated_attributes)
543
  return updated_attributes
544
 
545
- def run_gradio_chat(self, message: str, history: list, temperature: float,
546
- max_new_tokens: int, max_token: int, call_agent: bool,
547
- conversation: gr.State, max_round: int = 10, seed: int = None,
548
- call_agent_level: int = 0, sub_agent_task: str = None,
 
 
 
 
 
 
 
549
  uploaded_files: list = None):
550
- logger.debug("Chat started, message: %s", message[:100])
551
  if not message or len(message.strip()) < 5:
552
  yield "Please provide a valid message or upload files to analyze."
553
  return
554
 
555
- if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
556
- return
557
-
558
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
559
  call_agent, call_agent_level, message)
560
  conversation = self.initialize_conversation(
@@ -567,18 +620,17 @@ Summarize the function responses in one sentence with all necessary information.
567
  last_status = {}
568
  token_overflow = False
569
 
570
- if self.enable_checker:
571
- checker = ReasoningTraceChecker(message, conversation, init_index=len(conversation))
572
-
573
  try:
574
  while next_round and current_round < max_round:
575
  current_round += 1
576
- last_outputs = []
577
  if last_outputs:
578
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
579
- last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
580
- message_for_call_agent=message, call_agent=call_agent,
581
- call_agent_level=call_agent_level, temperature=temperature)
 
 
 
582
  history.extend(current_gradio_history)
583
 
584
  if special_tool_call == 'Finish':
@@ -587,7 +639,7 @@ Summarize the function responses in one sentence with all necessary information.
587
  conversation.extend(function_call_messages)
588
  return function_call_messages[0]['content']
589
 
590
- if special_tool_call in ['RequireClarification', 'DirectResponse']:
591
  last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
592
  history.append(ChatMessage(role="assistant", content=last_msg.content))
593
  yield history
@@ -604,19 +656,22 @@ Summarize the function responses in one sentence with all necessary information.
604
  yield history
605
  else:
606
  next_round = False
 
607
  return ''.join(last_outputs).replace("</s>", "")
608
 
609
- if self.enable_checker:
610
- good_status, wrong_info = checker.check_conversation()
611
- if not good_status:
612
- logger.warning("Checker error: %s", wrong_info)
613
- break
614
-
615
  last_outputs_str, token_overflow = self.llm_infer(
616
- messages=conversation, temperature=temperature, tools=picked_tools_prompt,
617
- max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
 
 
 
 
 
 
618
 
619
  if last_outputs_str is None:
 
620
  if self.force_finish:
621
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
622
  conversation, temperature, max_new_tokens, max_token)
@@ -630,7 +685,7 @@ Summarize the function responses in one sentence with all necessary information.
630
 
631
  last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
632
  for msg in history:
633
- if msg.metadata:
634
  msg.metadata['status'] = 'done'
635
 
636
  if '[FinalAnswer]' in last_thought:
@@ -646,15 +701,18 @@ Summarize the function responses in one sentence with all necessary information.
646
 
647
  last_outputs.append(last_outputs_str)
648
 
649
- if next_round and self.force_finish:
650
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
651
- conversation, temperature, max_new_tokens, max_token)
652
- parts = last_outputs_str.split('[FinalAnswer]', 1)
653
- final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
654
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
655
- yield history
656
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
657
- yield history
 
 
 
658
 
659
  except Exception as e:
660
  logger.error("Exception in run_gradio_chat: %s", e, exc_info=True)
 
24
  rag_model_name,
25
  tool_files_dict=None,
26
  enable_finish=True,
27
+ enable_rag=False,
28
  enable_summary=False,
29
+ init_rag_num=0,
30
+ step_rag_num=0,
31
  summary_mode='step',
32
  summary_skip_last_k=0,
33
  summary_context_length=None,
34
  force_finish=True,
35
  avoid_repeat=True,
36
  seed=None,
37
+ enable_checker=False,
38
  enable_chat=False,
39
  additional_default_tools=None):
40
  self.model_name = model_name
 
45
  self.model = None
46
  self.rag_model = ToolRAGModel(rag_model_name)
47
  self.tooluniverse = None
48
+ self.prompt_multi_step = "You are a helpful assistant that solves problems through step-by-step reasoning."
49
+ self.self_prompt = "Strictly follow the instruction."
50
+ self.chat_prompt = "You are a helpful assistant for user chat."
51
  self.enable_finish = enable_finish
52
  self.enable_rag = enable_rag
53
  self.enable_summary = enable_summary
 
61
  self.seed = seed
62
  self.enable_checker = enable_checker
63
  self.additional_default_tools = additional_default_tools
64
+ logger.info("TxAgent initialized with model: %s, RAG: %s", model_name, rag_model_name)
65
 
66
  def init_model(self):
67
  self.load_models()
68
  self.load_tooluniverse()
 
 
 
 
 
69
 
70
  def load_models(self, model_name=None):
71
+ if model_name is not None:
72
+ if model_name == self.model_name:
73
+ return f"The model {model_name} is already loaded."
74
  self.model_name = model_name
75
 
76
+ self.model = LLM(model=self.model_name, dtype="float16", max_model_len=512, gpu_memory_utilization=0.8)
77
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
78
  self.tokenizer = self.model.get_tokenizer()
79
  logger.info("Model %s loaded successfully", self.model_name)
80
+ return f"Model {model_name} loaded successfully."
81
 
82
  def load_tooluniverse(self):
83
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
 
88
  logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name))
89
 
90
  def load_tool_desc_embedding(self):
91
+ cache_path = os.path.join(os.path.dirname(self.tool_files_dict["new_tool"]), "tool_embeddings.pkl")
92
+ if os.path.exists(cache_path):
93
+ self.rag_model.load_cached_embeddings(cache_path)
94
+ else:
95
+ self.rag_model.load_tool_desc_embedding(self.tooluniverse)
96
+ self.rag_model.save_embeddings(cache_path)
97
  logger.debug("Tool description embeddings loaded")
98
 
99
  def rag_infer(self, query, top_k=5):
 
107
  call_agent_level += 1
108
  if call_agent_level >= 2:
109
  call_agent = False
 
 
 
 
110
  return picked_tools_prompt, call_agent_level
111
 
112
  def initialize_conversation(self, message, conversation=None, history=None):
113
  if conversation is None:
114
  conversation = []
115
 
116
+ conversation = self.set_system_prompt(
117
+ conversation, self.prompt_multi_step)
118
  if history:
119
+ for i in range(len(history)):
120
+ if history[i]['role'] == 'user':
121
+ conversation.append({"role": "user", "content": history[i]['content']})
122
+ elif history[i]['role'] == 'assistant':
123
+ conversation.append({"role": "assistant", "content": history[i]['content']})
124
  conversation.append({"role": "user", "content": message})
125
  logger.debug("Conversation initialized with %d messages", len(conversation))
126
  return conversation
127
 
128
+ def tool_RAG(self, message=None,
129
+ picked_tool_names=None,
130
+ existing_tools_prompt=[],
131
+ rag_num=0,
132
+ return_call_result=False):
133
+ if not self.enable_rag:
134
+ return []
135
+ extra_factor = 10
136
  if picked_tool_names is None:
137
+ assert picked_tool_names is not None or message is not None
138
+ picked_tool_names = self.rag_infer(
139
+ message, top_k=rag_num * extra_factor)
140
+
141
+ picked_tool_names_no_special = [tool for tool in picked_tool_names if tool not in self.special_tools_name]
142
+ picked_tool_names = picked_tool_names_no_special[:rag_num]
143
 
 
 
 
 
144
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
145
  picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
146
+ logger.debug("Retrieved %d tools via RAG", len(picked_tools_prompt))
147
  if return_call_result:
148
  return picked_tools_prompt, picked_tool_names
149
  return picked_tools_prompt
 
155
  if call_agent:
156
  tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
157
  logger.debug("CallAgent tool added")
 
 
 
 
 
 
 
 
 
158
  return tools
159
 
160
  def add_finish_tools(self, tools):
 
169
  conversation[0] = {"role": "system", "content": sys_prompt}
170
  return conversation
171
 
172
+ def run_function_call(self, fcall_str,
173
+ return_message=False,
174
+ existing_tools_prompt=None,
175
+ message_for_call_agent=None,
176
+ call_agent=False,
177
+ call_agent_level=None,
178
+ temperature=None):
179
  function_call_json, message = self.tooluniverse.extract_function_call_json(
180
  fcall_str, return_message=return_message, verbose=False)
181
  call_results = []
182
  special_tool_call = ''
183
  if function_call_json:
184
+ if isinstance(function_call_json, list):
185
+ for i in range(len(function_call_json)):
186
+ logger.info("Tool Call: %s", function_call_json[i])
187
+ if function_call_json[i]["name"] == 'Finish':
188
+ special_tool_call = 'Finish'
189
+ break
190
+ elif function_call_json[i]["name"] == 'CallAgent':
191
+ if call_agent_level < 2 and call_agent:
192
+ solution_plan = function_call_json[i]['arguments']['solution']
193
+ full_message = (
194
+ message_for_call_agent +
195
+ "\nYou must follow the following plan to answer the question: " +
196
+ str(solution_plan)
197
+ )
198
+ call_result = self.run_multistep_agent(
199
+ full_message, temperature=temperature,
200
+ max_new_tokens=128, max_token=768,
201
+ call_agent=False, call_agent_level=call_agent_level)
202
+ if call_result is None:
203
+ call_result = "⚠️ No content returned from sub-agent."
204
+ else:
205
+ call_result = call_result.split('[FinalAnswer]')[-1].strip()
206
+ else:
207
+ call_result = "Error: CallAgent disabled."
208
+ else:
209
+ call_result = self.tooluniverse.run_one_function(function_call_json[i])
210
+ call_id = self.tooluniverse.call_id_gen()
211
+ function_call_json[i]["call_id"] = call_id
212
+ logger.info("Tool Call Result: %s", call_result)
213
+ call_results.append({
214
+ "role": "tool",
215
+ "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
216
+ })
217
  else:
218
  call_results.append({
219
  "role": "tool",
 
222
 
223
  revised_messages = [{
224
  "role": "assistant",
225
+ "content": message.strip(),
226
  "tool_calls": json.dumps(function_call_json)
227
  }] + call_results
228
  return revised_messages, existing_tools_prompt, special_tool_call
229
 
230
+ def run_function_call_stream(self, fcall_str,
231
+ return_message=False,
232
+ existing_tools_prompt=None,
233
+ message_for_call_agent=None,
234
+ call_agent=False,
235
+ call_agent_level=None,
236
+ temperature=None,
237
+ return_gradio_history=True):
238
  function_call_json, message = self.tooluniverse.extract_function_call_json(
239
  fcall_str, return_message=return_message, verbose=False)
240
  call_results = []
241
  special_tool_call = ''
242
+ if return_gradio_history:
243
+ gradio_history = []
244
  if function_call_json:
245
+ if isinstance(function_call_json, list):
246
+ for i in range(len(function_call_json)):
247
+ if function_call_json[i]["name"] == 'Finish':
248
+ special_tool_call = 'Finish'
249
+ break
250
+ elif function_call_json[i]["name"] == 'DirectResponse':
251
+ call_result = function_call_json[i]['arguments']['respose']
252
+ special_tool_call = 'DirectResponse'
253
+ elif function_call_json[i]["name"] == 'RequireClarification':
254
+ call_result = function_call_json[i]['arguments']['unclear_question']
255
+ special_tool_call = 'RequireClarification'
256
+ elif function_call_json[i]["name"] == 'CallAgent':
257
+ if call_agent_level < 2 and call_agent:
258
+ solution_plan = function_call_json[i]['arguments']['solution']
259
+ full_message = (
260
+ message_for_call_agent +
261
+ "\nYou must follow the following plan to answer the question: " +
262
+ str(solution_plan)
263
+ )
264
+ sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
265
+ call_result = yield from self.run_gradio_chat(
266
+ full_message, history=[], temperature=temperature,
267
+ max_new_tokens=128, max_token=768,
268
+ call_agent=False, call_agent_level=call_agent_level,
269
+ conversation=None, sub_agent_task=sub_agent_task)
270
+ if call_result is not None and isinstance(call_result, str):
271
+ call_result = call_result.split('[FinalAnswer]')[-1]
272
+ else:
273
+ call_result = "⚠️ No content returned from sub-agent."
274
+ else:
275
+ call_result = "Error: CallAgent disabled."
276
+ else:
277
+ call_result = self.tooluniverse.run_one_function(function_call_json[i])
278
+ call_id = self.tooluniverse.call_id_gen()
279
+ function_call_json[i]["call_id"] = call_id
280
+ call_results.append({
281
+ "role": "tool",
282
+ "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
283
+ })
284
+ if return_gradio_history and function_call_json[i]["name"] != 'Finish':
285
+ metadata = {"title": f"🧰 {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
286
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata=metadata))
287
  else:
288
  call_results.append({
289
  "role": "tool",
 
292
 
293
  revised_messages = [{
294
  "role": "assistant",
295
+ "content": message.strip(),
296
  "tool_calls": json.dumps(function_call_json)
297
  }] + call_results
298
+ if return_gradio_history:
299
+ return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
300
+ return revised_messages, existing_tools_prompt, special_tool_call
301
 
302
+ def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
303
  if conversation[-1]['role'] == 'assistant':
304
  conversation.append(
305
+ {'role': 'tool', 'content': 'Errors occurred during function call; provide final answer with current information.'})
306
  finish_tools_prompt = self.add_finish_tools([])
307
+ last_outputs_str = self.llm_infer(
308
+ messages=conversation,
309
+ temperature=temperature,
310
+ tools=finish_tools_prompt,
311
+ output_begin_string='[FinalAnswer]',
312
+ skip_special_tokens=True,
313
+ max_new_tokens=max_new_tokens,
314
+ max_token=max_token)
315
+ logger.info("Unfinished reasoning answer: %s", last_outputs_str[:100])
316
+ return last_outputs_str
317
+
318
+ def run_multistep_agent(self, message: str,
319
+ temperature: float,
320
+ max_new_tokens: int,
321
+ max_token: int,
322
+ max_round: int = 5,
323
+ call_agent=False,
324
+ call_agent_level=0):
325
+ logger.info("Starting multistep agent for message: %s", message[:100])
326
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
327
  call_agent, call_agent_level, message)
328
  conversation = self.initialize_conversation(message)
 
334
  enable_summary = False
335
  last_status = {}
336
 
 
 
 
337
  while next_round and current_round < max_round:
338
  current_round += 1
339
+ if len(outputs) > 0:
340
  function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
341
+ last_outputs, return_message=True,
342
+ existing_tools_prompt=picked_tools_prompt,
343
+ message_for_call_agent=message,
344
+ call_agent=call_agent,
345
+ call_agent_level=call_agent_level,
346
+ temperature=temperature)
347
 
348
  if special_tool_call == 'Finish':
349
  next_round = False
350
  conversation.extend(function_call_messages)
351
  content = function_call_messages[0]['content']
352
+ if content is None:
353
+ return "❌ No content returned after Finish tool call."
354
+ return content.split('[FinalAnswer]')[-1]
355
 
356
  if (self.enable_summary or token_overflow) and not call_agent:
357
  enable_summary = True
358
  last_status = self.function_result_summary(
359
+ - conversation, status=last_status, enable_summary=enable_summary)
360
 
361
  if function_call_messages:
362
  conversation.extend(function_call_messages)
363
  outputs.append(tool_result_format(function_call_messages))
364
  else:
365
  next_round = False
366
+ conversation.extend([{"role": "assistant", "content": ''.join(last_outputs)}])
367
  return ''.join(last_outputs).replace("</s>", "")
368
 
 
 
 
 
 
 
369
  last_outputs = []
370
+ outputs.append("### TxAgent:\n")
371
  last_outputs_str, token_overflow = self.llm_infer(
372
+ messages=conversation,
373
+ temperature=temperature,
374
+ tools=picked_tools_prompt,
375
+ skip_special_tokens=False,
376
+ max_new_tokens=max_new_tokens,
377
+ max_token=max_token,
378
+ check_token_status=True)
379
  if last_outputs_str is None:
380
+ logger.warning("Token limit exceeded")
381
  if self.force_finish:
382
  return self.get_answer_based_on_unfinished_reasoning(
383
  conversation, temperature, max_new_tokens, max_token)
384
  return "❌ Token limit exceeded."
385
  last_outputs.append(last_outputs_str)
386
 
387
+ if max_round == current_round:
388
  logger.warning("Max rounds exceeded")
389
  if self.force_finish:
390
  return self.get_answer_based_on_unfinished_reasoning(
 
394
  def build_logits_processor(self, messages, llm):
395
  tokenizer = llm.get_tokenizer()
396
  if self.avoid_repeat and len(messages) > 2:
397
+ assistant_messages = [msg['content'] for msg in messages[-3:] if msg['role'] == 'assistant'][:2]
 
 
398
  forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
399
  return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
400
  return None
401
 
402
+ def llm_infer(self, messages, temperature=0.1, tools=None,
403
+ output_begin_string=None, max_new_tokens=128,
404
+ max_token=768, skip_special_tokens=True,
405
+ model=None, tokenizer=None, terminators=None,
406
+ seed=None, check_token_status=False):
407
  if model is None:
408
  model = self.model
409
 
 
412
  temperature=temperature,
413
  max_tokens=max_new_tokens,
414
  seed=seed if seed is not None else self.seed,
 
415
  )
416
 
417
+ prompt = self.chat_template.render(
418
+ messages=messages, tools=tools, add_generation_prompt=True)
419
+ if output_begin_string is not None:
420
  prompt += output_begin_string
421
 
422
+ if check_token_status and max_token is not None:
423
+ token_overflow = False
424
  num_input_tokens = len(self.tokenizer.encode(prompt, return_tensors="pt")[0])
425
  if num_input_tokens > max_token:
426
  torch.cuda.empty_cache()
427
  gc.collect()
428
  logger.info("Token overflow: %d > %d", num_input_tokens, max_token)
429
  return None, True
 
430
 
431
  output = model.generate(prompt, sampling_params=sampling_params)
432
  output = output[0].outputs[0].text
433
  logger.debug("Inference output: %s", output[:100])
434
+ torch.cuda.empty_cache()
435
+ gc.collect()
436
+ if check_token_status and max_token is not None:
437
+ return output, token_overflow
438
  return output
439
 
440
+ def run_self_agent(self, message: str,
441
+ temperature: float,
442
+ max_new_tokens: int,
443
+ max_token: int):
444
+ logger.info("Starting self agent")
445
  conversation = self.set_system_prompt([], self.self_prompt)
446
  conversation.append({"role": "user", "content": message})
447
+ return self.llm_infer(
448
+ messages=conversation,
449
+ temperature=temperature,
450
+ tools=None,
451
+ max_new_tokens=max_new_tokens,
452
+ max_token=max_token)
453
+
454
+ def run_chat_agent(self, message: str,
455
+ temperature: float,
456
+ max_new_tokens: int,
457
+ max_token: int):
458
+ logger.info("Starting chat agent")
459
  conversation = self.set_system_prompt([], self.chat_prompt)
460
  conversation.append({"role": "user", "content": message})
461
+ return self.llm_infer(
462
+ messages=conversation,
463
+ temperature=temperature,
464
+ tools=None,
465
+ max_new_tokens=max_new_tokens,
466
+ max_token=max_token)
467
+
468
+ def run_format_agent(self, message: str,
469
+ answer: str,
470
+ temperature: float,
471
+ max_new_tokens: int,
472
+ max_token: int):
473
+ logger.info("Starting format agent")
474
  if '[FinalAnswer]' in answer:
475
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
476
  elif "\n\n" in answer:
477
  possible_final_answer = answer.split("\n\n")[-1]
478
  else:
479
  possible_final_answer = answer.strip()
480
+ if len(possible_final_answer) == 1 and possible_final_answer in ['A', 'B', 'C', 'D', 'E']:
481
+ return possible_final_answer
 
482
  elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
483
  return possible_final_answer[0]
484
 
485
  conversation = self.set_system_prompt(
486
+ [], "Transform the agent's answer to a single letter: 'A', 'B', 'C', 'D'.")
487
+ conversation.append({"role": "user", "content": message +
488
+ "\nAgent's answer: " + answer + "\nAnswer (must be a letter):"})
489
+ return self.llm_infer(
490
+ messages=conversation,
491
+ temperature=temperature,
492
+ tools=None,
493
+ max_new_tokens=max_new_tokens,
494
+ max_token=max_token)
495
+
496
+ def run_summary_agent(self, thought_calls: str,
497
+ function_response: str,
498
+ temperature: float,
499
+ max_new_tokens: int,
500
+ max_token: int):
501
+ logger.info("Summarizing tool result")
502
+ prompt = f"""Thought and function calls:
503
+ {thought_calls}
504
+ Function calls' responses:
505
+ \"\"\"
506
+ {function_response}
507
+ \"\"\"
508
+ Summarize the function calls' responses in one sentence with all necessary information.
509
+ """
510
  conversation = [{"role": "user", "content": prompt}]
511
+ output = self.llm_infer(
512
+ messages=conversation,
513
+ temperature=temperature,
514
+ tools=None,
515
+ max_new_tokens=max_new_tokens,
516
+ max_token=max_token)
517
  if '[' in output:
518
  output = output.split('[')[0]
519
  return output
 
521
  def function_result_summary(self, input_list, status, enable_summary):
522
  if 'tool_call_step' not in status:
523
  status['tool_call_step'] = 0
 
 
 
 
524
  for idx in range(len(input_list)):
525
  pos_id = len(input_list) - idx - 1
526
  if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]:
 
 
527
  break
528
 
529
+ status['step'] = status.get('step', 0) + 1
530
  if not enable_summary:
531
  return status
532
 
533
+ status['summarized_index'] = status.get('summarized_index', 0)
534
+ status['summarized_step'] = status.get('summarized_step', 0)
535
+ status['previous_length'] = status.get('previous_length', 0)
536
+ status['history'] = status.get('history', [])
 
 
 
 
537
 
 
 
 
 
538
  function_response = ''
539
+ idx = status['summarized_index']
540
  this_thought_calls = None
541
+
542
  while idx < len(input_list):
543
  if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \
544
  (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
545
  if input_list[idx]['role'] == 'assistant':
546
+ if function_response:
547
+ status['summarized_step'] += 1
548
+ result_summary = self.run_summary_agent(
549
+ thought_calls=this_thought_calls,
550
+ function_response=function_response,
551
+ temperature=0.1,
552
+ max_new_tokens=128,
553
+ max_token=768)
554
+ input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
555
+ status['summarized_index'] = last_call_idx + 2
556
+ idx += 1
557
+ last_call_idx = idx
558
+ this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls']
559
+ function_response = ''
560
+ elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
 
561
  function_response += input_list[idx]['content']
562
  del input_list[idx]
563
  idx -= 1
 
568
  if function_response:
569
  status['summarized_step'] += 1
570
  result_summary = self.run_summary_agent(
571
+ thought_calls=this_thought_calls,
572
+ function_response=function_response,
573
+ temperature=0.1,
574
+ max_new_tokens=128,
575
+ max_token=768)
576
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
577
  for tool_call in tool_calls:
578
  del tool_call['call_id']
579
  input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
580
+ input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
 
581
  status['summarized_index'] = last_call_idx + 2
582
 
583
  return status
 
588
  if hasattr(self, key):
589
  setattr(self, key, value)
590
  updated_attributes[key] = value
591
+ logger.info("Updated parameters: %s", updated_attributes)
592
  return updated_attributes
593
 
594
+ def run_gradio_chat(self, message: str,
595
+ history: list,
596
+ temperature: float,
597
+ max_new_tokens: int,
598
+ max_token: int,
599
+ call_agent: bool,
600
+ conversation: gr.State,
601
+ max_round: int = 5,
602
+ seed: int = None,
603
+ call_agent_level: int = 0,
604
+ sub_agent_task: str = None,
605
  uploaded_files: list = None):
606
+ logger.info("Chat started, message: %s", message[:100])
607
  if not message or len(message.strip()) < 5:
608
  yield "Please provide a valid message or upload files to analyze."
609
  return
610
 
 
 
 
611
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
612
  call_agent, call_agent_level, message)
613
  conversation = self.initialize_conversation(
 
620
  last_status = {}
621
  token_overflow = False
622
 
 
 
 
623
  try:
624
  while next_round and current_round < max_round:
625
  current_round += 1
 
626
  if last_outputs:
627
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
628
+ last_outputs, return_message=True,
629
+ existing_tools_prompt=picked_tools_prompt,
630
+ message_for_call_agent=message,
631
+ call_agent=call_agent,
632
+ call_agent_level=call_agent_level,
633
+ temperature=temperature)
634
  history.extend(current_gradio_history)
635
 
636
  if special_tool_call == 'Finish':
 
639
  conversation.extend(function_call_messages)
640
  return function_call_messages[0]['content']
641
 
642
+ elif special_tool_call in ['RequireClarification', 'DirectResponse']:
643
  last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
644
  history.append(ChatMessage(role="assistant", content=last_msg.content))
645
  yield history
 
656
  yield history
657
  else:
658
  next_round = False
659
+ conversation.append({"role": "assistant", "content": ''.join(last_outputs)})
660
  return ''.join(last_outputs).replace("</s>", "")
661
 
662
+ last_outputs = []
 
 
 
 
 
663
  last_outputs_str, token_overflow = self.llm_infer(
664
+ messages=conversation,
665
+ temperature=temperature,
666
+ tools=picked_tools_prompt,
667
+ skip_special_tokens=False,
668
+ max_new_tokens=max_new_tokens,
669
+ max_token=max_token,
670
+ seed=seed,
671
+ check_token_status=True)
672
 
673
  if last_outputs_str is None:
674
+ logger.warning("Token limit exceeded")
675
  if self.force_finish:
676
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
677
  conversation, temperature, max_new_tokens, max_token)
 
685
 
686
  last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
687
  for msg in history:
688
+ if msg.metadata is not None:
689
  msg.metadata['status'] = 'done'
690
 
691
  if '[FinalAnswer]' in last_thought:
 
701
 
702
  last_outputs.append(last_outputs_str)
703
 
704
+ if next_round:
705
+ if self.force_finish:
706
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
707
+ conversation, temperature, max_new_tokens, max_token)
708
+ parts = last_outputs_str.split('[FinalAnswer]', 1)
709
+ final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
710
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
711
+ yield history
712
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
713
+ yield history
714
+ else:
715
+ yield "Reasoning rounds exceeded limit."
716
 
717
  except Exception as e:
718
  logger.error("Exception in run_gradio_chat: %s", e, exc_info=True)