Ali2206 commited on
Commit
80b0f9f
·
verified ·
1 Parent(s): e7f3d1d

Create src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +811 -0
src/txagent/txagent.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ import json
5
+ import gc
6
+ import numpy as np
7
+ from vllm import LLM, SamplingParams
8
+ from jinja2 import Template
9
+ from typing import List
10
+ import types
11
+ from tooluniverse import ToolUniverse
12
+ from gradio import ChatMessage
13
+ from .toolrag import ToolRAGModel
14
+ import torch
15
+ import logging
16
+
17
+ # Configure logging with a more specific logger name
18
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
19
+ logger = logging.getLogger("TxAgent")
20
+
21
+ from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
22
+
23
+ class TxAgent:
24
+ def __init__(self, model_name,
25
+ rag_model_name,
26
+ tool_files_dict=None,
27
+ enable_finish=True,
28
+ enable_rag=False,
29
+ enable_summary=False,
30
+ init_rag_num=0,
31
+ step_rag_num=0,
32
+ summary_mode='step',
33
+ summary_skip_last_k=0,
34
+ summary_context_length=None,
35
+ force_finish=True,
36
+ avoid_repeat=True,
37
+ seed=None,
38
+ enable_checker=False,
39
+ enable_chat=False,
40
+ additional_default_tools=None):
41
+ self.model_name = model_name
42
+ self.tokenizer = None
43
+ self.terminators = None
44
+ self.rag_model_name = rag_model_name
45
+ self.tool_files_dict = tool_files_dict
46
+ self.model = None
47
+ self.rag_model = ToolRAGModel(rag_model_name)
48
+ self.tooluniverse = None
49
+ self.prompt_multi_step = "You are a helpful assistant that solves problems through step-by-step reasoning."
50
+ self.self_prompt = "Strictly follow the instruction."
51
+ self.chat_prompt = "You are a helpful assistant for user chat."
52
+ self.enable_finish = enable_finish
53
+ self.enable_rag = enable_rag
54
+ self.enable_summary = enable_summary
55
+ self.summary_mode = summary_mode
56
+ self.summary_skip_last_k = summary_skip_last_k
57
+ self.summary_context_length = summary_context_length
58
+ self.init_rag_num = init_rag_num
59
+ self.step_rag_num = step_rag_num
60
+ self.force_finish = force_finish
61
+ self.avoid_repeat = avoid_repeat
62
+ self.seed = seed
63
+ self.enable_checker = enable_checker
64
+ self.additional_default_tools = additional_default_tools
65
+ logger.info("TxAgent initialized with model: %s, RAG: %s", model_name, rag_model_name)
66
+
67
+ def init_model(self):
68
+ self.load_models()
69
+ self.load_tooluniverse()
70
+
71
+ def load_models(self, model_name=None):
72
+ if model_name is not None:
73
+ if model_name == self.model_name:
74
+ return f"The model {model_name} is already loaded."
75
+ self.model_name = model_name
76
+
77
+ self.model = LLM(
78
+ model=self.model_name,
79
+ dtype="float16",
80
+ max_model_len=131072,
81
+ max_num_batched_tokens=65536, # Increased for A100 80GB
82
+ max_num_seqs=512,
83
+ gpu_memory_utilization=0.95, # Higher utilization for better performance
84
+ trust_remote_code=True,
85
+ )
86
+ self.chat_template = Template(self.model.get_tokenizer().chat_template)
87
+ self.tokenizer = self.model.get_tokenizer()
88
+ logger.info(
89
+ "Model %s loaded with max_model_len=%d, max_num_batched_tokens=%d, gpu_memory_utilization=%.2f",
90
+ self.model_name, 131072, 32768, 0.9
91
+ )
92
+ return f"Model {model_name} loaded successfully."
93
+
94
+ def load_tooluniverse(self):
95
+ self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
96
+ self.tooluniverse.load_tools()
97
+ special_tools = self.tooluniverse.prepare_tool_prompts(
98
+ self.tooluniverse.tool_category_dicts["special_tools"])
99
+ self.special_tools_name = [tool['name'] for tool in special_tools]
100
+ logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name))
101
+
102
+ def load_tool_desc_embedding(self):
103
+ cache_path = os.path.join(os.path.dirname(self.tool_files_dict["new_tool"]), "tool_embeddings.pkl")
104
+ if os.path.exists(cache_path):
105
+ self.rag_model.load_cached_embeddings(cache_path)
106
+ else:
107
+ self.rag_model.load_tool_desc_embedding(self.tooluniverse)
108
+ self.rag_model.save_embeddings(cache_path)
109
+ logger.debug("Tool description embeddings loaded")
110
+
111
+ def rag_infer(self, query, top_k=5):
112
+ return self.rag_model.rag_infer(query, top_k)
113
+
114
+ def initialize_tools_prompt(self, call_agent, call_agent_level, message):
115
+ picked_tools_prompt = []
116
+ picked_tools_prompt = self.add_special_tools(
117
+ picked_tools_prompt, call_agent=call_agent)
118
+ if call_agent:
119
+ call_agent_level += 1
120
+ if call_agent_level >= 2:
121
+ call_agent = False
122
+ return picked_tools_prompt, call_agent_level
123
+
124
+ def initialize_conversation(self, message, conversation=None, history=None):
125
+ if conversation is None:
126
+ conversation = []
127
+
128
+ conversation = self.set_system_prompt(
129
+ conversation, self.prompt_multi_step)
130
+ if history:
131
+ for i in range(len(history)):
132
+ if history[i]['role'] == 'user':
133
+ conversation.append({"role": "user", "content": history[i]['content']})
134
+ elif history[i]['role'] == 'assistant':
135
+ conversation.append({"role": "assistant", "content": history[i]['content']})
136
+ conversation.append({"role": "user", "content": message})
137
+ logger.debug("Conversation initialized with %d messages", len(conversation))
138
+ return conversation
139
+
140
+ def tool_RAG(self, message=None,
141
+ picked_tool_names=None,
142
+ existing_tools_prompt=[],
143
+ rag_num=0,
144
+ return_call_result=False):
145
+ if not self.enable_rag:
146
+ return []
147
+ extra_factor = 10
148
+ if picked_tool_names is None:
149
+ assert picked_tool_names is not None or message is not None
150
+ picked_tool_names = self.rag_infer(
151
+ message, top_k=rag_num * extra_factor)
152
+
153
+ picked_tool_names_no_special = [tool for tool in picked_tool_names if tool not in self.special_tools_name]
154
+ picked_tool_names = picked_tool_names_no_special[:rag_num]
155
+
156
+ picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
157
+ picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
158
+ logger.debug("Retrieved %d tools via RAG", len(picked_tools_prompt))
159
+ if return_call_result:
160
+ return picked_tools_prompt, picked_tool_names
161
+ return picked_tools_prompt
162
+
163
+ def add_special_tools(self, tools, call_agent=False):
164
+ if self.enable_finish:
165
+ tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
166
+ logger.debug("Finish tool added")
167
+ if call_agent:
168
+ tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
169
+ logger.debug("CallAgent tool added")
170
+ return tools
171
+
172
+ def add_finish_tools(self, tools):
173
+ tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
174
+ logger.debug("Finish tool added")
175
+ return tools
176
+
177
+ def set_system_prompt(self, conversation, sys_prompt):
178
+ if not conversation:
179
+ conversation.append({"role": "system", "content": sys_prompt})
180
+ else:
181
+ conversation[0] = {"role": "system", "content": sys_prompt}
182
+ return conversation
183
+
184
+ def run_function_call(self, fcall_str,
185
+ return_message=False,
186
+ existing_tools_prompt=None,
187
+ message_for_call_agent=None,
188
+ call_agent=False,
189
+ call_agent_level=None,
190
+ temperature=None):
191
+ try:
192
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
193
+ fcall_str, return_message=return_message, verbose=False)
194
+ except Exception as e:
195
+ logger.error("Tool call parsing failed: %s", e)
196
+ function_call_json = []
197
+ message = fcall_str
198
+
199
+ call_results = []
200
+ special_tool_call = ''
201
+ if function_call_json:
202
+ if isinstance(function_call_json, list):
203
+ for i in range(len(function_call_json)):
204
+ logger.info("Tool Call: %s", function_call_json[i])
205
+ if function_call_json[i]["name"] == 'Finish':
206
+ special_tool_call = 'Finish'
207
+ break
208
+ elif function_call_json[i]["name"] == 'CallAgent':
209
+ if call_agent_level < 2 and call_agent:
210
+ solution_plan = function_call_json[i]['arguments']['solution']
211
+ full_message = (
212
+ message_for_call_agent +
213
+ "\nYou must follow the following plan to answer the question: " +
214
+ str(solution_plan)
215
+ )
216
+ call_result = self.run_multistep_agent(
217
+ full_message, temperature=temperature,
218
+ max_new_tokens=512, max_token=131072,
219
+ call_agent=False, call_agent_level=call_agent_level)
220
+ if call_result is None:
221
+ call_result = "⚠️ No content returned from sub-agent."
222
+ else:
223
+ call_result = call_result.split('[FinalAnswer]')[-1].strip()
224
+ else:
225
+ call_result = "Error: CallAgent disabled."
226
+ else:
227
+ call_result = self.tooluniverse.run_one_function(function_call_json[i])
228
+ call_id = self.tooluniverse.call_id_gen()
229
+ function_call_json[i]["call_id"] = call_id
230
+ logger.info("Tool Call Result: %s", call_result)
231
+ call_results.append({
232
+ "role": "tool",
233
+ "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
234
+ })
235
+ else:
236
+ call_results.append({
237
+ "role": "tool",
238
+ "content": json.dumps({"content": "Invalid or no function call detected."})
239
+ })
240
+
241
+ revised_messages = [{
242
+ "role": "assistant",
243
+ "content": message.strip(),
244
+ "tool_calls": json.dumps(function_call_json)
245
+ }] + call_results
246
+ return revised_messages, existing_tools_prompt, special_tool_call
247
+
248
+ def run_function_call_stream(self, fcall_str,
249
+ return_message=False,
250
+ existing_tools_prompt=None,
251
+ message_for_call_agent=None,
252
+ call_agent=False,
253
+ call_agent_level=None,
254
+ temperature=None,
255
+ return_gradio_history=True):
256
+ try:
257
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
258
+ fcall_str, return_message=return_message, verbose=False)
259
+ except Exception as e:
260
+ logger.error("Tool call parsing failed: %s", e)
261
+ function_call_json = []
262
+ message = fcall_str
263
+
264
+ call_results = []
265
+ special_tool_call = ''
266
+ if return_gradio_history:
267
+ gradio_history = []
268
+ if function_call_json:
269
+ if isinstance(function_call_json, list):
270
+ for i in range(len(function_call_json)):
271
+ if function_call_json[i]["name"] == 'Finish':
272
+ special_tool_call = 'Finish'
273
+ break
274
+ elif function_call_json[i]["name"] == 'DirectResponse':
275
+ call_result = function_call_json[i]['arguments']['respose']
276
+ special_tool_call = 'DirectResponse'
277
+ elif function_call_json[i]["name"] == 'RequireClarification':
278
+ call_result = function_call_json[i]['arguments']['unclear_question']
279
+ special_tool_call = 'RequireClarification'
280
+ elif function_call_json[i]["name"] == 'CallAgent':
281
+ if call_agent_level < 2 and call_agent:
282
+ solution_plan = function_call_json[i]['arguments']['solution']
283
+ full_message = (
284
+ message_for_call_agent +
285
+ "\nYou must follow the following plan to answer the question: " +
286
+ str(solution_plan)
287
+ )
288
+ sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
289
+ call_result = yield from self.run_gradio_chat(
290
+ full_message, history=[], temperature=temperature,
291
+ max_new_tokens=512, max_token=131072,
292
+ call_agent=False, call_agent_level=call_agent_level,
293
+ conversation=None, sub_agent_task=sub_agent_task)
294
+ if call_result is not None and isinstance(call_result, str):
295
+ call_result = call_result.split('[FinalAnswer]')[-1]
296
+ else:
297
+ call_result = "⚠️ No content returned from sub-agent."
298
+ else:
299
+ call_result = "Error: CallAgent disabled."
300
+ else:
301
+ call_result = self.tooluniverse.run_one_function(function_call_json[i])
302
+ call_id = self.tooluniverse.call_id_gen()
303
+ function_call_json[i]["call_id"] = call_id
304
+ call_results.append({
305
+ "role": "tool",
306
+ "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
307
+ })
308
+ if return_gradio_history and function_call_json[i]["name"] != 'Finish':
309
+ metadata = {"title": f"🧰 {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
310
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata=metadata))
311
+ else:
312
+ call_results.append({
313
+ "role": "tool",
314
+ "content": json.dumps({"content": "Invalid or no function call detected."})
315
+ })
316
+
317
+ revised_messages = [{
318
+ "role": "assistant",
319
+ "content": message.strip(),
320
+ "tool_calls": json.dumps(function_call_json)
321
+ }] + call_results
322
+ if return_gradio_history:
323
+ return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
324
+ return revised_messages, existing_tools_prompt, special_tool_call
325
+
326
+ def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
327
+ if conversation[-1]['role'] == 'assistant':
328
+ conversation.append(
329
+ {'role': 'tool', 'content': 'Errors occurred during function call; provide final answer with current information.'})
330
+ finish_tools_prompt = self.add_finish_tools([])
331
+ last_outputs_str = self.llm_infer(
332
+ messages=conversation,
333
+ temperature=temperature,
334
+ tools=finish_tools_prompt,
335
+ output_begin_string='[FinalAnswer]',
336
+ skip_special_tokens=True,
337
+ max_new_tokens=max_new_tokens,
338
+ max_token=max_token)
339
+ logger.info("Unfinished reasoning answer: %s", last_outputs_str[:100])
340
+ return last_outputs_str
341
+
342
+ def run_multistep_agent(self, message: str,
343
+ temperature: float,
344
+ max_new_tokens: int,
345
+ max_token: int,
346
+ max_round: int = 5,
347
+ call_agent=False,
348
+ call_agent_level=0):
349
+ logger.info("Starting multistep agent for message: %s", message[:100])
350
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
351
+ call_agent, call_agent_level, message)
352
+ conversation = self.initialize_conversation(message)
353
+ outputs = []
354
+ last_outputs = []
355
+ next_round = True
356
+ current_round = 0
357
+ token_overflow = False
358
+ enable_summary = False
359
+ last_status = {}
360
+
361
+ while next_round and current_round < max_round:
362
+ current_round += 1
363
+ if len(outputs) > 0:
364
+ function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
365
+ last_outputs, return_message=True,
366
+ existing_tools_prompt=picked_tools_prompt,
367
+ message_for_call_agent=message,
368
+ call_agent=call_agent,
369
+ call_agent_level=call_agent_level,
370
+ temperature=temperature)
371
+
372
+ if special_tool_call == 'Finish':
373
+ next_round = False
374
+ conversation.extend(function_call_messages)
375
+ content = function_call_messages[0]['content']
376
+ if content is None:
377
+ return "❌ No content returned after Finish tool call."
378
+ return content.split('[FinalAnswer]')[-1]
379
+
380
+ if (self.enable_summary or token_overflow) and not call_agent:
381
+ enable_summary = True
382
+ last_status = self.function_result_summary(
383
+ conversation, status=last_status, enable_summary=enable_summary)
384
+
385
+ if function_call_messages:
386
+ conversation.extend(function_call_messages)
387
+ outputs.append(tool_result_format(function_call_messages))
388
+ else:
389
+ next_round = False
390
+ conversation.extend([{"role": "assistant", "content": ''.join(last_outputs)}])
391
+ return ''.join(last_outputs).replace("</s>", "")
392
+
393
+ last_outputs = []
394
+ outputs.append("### TxAgent:\n")
395
+ last_outputs_str, token_overflow = self.llm_infer(
396
+ messages=conversation,
397
+ temperature=temperature,
398
+ tools=picked_tools_prompt,
399
+ skip_special_tokens=False,
400
+ max_new_tokens=2048,
401
+ max_token=131072,
402
+ check_token_status=True)
403
+ if last_outputs_str is None:
404
+ logger.warning("Token limit exceeded")
405
+ if self.force_finish:
406
+ return self.get_answer_based_on_unfinished_reasoning(
407
+ conversation, temperature, max_new_tokens, max_token)
408
+ return "❌ Token limit exceeded."
409
+ last_outputs.append(last_outputs_str)
410
+
411
+ if max_round == current_round:
412
+ logger.warning("Max rounds exceeded")
413
+ if self.force_finish:
414
+ return self.get_answer_based_on_unfinished_reasoning(
415
+ conversation, temperature, max_new_tokens, max_token)
416
+ return None
417
+
418
+ def build_logits_processor(self, messages, llm):
419
+ logger.warning("Logits processor disabled due to vLLM V1 limitation")
420
+ return None
421
+
422
+ def llm_infer(self, messages, temperature=0.1, tools=None,
423
+ output_begin_string=None, max_new_tokens=512,
424
+ max_token=131072, skip_special_tokens=True,
425
+ model=None, tokenizer=None, terminators=None,
426
+ seed=None, check_token_status=False):
427
+ if model is None:
428
+ model = self.model
429
+
430
+ logits_processor = self.build_logits_processor(messages, model)
431
+ sampling_params = SamplingParams(
432
+ temperature=temperature,
433
+ max_tokens=max_new_tokens,
434
+ seed=seed if seed is not None else self.seed,
435
+ )
436
+
437
+ prompt = self.chat_template.render(
438
+ messages=messages, tools=tools, add_generation_prompt=True)
439
+ if output_begin_string is not None:
440
+ prompt += output_begin_string
441
+
442
+ if check_token_status and max_token is not None:
443
+ token_overflow = False
444
+ num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False))
445
+ logger.info("Input prompt tokens: %d, max_token: %d", num_input_tokens, max_token)
446
+ if num_input_tokens > max_token:
447
+ torch.cuda.empty_cache()
448
+ gc.collect()
449
+ logger.warning("Token overflow: %d > %d", num_input_tokens, max_token)
450
+ return None, True
451
+
452
+ output = model.generate(prompt, sampling_params=sampling_params)
453
+ output_text = output[0].outputs[0].text
454
+ output_tokens = len(self.tokenizer.encode(output_text, add_special_tokens=False))
455
+ logger.debug("Inference output: %s (output tokens: %d)", output_text[:100], output_tokens)
456
+ torch.cuda.empty_cache()
457
+ gc.collect()
458
+ if check_token_status and max_token is not None:
459
+ return output_text, token_overflow
460
+ return output_text
461
+
462
+ def run_self_agent(self, message: str,
463
+ temperature: float,
464
+ max_new_tokens: int,
465
+ max_token: int):
466
+ logger.info("Starting self agent")
467
+ conversation = self.set_system_prompt([], self.self_prompt)
468
+ conversation.append({"role": "user", "content": message})
469
+ return self.llm_infer(
470
+ messages=conversation,
471
+ temperature=temperature,
472
+ tools=None,
473
+ max_new_tokens=max_new_tokens,
474
+ max_token=max_token)
475
+
476
+ def run_chat_agent(self, message: str,
477
+ temperature: float,
478
+ max_new_tokens: int,
479
+ max_token: int):
480
+ logger.info("Starting chat agent")
481
+ conversation = self.set_system_prompt([], self.chat_prompt)
482
+ conversation.append({"role": "user", "content": message})
483
+ return self.llm_infer(
484
+ messages=conversation,
485
+ temperature=temperature,
486
+ tools=None,
487
+ max_new_tokens=max_new_tokens,
488
+ max_token=max_token)
489
+
490
+ def run_format_agent(self, message: str,
491
+ answer: str,
492
+ temperature: float,
493
+ max_new_tokens: int,
494
+ max_token: int):
495
+ logger.info("Starting format agent")
496
+ if '[FinalAnswer]' in answer:
497
+ possible_final_answer = answer.split("[FinalAnswer]")[-1]
498
+ elif "\n\n" in answer:
499
+ possible_final_answer = answer.split("\n\n")[-1]
500
+ else:
501
+ possible_final_answer = answer.strip()
502
+ if len(possible_final_answer) == 1 and possible_final_answer in ['A', 'B', 'C', 'D', 'E']:
503
+ return possible_final_answer
504
+ elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
505
+ return possible_final_answer[0]
506
+
507
+ conversation = self.set_system_prompt(
508
+ [], "Transform the agent's answer to a single letter: 'A', 'B', 'C', 'D'.")
509
+ conversation.append({"role": "user", "content": message +
510
+ "\nAgent's answer: " + answer + "\nAnswer (must be a letter):"})
511
+ return self.llm_infer(
512
+ messages=conversation,
513
+ temperature=temperature,
514
+ tools=None,
515
+ max_new_tokens=max_new_tokens,
516
+ max_token=max_token)
517
+
518
+ def run_summary_agent(self, thought_calls: str,
519
+ function_response: str,
520
+ temperature: float,
521
+ max_new_tokens: int,
522
+ max_token: int):
523
+ logger.info("Summarizing tool result")
524
+ prompt = f"""Thought and function calls:
525
+ {thought_calls}
526
+ Function calls' responses:
527
+ \"\"\"
528
+ {function_response}
529
+ \"\"\"
530
+ Summarize the function calls' l responses in one sentence with all necessary information.
531
+ """
532
+ conversation = [{"role": "user", "content": prompt}]
533
+ output = self.llm_infer(
534
+ messages=conversation,
535
+ temperature=temperature,
536
+ tools=None,
537
+ max_new_tokens=max_new_tokens,
538
+ max_token=max_token)
539
+ if '[' in output:
540
+ output = output.split('[')[0]
541
+ return output
542
+
543
+ def function_result_summary(self, input_list, status, enable_summary):
544
+ if 'tool_call_step' not in status:
545
+ status['tool_call_step'] = 0
546
+ for idx in range(len(input_list)):
547
+ pos_id = len(input_list) - idx - 1
548
+ if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]:
549
+ break
550
+
551
+ status['step'] = status.get('step', 0) + 1
552
+ if not enable_summary:
553
+ return status
554
+
555
+ status['summarized_index'] = status.get('summarized_index', 0)
556
+ status['summarized_step'] = status.get('summarized_step', 0)
557
+ status['previous_length'] = status.get('previous_length', 0)
558
+ status['history'] = status.get('history', [])
559
+
560
+ function_response = ''
561
+ idx = status['summarized_index']
562
+ this_thought_calls = None
563
+
564
+ while idx < len(input_list):
565
+ if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \
566
+ (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
567
+ if input_list[idx]['role'] == 'assistant':
568
+ if function_response:
569
+ status['summarized_step'] += 1
570
+ result_summary = self.run_summary_agent(
571
+ thought_calls=this_thought_calls,
572
+ function_response=function_response,
573
+ temperature=0.1,
574
+ max_new_tokens=512,
575
+ max_token=131072)
576
+ input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
577
+ status['summarized_index'] = last_call_idx + 2
578
+ idx += 1
579
+ last_call_idx = idx
580
+ this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls']
581
+ function_response = ''
582
+ elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
583
+ function_response += input_list[idx]['content']
584
+ del input_list[idx]
585
+ idx -= 1
586
+ else:
587
+ break
588
+ idx += 1
589
+
590
+ if function_response:
591
+ status['summarized_step'] += 1
592
+ result_summary = self.run_summary_agent(
593
+ thought_calls=this_thought_calls,
594
+ function_response=function_response,
595
+ temperature=0.1,
596
+ max_new_tokens=512,
597
+ max_token=131072)
598
+ tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
599
+ for tool_call in tool_calls:
600
+ del tool_call['call_id']
601
+ input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
602
+ input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
603
+ status['summarized_index'] = last_call_idx + 2
604
+
605
+ return status
606
+
607
+ def update_parameters(self, **kwargs):
608
+ updated_attributes = {}
609
+ for key, value in kwargs.items():
610
+ if hasattr(self, key):
611
+ setattr(self, key, value)
612
+ updated_attributes[key] = value
613
+ logger.info("Updated parameters: %s", updated_attributes)
614
+ return updated_attributes
615
+
616
+ def run_gradio_chat(self, message: str,
617
+ history: list,
618
+ temperature: float,
619
+ max_new_tokens: int = 2048,
620
+ max_token: int = 131072,
621
+ call_agent: bool = False,
622
+ conversation: gr.State = None,
623
+ max_round: int = 5,
624
+ seed: int = None,
625
+ call_agent_level: int = 0,
626
+ sub_agent_task: str = None,
627
+ uploaded_files: list = None):
628
+ logger.info("Chat started, message: %s", message[:100])
629
+ if not message or len(message.strip()) < 5:
630
+ yield "Please provide a valid message or upload files to analyze."
631
+ return
632
+
633
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
634
+ call_agent, call_agent_level, message)
635
+ conversation = self.initialize_conversation(
636
+ message, conversation, history)
637
+ history = []
638
+ last_outputs = []
639
+
640
+ next_round = True
641
+ current_round = 0
642
+ enable_summary = False
643
+ last_status = {}
644
+ token_overflow = False
645
+
646
+ try:
647
+ while next_round and current_round < max_round:
648
+ current_round += 1
649
+ logger.debug("Starting round %d/%d", current_round, max_round)
650
+ if last_outputs:
651
+ function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
652
+ last_outputs, return_message=True,
653
+ existing_tools_prompt=picked_tools_prompt,
654
+ message_for_call_agent=message,
655
+ call_agent=call_agent,
656
+ call_agent_level=call_agent_level,
657
+ temperature=temperature)
658
+ history.extend(current_gradio_history)
659
+
660
+ if special_tool_call == 'Finish':
661
+ logger.info("Finish tool called, ending chat")
662
+ yield history
663
+ next_round = False
664
+ conversation.extend(function_call_messages)
665
+ content = function_call_messages[0]['content']
666
+ if content:
667
+ return content
668
+ return "No content returned after Finish tool call."
669
+
670
+ elif special_tool_call in ['RequireClarification', 'DirectResponse']:
671
+ last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
672
+ history.append(ChatMessage(role="assistant", content=last_msg.content))
673
+ logger.info("Special tool %s called, ending chat", special_tool_call)
674
+ yield history
675
+ next_round = False
676
+ return last_msg.content
677
+
678
+ if (self.enable_summary or token_overflow) and not call_agent:
679
+ enable_summary = True
680
+ last_status = self.function_result_summary(
681
+ conversation, status=last_status, enable_summary=enable_summary)
682
+
683
+ if function_call_messages:
684
+ conversation.extend(function_call_messages)
685
+ yield history
686
+ else:
687
+ next_round = False
688
+ conversation.append({"role": "assistant", "content": ''.join(last_outputs)})
689
+ logger.info("No function call messages, ending chat")
690
+ return ''.join(last_outputs).replace("</s>", "")
691
+
692
+ last_outputs = []
693
+ last_outputs_str, token_overflow = self.llm_infer(
694
+ messages=conversation,
695
+ temperature=temperature,
696
+ tools=picked_tools_prompt,
697
+ skip_special_tokens=False,
698
+ max_new_tokens=max_new_tokens,
699
+ max_token=max_token,
700
+ seed=seed,
701
+ check_token_status=True)
702
+
703
+ if last_outputs_str is None:
704
+ logger.warning("Token limit exceeded")
705
+ if self.force_finish:
706
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
707
+ conversation, temperature, max_new_tokens, max_token)
708
+ history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
709
+ yield history
710
+ return last_outputs_str
711
+ error_msg = "Token limit exceeded."
712
+ history.append(ChatMessage(role="assistant", content=error_msg))
713
+ yield history
714
+ return error_msg
715
+
716
+ last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
717
+ for msg in history:
718
+ if msg.metadata is not None:
719
+ msg.metadata['status'] = 'done'
720
+
721
+ if '[FinalAnswer]' in last_thought:
722
+ parts = last_thought.split('[FinalAnswer]', 1)
723
+ final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
724
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
725
+ yield history
726
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
727
+ logger.info("Final answer provided: %s", final_answer[:100])
728
+ yield history
729
+ next_round = False # Ensure we exit after final answer
730
+ return final_answer
731
+ else:
732
+ history.append(ChatMessage(role="assistant", content=last_thought))
733
+ yield history
734
+
735
+ last_outputs.append(last_outputs_str)
736
+
737
+ if next_round:
738
+ if self.force_finish:
739
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
740
+ conversation, temperature, max_new_tokens, max_token)
741
+ parts = last_outputs_str.split('[FinalAnswer]', 1)
742
+ final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
743
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
744
+ yield history
745
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
746
+ logger.info("Forced final answer: %s", final_answer[:100])
747
+ yield history
748
+ return final_answer
749
+ else:
750
+ error_msg = "Reasoning rounds exceeded limit."
751
+ history.append(ChatMessage(role="assistant", content=error_msg))
752
+ yield history
753
+ return error_msg
754
+
755
+ except Exception as e:
756
+ logger.error("Exception in run_gradio_chat: %s", e, exc_info=True)
757
+ error_msg = f"Error: {e}"
758
+ history.append(ChatMessage(role="assistant", content=error_msg))
759
+ yield history
760
+ if self.force_finish:
761
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
762
+ conversation, temperature, max_new_tokens, max_token)
763
+ parts = last_outputs_str.split('[FinalAnswer]', 1)
764
+ final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
765
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
766
+ yield history
767
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
768
+ logger.info("Forced final answer after error: %s", final_answer[:100])
769
+ yield history
770
+ return final_answer
771
+ return error_msg
772
+
773
+ def run_gradio_chat_batch(self, messages: List[str],
774
+ temperature: float,
775
+ max_new_tokens: int = 2048,
776
+ max_token: int = 131072,
777
+ call_agent: bool = False,
778
+ conversation: List = None,
779
+ max_round: int = 5,
780
+ seed: int = None,
781
+ call_agent_level: int = 0):
782
+ """Run batch inference for multiple messages."""
783
+ logger.info("Starting batch chat for %d messages", len(messages))
784
+ batch_results = []
785
+
786
+ for message in messages:
787
+ # Initialize conversation for each message
788
+ conv = self.initialize_conversation(message, conversation, history=None)
789
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
790
+ call_agent, call_agent_level, message)
791
+
792
+ # Run single inference for simplicity (extend for multi-round if needed)
793
+ output, token_overflow = self.llm_infer(
794
+ messages=conv,
795
+ temperature=temperature,
796
+ tools=picked_tools_prompt,
797
+ max_new_tokens=max_new_tokens,
798
+ max_token=max_token,
799
+ skip_special_tokens=False,
800
+ seed=seed,
801
+ check_token_status=True
802
+ )
803
+
804
+ if output is None:
805
+ logger.warning("Token limit exceeded for message: %s", message[:100])
806
+ batch_results.append("Token limit exceeded.")
807
+ else:
808
+ batch_results.append(output)
809
+
810
+ logger.info("Batch chat completed for %d messages", len(messages))
811
+ return batch_results