Ali2206 commited on
Commit
dbbf9d4
·
verified ·
1 Parent(s): 1675831

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +943 -937
src/txagent/txagent.py CHANGED
@@ -1,937 +1,943 @@
1
- import gradio as gr
2
- import os
3
- import sys
4
- import json
5
- import gc
6
- import numpy as np
7
- from vllm import LLM, SamplingParams
8
- from jinja2 import Template
9
- from typing import List
10
- import types
11
- from tooluniverse import ToolUniverse
12
- from gradio import ChatMessage
13
- from .toolrag import ToolRAGModel
14
-
15
- from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
16
-
17
-
18
- class TxAgent:
19
- def __init__(self, model_name,
20
- rag_model_name,
21
- tool_files_dict=None, # None leads to the default tool files in ToolUniverse
22
- enable_finish=True,
23
- enable_rag=True,
24
- enable_summary=False,
25
- init_rag_num=0,
26
- step_rag_num=10,
27
- summary_mode='step',
28
- summary_skip_last_k=0,
29
- summary_context_length=None,
30
- force_finish=True,
31
- avoid_repeat=True,
32
- seed=None,
33
- enable_checker=False,
34
- enable_chat=False,
35
- additional_default_tools=None,
36
- ):
37
- self.model_name = model_name
38
- self.tokenizer = None
39
- self.terminators = None
40
- self.rag_model_name = rag_model_name
41
- self.tool_files_dict = tool_files_dict
42
- self.model = None
43
- self.rag_model = ToolRAGModel(rag_model_name)
44
- self.tooluniverse = None
45
- # self.tool_desc = None
46
- self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
47
- self.self_prompt = "Strictly follow the instruction."
48
- self.chat_prompt = "You are helpful assistant to chat with the user."
49
- self.enable_finish = enable_finish
50
- self.enable_rag = enable_rag
51
- self.enable_summary = enable_summary
52
- self.summary_mode = summary_mode
53
- self.summary_skip_last_k = summary_skip_last_k
54
- self.summary_context_length = summary_context_length
55
- self.init_rag_num = init_rag_num
56
- self.step_rag_num = step_rag_num
57
- self.force_finish = force_finish
58
- self.avoid_repeat = avoid_repeat
59
- self.seed = seed
60
- self.enable_checker = enable_checker
61
- self.additional_default_tools = additional_default_tools
62
- self.print_self_values()
63
-
64
- def init_model(self):
65
- self.load_models()
66
- self.load_tooluniverse()
67
- self.load_tool_desc_embedding()
68
-
69
- def print_self_values(self):
70
- for attr, value in self.__dict__.items():
71
- print(f"{attr}: {value}")
72
-
73
- def load_models(self, model_name=None):
74
- if model_name is not None:
75
- if model_name == self.model_name:
76
- return f"The model {model_name} is already loaded."
77
- self.model_name = model_name
78
-
79
- self.model = LLM(model=self.model_name)
80
- self.chat_template = Template(self.model.get_tokenizer().chat_template)
81
- self.tokenizer = self.model.get_tokenizer()
82
-
83
- return f"Model {model_name} loaded successfully."
84
-
85
- def load_tooluniverse(self):
86
- self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
87
- self.tooluniverse.load_tools()
88
- special_tools = self.tooluniverse.prepare_tool_prompts(
89
- self.tooluniverse.tool_category_dicts["special_tools"])
90
- self.special_tools_name = [tool['name'] for tool in special_tools]
91
-
92
- def load_tool_desc_embedding(self):
93
- self.rag_model.load_tool_desc_embedding(self.tooluniverse)
94
-
95
- def rag_infer(self, query, top_k=5):
96
- return self.rag_model.rag_infer(query, top_k)
97
-
98
- def initialize_tools_prompt(self, call_agent, call_agent_level, message):
99
- picked_tools_prompt = []
100
- picked_tools_prompt = self.add_special_tools(
101
- picked_tools_prompt, call_agent=call_agent)
102
- if call_agent:
103
- call_agent_level += 1
104
- if call_agent_level >= 2:
105
- call_agent = False
106
-
107
- if not call_agent:
108
- picked_tools_prompt += self.tool_RAG(
109
- message=message, rag_num=self.init_rag_num)
110
- return picked_tools_prompt, call_agent_level
111
-
112
- def initialize_conversation(self, message, conversation=None, history=None):
113
- if conversation is None:
114
- conversation = []
115
-
116
- conversation = self.set_system_prompt(
117
- conversation, self.prompt_multi_step)
118
- if history is not None:
119
- if len(history) == 0:
120
- conversation = []
121
- print("clear conversation successfully")
122
- else:
123
- for i in range(len(history)):
124
- if history[i]['role'] == 'user':
125
- if i-1 >= 0 and history[i-1]['role'] == 'assistant':
126
- conversation.append(
127
- {"role": "assistant", "content": history[i-1]['content']})
128
- conversation.append(
129
- {"role": "user", "content": history[i]['content']})
130
- if i == len(history)-1 and history[i]['role'] == 'assistant':
131
- conversation.append(
132
- {"role": "assistant", "content": history[i]['content']})
133
-
134
- conversation.append({"role": "user", "content": message})
135
-
136
- return conversation
137
-
138
- def tool_RAG(self, message=None,
139
- picked_tool_names=None,
140
- existing_tools_prompt=[],
141
- rag_num=5,
142
- return_call_result=False):
143
- extra_factor = 30 # Factor to retrieve more than rag_num
144
- if picked_tool_names is None:
145
- assert picked_tool_names is not None or message is not None
146
- picked_tool_names = self.rag_infer(
147
- message, top_k=rag_num*extra_factor)
148
-
149
- picked_tool_names_no_special = []
150
- for tool in picked_tool_names:
151
- if tool not in self.special_tools_name:
152
- picked_tool_names_no_special.append(tool)
153
- picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
154
- picked_tool_names = picked_tool_names_no_special[:rag_num]
155
-
156
- picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
157
- picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
158
- picked_tools)
159
- if return_call_result:
160
- return picked_tools_prompt, picked_tool_names
161
- return picked_tools_prompt
162
-
163
- def add_special_tools(self, tools, call_agent=False):
164
- if self.enable_finish:
165
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
166
- 'Finish', return_prompt=True))
167
- print("Finish tool is added")
168
- if call_agent:
169
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
170
- 'CallAgent', return_prompt=True))
171
- print("CallAgent tool is added")
172
- else:
173
- if self.enable_rag:
174
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
- 'Tool_RAG', return_prompt=True))
176
- print("Tool_RAG tool is added")
177
-
178
- if self.additional_default_tools is not None:
179
- for each_tool_name in self.additional_default_tools:
180
- tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
181
- each_tool_name, return_prompt=True)
182
- if tool_prompt is not None:
183
- print(f"{each_tool_name} tool is added")
184
- tools.append(tool_prompt)
185
- return tools
186
-
187
- def add_finish_tools(self, tools):
188
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
189
- 'Finish', return_prompt=True))
190
- print("Finish tool is added")
191
- return tools
192
-
193
- def set_system_prompt(self, conversation, sys_prompt):
194
- if len(conversation) == 0:
195
- conversation.append(
196
- {"role": "system", "content": sys_prompt})
197
- else:
198
- conversation[0] = {"role": "system", "content": sys_prompt}
199
- return conversation
200
-
201
- def run_function_call(self, fcall_str,
202
- return_message=False,
203
- existing_tools_prompt=None,
204
- message_for_call_agent=None,
205
- call_agent=False,
206
- call_agent_level=None,
207
- temperature=None):
208
-
209
- function_call_json, message = self.tooluniverse.extract_function_call_json(
210
- fcall_str, return_message=return_message, verbose=False)
211
- call_results = []
212
- special_tool_call = ''
213
- if function_call_json is not None:
214
- if isinstance(function_call_json, list):
215
- for i in range(len(function_call_json)):
216
- print("\033[94mTool Call:\033[0m", function_call_json[i])
217
- if function_call_json[i]["name"] == 'Finish':
218
- special_tool_call = 'Finish'
219
- break
220
- elif function_call_json[i]["name"] == 'Tool_RAG':
221
- new_tools_prompt, call_result = self.tool_RAG(
222
- message=message,
223
- existing_tools_prompt=existing_tools_prompt,
224
- rag_num=self.step_rag_num,
225
- return_call_result=True)
226
- existing_tools_prompt += new_tools_prompt
227
- elif function_call_json[i]["name"] == 'CallAgent':
228
- if call_agent_level < 2 and call_agent:
229
- solution_plan = function_call_json[i]['arguments']['solution']
230
- full_message = (
231
- message_for_call_agent +
232
- "\nYou must follow the following plan to answer the question: " +
233
- str(solution_plan)
234
- )
235
- call_result = self.run_multistep_agent(
236
- full_message, temperature=temperature,
237
- max_new_tokens=1024, max_token=99999,
238
- call_agent=False, call_agent_level=call_agent_level)
239
- call_result = call_result.split(
240
- '[FinalAnswer]')[-1].strip()
241
- else:
242
- call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
243
- else:
244
- call_result = self.tooluniverse.run_one_function(
245
- function_call_json[i])
246
-
247
- call_id = self.tooluniverse.call_id_gen()
248
- function_call_json[i]["call_id"] = call_id
249
- print("\033[94mTool Call Result:\033[0m", call_result)
250
- call_results.append({
251
- "role": "tool",
252
- "content": json.dumps({"content": call_result, "call_id": call_id})
253
- })
254
- else:
255
- call_results.append({
256
- "role": "tool",
257
- "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
258
- })
259
-
260
- revised_messages = [{
261
- "role": "assistant",
262
- "content": message.strip(),
263
- "tool_calls": json.dumps(function_call_json)
264
- }] + call_results
265
-
266
- # Yield the final result.
267
- return revised_messages, existing_tools_prompt, special_tool_call
268
-
269
- def run_function_call_stream(self, fcall_str,
270
- return_message=False,
271
- existing_tools_prompt=None,
272
- message_for_call_agent=None,
273
- call_agent=False,
274
- call_agent_level=None,
275
- temperature=None,
276
- return_gradio_history=True):
277
-
278
- function_call_json, message = self.tooluniverse.extract_function_call_json(
279
- fcall_str, return_message=return_message, verbose=False)
280
- call_results = []
281
- special_tool_call = ''
282
- if return_gradio_history:
283
- gradio_history = []
284
- if function_call_json is not None:
285
- if isinstance(function_call_json, list):
286
- for i in range(len(function_call_json)):
287
- if function_call_json[i]["name"] == 'Finish':
288
- special_tool_call = 'Finish'
289
- break
290
- elif function_call_json[i]["name"] == 'Tool_RAG':
291
- new_tools_prompt, call_result = self.tool_RAG(
292
- message=message,
293
- existing_tools_prompt=existing_tools_prompt,
294
- rag_num=self.step_rag_num,
295
- return_call_result=True)
296
- existing_tools_prompt += new_tools_prompt
297
- elif function_call_json[i]["name"] == 'DirectResponse':
298
- call_result = function_call_json[i]['arguments']['respose']
299
- special_tool_call = 'DirectResponse'
300
- elif function_call_json[i]["name"] == 'RequireClarification':
301
- call_result = function_call_json[i]['arguments']['unclear_question']
302
- special_tool_call = 'RequireClarification'
303
- elif function_call_json[i]["name"] == 'CallAgent':
304
- if call_agent_level < 2 and call_agent:
305
- solution_plan = function_call_json[i]['arguments']['solution']
306
- full_message = (
307
- message_for_call_agent +
308
- "\nYou must follow the following plan to answer the question: " +
309
- str(solution_plan)
310
- )
311
- sub_agent_task = "Sub TxAgent plan: " + \
312
- str(solution_plan)
313
- # When streaming, yield responses as they arrive.
314
- call_result = yield from self.run_gradio_chat(
315
- full_message, history=[], temperature=temperature,
316
- max_new_tokens=1024, max_token=99999,
317
- call_agent=False, call_agent_level=call_agent_level,
318
- conversation=None,
319
- sub_agent_task=sub_agent_task)
320
-
321
- call_result = call_result.split(
322
- '[FinalAnswer]')[-1]
323
- else:
324
- call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
325
- else:
326
- call_result = self.tooluniverse.run_one_function(
327
- function_call_json[i])
328
-
329
- call_id = self.tooluniverse.call_id_gen()
330
- function_call_json[i]["call_id"] = call_id
331
- call_results.append({
332
- "role": "tool",
333
- "content": json.dumps({"content": call_result, "call_id": call_id})
334
- })
335
- if return_gradio_history and function_call_json[i]["name"] != 'Finish':
336
- if function_call_json[i]["name"] == 'Tool_RAG':
337
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
338
- "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
339
-
340
- else:
341
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
342
- "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
343
- else:
344
- call_results.append({
345
- "role": "tool",
346
- "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
347
- })
348
-
349
- revised_messages = [{
350
- "role": "assistant",
351
- "content": message.strip(),
352
- "tool_calls": json.dumps(function_call_json)
353
- }] + call_results
354
-
355
- # Yield the final result.
356
- if return_gradio_history:
357
- return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
358
- else:
359
- return revised_messages, existing_tools_prompt, special_tool_call
360
-
361
- def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
362
- if conversation[-1]['role'] == 'assisant':
363
- conversation.append(
364
- {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
365
- finish_tools_prompt = self.add_finish_tools([])
366
-
367
- last_outputs_str = self.llm_infer(messages=conversation,
368
- temperature=temperature,
369
- tools=finish_tools_prompt,
370
- output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
371
- skip_special_tokens=True,
372
- max_new_tokens=max_new_tokens, max_token=max_token)
373
- print(last_outputs_str)
374
- return last_outputs_str
375
-
376
- def run_multistep_agent(self, message: str,
377
- temperature: float,
378
- max_new_tokens: int,
379
- max_token: int,
380
- max_round: int = 20,
381
- call_agent=False,
382
- call_agent_level=0) -> str:
383
- """
384
- Generate a streaming response using the llama3-8b model.
385
- Args:
386
- message (str): The input message.
387
- temperature (float): The temperature for generating the response.
388
- max_new_tokens (int): The maximum number of new tokens to generate.
389
- Returns:
390
- str: The generated response.
391
- """
392
- print("\033[1;32;40mstart\033[0m")
393
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
394
- call_agent, call_agent_level, message)
395
- conversation = self.initialize_conversation(message)
396
-
397
- outputs = []
398
- last_outputs = []
399
- next_round = True
400
- function_call_messages = []
401
- current_round = 0
402
- token_overflow = False
403
- enable_summary = False
404
- last_status = {}
405
-
406
- if self.enable_checker:
407
- checker = ReasoningTraceChecker(message, conversation)
408
- try:
409
- while next_round and current_round < max_round:
410
- current_round += 1
411
- if len(outputs) > 0:
412
- function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
413
- last_outputs, return_message=True,
414
- existing_tools_prompt=picked_tools_prompt,
415
- message_for_call_agent=message,
416
- call_agent=call_agent,
417
- call_agent_level=call_agent_level,
418
- temperature=temperature)
419
-
420
- if special_tool_call == 'Finish':
421
- next_round = False
422
- conversation.extend(function_call_messages)
423
- if isinstance(function_call_messages[0]['content'], types.GeneratorType):
424
- function_call_messages[0]['content'] = next(
425
- function_call_messages[0]['content'])
426
- return function_call_messages[0]['content'].split('[FinalAnswer]')[-1]
427
-
428
- if (self.enable_summary or token_overflow) and not call_agent:
429
- if token_overflow:
430
- print("token_overflow, using summary")
431
- enable_summary = True
432
- last_status = self.function_result_summary(
433
- conversation, status=last_status, enable_summary=enable_summary)
434
-
435
- if function_call_messages is not None:
436
- conversation.extend(function_call_messages)
437
- outputs.append(tool_result_format(
438
- function_call_messages))
439
- else:
440
- next_round = False
441
- conversation.extend(
442
- [{"role": "assistant", "content": ''.join(last_outputs)}])
443
- return ''.join(last_outputs).replace("</s>", "")
444
- if self.enable_checker:
445
- good_status, wrong_info = checker.check_conversation()
446
- if not good_status:
447
- next_round = False
448
- print(
449
- "Internal error in reasoning: " + wrong_info)
450
- break
451
- last_outputs = []
452
- outputs.append("### TxAgent:\n")
453
- last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
454
- temperature=temperature,
455
- tools=picked_tools_prompt,
456
- skip_special_tokens=False,
457
- max_new_tokens=max_new_tokens, max_token=max_token,
458
- check_token_status=True)
459
- if last_outputs_str is None:
460
- next_round = False
461
- print(
462
- "The number of tokens exceeds the maximum limit.")
463
- else:
464
- last_outputs.append(last_outputs_str)
465
- if max_round == current_round:
466
- print("The number of rounds exceeds the maximum limit!")
467
- if self.force_finish:
468
- return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
469
- else:
470
- return None
471
-
472
- except Exception as e:
473
- print(f"Error: {e}")
474
- if self.force_finish:
475
- return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
476
- else:
477
- return None
478
-
479
- def build_logits_processor(self, messages, llm):
480
- # Use the tokenizer from the LLM instance.
481
- tokenizer = llm.get_tokenizer()
482
- if self.avoid_repeat and len(messages) > 2:
483
- assistant_messages = []
484
- for i in range(1, len(messages) + 1):
485
- if messages[-i]['role'] == 'assistant':
486
- assistant_messages.append(messages[-i]['content'])
487
- if len(assistant_messages) == 2:
488
- break
489
- forbidden_ids = [tokenizer.encode(
490
- msg, add_special_tokens=False) for msg in assistant_messages]
491
- return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
492
- else:
493
- return None
494
-
495
- def llm_infer(self, messages, temperature=0.1, tools=None,
496
- output_begin_string=None, max_new_tokens=2048,
497
- max_token=None, skip_special_tokens=True,
498
- model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
499
-
500
- if model is None:
501
- model = self.model
502
-
503
- logits_processor = self.build_logits_processor(messages, model)
504
- sampling_params = SamplingParams(
505
- temperature=temperature,
506
- max_tokens=max_new_tokens,
507
- logits_processors=logits_processor,
508
- seed=seed if seed is not None else self.seed,
509
- )
510
-
511
- prompt = self.chat_template.render(
512
- messages=messages, tools=tools, add_generation_prompt=True)
513
- if output_begin_string is not None:
514
- prompt += output_begin_string
515
-
516
- if check_token_status and max_token is not None:
517
- token_overflow = False
518
- num_input_tokens = len(self.tokenizer.encode(
519
- prompt, return_tensors="pt")[0])
520
- if max_token is not None:
521
- if num_input_tokens > max_token:
522
- torch.cuda.empty_cache()
523
- gc.collect()
524
- print("Number of input tokens before inference:",
525
- num_input_tokens)
526
- logger.info(
527
- "The number of tokens exceeds the maximum limit!!!!")
528
- token_overflow = True
529
- return None, token_overflow
530
- output = model.generate(
531
- prompt,
532
- sampling_params=sampling_params,
533
- )
534
- output = output[0].outputs[0].text
535
- print("\033[92m" + output + "\033[0m")
536
- if check_token_status and max_token is not None:
537
- return output, token_overflow
538
-
539
- return output
540
-
541
- def run_self_agent(self, message: str,
542
- temperature: float,
543
- max_new_tokens: int,
544
- max_token: int) -> str:
545
-
546
- print("\033[1;32;40mstart self agent\033[0m")
547
- conversation = []
548
- conversation = self.set_system_prompt(conversation, self.self_prompt)
549
- conversation.append({"role": "user", "content": message})
550
- return self.llm_infer(messages=conversation,
551
- temperature=temperature,
552
- tools=None,
553
- max_new_tokens=max_new_tokens, max_token=max_token)
554
-
555
- def run_chat_agent(self, message: str,
556
- temperature: float,
557
- max_new_tokens: int,
558
- max_token: int) -> str:
559
-
560
- print("\033[1;32;40mstart chat agent\033[0m")
561
- conversation = []
562
- conversation = self.set_system_prompt(conversation, self.chat_prompt)
563
- conversation.append({"role": "user", "content": message})
564
- return self.llm_infer(messages=conversation,
565
- temperature=temperature,
566
- tools=None,
567
- max_new_tokens=max_new_tokens, max_token=max_token)
568
-
569
- def run_format_agent(self, message: str,
570
- answer: str,
571
- temperature: float,
572
- max_new_tokens: int,
573
- max_token: int) -> str:
574
-
575
- print("\033[1;32;40mstart format agent\033[0m")
576
- if '[FinalAnswer]' in answer:
577
- possible_final_answer = answer.split("[FinalAnswer]")[-1]
578
- elif "\n\n" in answer:
579
- possible_final_answer = answer.split("\n\n")[-1]
580
- else:
581
- possible_final_answer = answer.strip()
582
- if len(possible_final_answer) == 1:
583
- choice = possible_final_answer[0]
584
- if choice in ['A', 'B', 'C', 'D', 'E']:
585
- return choice
586
- elif len(possible_final_answer) > 1:
587
- if possible_final_answer[1] == ':':
588
- choice = possible_final_answer[0]
589
- if choice in ['A', 'B', 'C', 'D', 'E']:
590
- print("choice", choice)
591
- return choice
592
-
593
- conversation = []
594
- format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
595
- conversation = self.set_system_prompt(conversation, format_prompt)
596
- conversation.append({"role": "user", "content": message +
597
- "\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
598
- return self.llm_infer(messages=conversation,
599
- temperature=temperature,
600
- tools=None,
601
- max_new_tokens=max_new_tokens, max_token=max_token)
602
-
603
- def run_summary_agent(self, thought_calls: str,
604
- function_response: str,
605
- temperature: float,
606
- max_new_tokens: int,
607
- max_token: int) -> str:
608
- print("\033[1;32;40mSummarized Tool Result:\033[0m")
609
- generate_tool_result_summary_training_prompt = """Thought and function calls:
610
- {thought_calls}
611
-
612
- Function calls' responses:
613
- \"\"\"
614
- {function_response}
615
- \"\"\"
616
-
617
- Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
618
-
619
- Directly respond with the summarized sentence of the function calls' responses only.
620
-
621
- Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
622
- """.format(thought_calls=thought_calls, function_response=function_response)
623
- conversation = []
624
- conversation.append(
625
- {"role": "user", "content": generate_tool_result_summary_training_prompt})
626
- output = self.llm_infer(messages=conversation,
627
- temperature=temperature,
628
- tools=None,
629
- max_new_tokens=max_new_tokens, max_token=max_token)
630
-
631
- if '[' in output:
632
- output = output.split('[')[0]
633
- return output
634
-
635
- def function_result_summary(self, input_list, status, enable_summary):
636
- """
637
- Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
638
- Supports 'length' and 'step' modes, and skips the last 'k' groups.
639
-
640
- Parameters:
641
- input_list (list): A list of dictionaries containing role and other information.
642
- summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
643
- summary_context_length (int): The context length threshold for the 'length' mode.
644
- last_processed_index (tuple or int): The last processed index.
645
-
646
- Returns:
647
- list: A list of extracted information from valid sequences.
648
- """
649
- if 'tool_call_step' not in status:
650
- status['tool_call_step'] = 0
651
-
652
- for idx in range(len(input_list)):
653
- pos_id = len(input_list)-idx-1
654
- if input_list[pos_id]['role'] == 'assistant':
655
- if 'tool_calls' in input_list[pos_id]:
656
- if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
657
- status['tool_call_step'] += 1
658
- break
659
-
660
- if 'step' in status:
661
- status['step'] += 1
662
- else:
663
- status['step'] = 0
664
-
665
- if not enable_summary:
666
- return status
667
-
668
- if 'summarized_index' not in status:
669
- status['summarized_index'] = 0
670
-
671
- if 'summarized_step' not in status:
672
- status['summarized_step'] = 0
673
-
674
- if 'previous_length' not in status:
675
- status['previous_length'] = 0
676
-
677
- if 'history' not in status:
678
- status['history'] = []
679
-
680
- function_response = ''
681
- idx = 0
682
- current_summarized_index = status['summarized_index']
683
-
684
- status['history'].append(self.summary_mode == 'step' and status['summarized_step']
685
- < status['step']-status['tool_call_step']-self.summary_skip_last_k)
686
-
687
- idx = current_summarized_index
688
- while idx < len(input_list):
689
- if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
690
-
691
- if input_list[idx]['role'] == 'assistant':
692
- if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
693
- this_thought_calls = None
694
- else:
695
- if len(function_response) != 0:
696
- print("internal summary")
697
- status['summarized_step'] += 1
698
- result_summary = self.run_summary_agent(
699
- thought_calls=this_thought_calls,
700
- function_response=function_response,
701
- temperature=0.1,
702
- max_new_tokens=1024,
703
- max_token=99999
704
- )
705
-
706
- input_list.insert(
707
- last_call_idx+1, {'role': 'tool', 'content': result_summary})
708
- status['summarized_index'] = last_call_idx + 2
709
- idx += 1
710
-
711
- last_call_idx = idx
712
- this_thought_calls = input_list[idx]['content'] + \
713
- input_list[idx]['tool_calls']
714
- function_response = ''
715
-
716
- elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
717
- function_response += input_list[idx]['content']
718
- del input_list[idx]
719
- idx -= 1
720
-
721
- else:
722
- break
723
- idx += 1
724
-
725
- if len(function_response) != 0:
726
- status['summarized_step'] += 1
727
- result_summary = self.run_summary_agent(
728
- thought_calls=this_thought_calls,
729
- function_response=function_response,
730
- temperature=0.1,
731
- max_new_tokens=1024,
732
- max_token=99999
733
- )
734
-
735
- tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
736
- for tool_call in tool_calls:
737
- del tool_call['call_id']
738
- input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
739
- input_list.insert(
740
- last_call_idx+1, {'role': 'tool', 'content': result_summary})
741
- status['summarized_index'] = last_call_idx + 2
742
-
743
- return status
744
-
745
- # Following are Gradio related functions
746
-
747
- # General update method that accepts any new arguments through kwargs
748
- def update_parameters(self, **kwargs):
749
- for key, value in kwargs.items():
750
- if hasattr(self, key):
751
- setattr(self, key, value)
752
-
753
- # Return the updated attributes
754
- updated_attributes = {key: value for key,
755
- value in kwargs.items() if hasattr(self, key)}
756
- return updated_attributes
757
-
758
- def run_gradio_chat(self, message: str,
759
- history: list,
760
- temperature: float,
761
- max_new_tokens: int,
762
- max_token: int,
763
- call_agent: bool,
764
- conversation: gr.State,
765
- max_round: int = 20,
766
- seed: int = None,
767
- call_agent_level: int = 0,
768
- sub_agent_task: str = None) -> str:
769
- """
770
- Generate a streaming response using the llama3-8b model.
771
- Args:
772
- message (str): The input message.
773
- history (list): The conversation history used by ChatInterface.
774
- temperature (float): The temperature for generating the response.
775
- max_new_tokens (int): The maximum number of new tokens to generate.
776
- Returns:
777
- str: The generated response.
778
- """
779
- print("\033[1;32;40mstart\033[0m")
780
- print("len(message)", len(message))
781
- if len(message) <= 10:
782
- yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
783
- return "Please provide a valid message."
784
- outputs = []
785
- outputs_str = ''
786
- last_outputs = []
787
-
788
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
789
- call_agent,
790
- call_agent_level,
791
- message)
792
-
793
- conversation = self.initialize_conversation(
794
- message,
795
- conversation=conversation,
796
- history=history)
797
- history = []
798
-
799
- next_round = True
800
- function_call_messages = []
801
- current_round = 0
802
- enable_summary = False
803
- last_status = {} # for summary
804
- token_overflow = False
805
- if self.enable_checker:
806
- checker = ReasoningTraceChecker(
807
- message, conversation, init_index=len(conversation))
808
-
809
- try:
810
- while next_round and current_round < max_round:
811
- current_round += 1
812
- if len(last_outputs) > 0:
813
- function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
814
- last_outputs, return_message=True,
815
- existing_tools_prompt=picked_tools_prompt,
816
- message_for_call_agent=message,
817
- call_agent=call_agent,
818
- call_agent_level=call_agent_level,
819
- temperature=temperature)
820
- history.extend(current_gradio_history)
821
- if special_tool_call == 'Finish':
822
- yield history
823
- next_round = False
824
- conversation.extend(function_call_messages)
825
- return function_call_messages[0]['content']
826
- elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
827
- history.append(
828
- ChatMessage(role="assistant", content=history[-1].content))
829
- yield history
830
- next_round = False
831
- return history[-1].content
832
- if (self.enable_summary or token_overflow) and not call_agent:
833
- if token_overflow:
834
- print("token_overflow, using summary")
835
- enable_summary = True
836
- last_status = self.function_result_summary(
837
- conversation, status=last_status,
838
- enable_summary=enable_summary)
839
- if function_call_messages is not None:
840
- conversation.extend(function_call_messages)
841
- formated_md_function_call_messages = tool_result_format(
842
- function_call_messages)
843
- yield history
844
- else:
845
- next_round = False
846
- conversation.extend(
847
- [{"role": "assistant", "content": ''.join(last_outputs)}])
848
- return ''.join(last_outputs).replace("</s>", "")
849
- if self.enable_checker:
850
- good_status, wrong_info = checker.check_conversation()
851
- if not good_status:
852
- next_round = False
853
- print("Internal error in reasoning: " + wrong_info)
854
- break
855
- last_outputs = []
856
- last_outputs_str, token_overflow = self.llm_infer(
857
- messages=conversation,
858
- temperature=temperature,
859
- tools=picked_tools_prompt,
860
- skip_special_tokens=False,
861
- max_new_tokens=max_new_tokens,
862
- max_token=max_token,
863
- seed=seed,
864
- check_token_status=True)
865
- last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
866
- for each in history:
867
- if each.metadata is not None:
868
- each.metadata['status'] = 'done'
869
- if '[FinalAnswer]' in last_thought:
870
- final_thought, final_answer = last_thought.split(
871
- '[FinalAnswer]')
872
- history.append(
873
- ChatMessage(role="assistant",
874
- content=final_thought.strip())
875
- )
876
- yield history
877
- history.append(
878
- ChatMessage(
879
- role="assistant", content="**Answer**:\n"+final_answer.strip())
880
- )
881
- yield history
882
- else:
883
- history.append(ChatMessage(
884
- role="assistant", content=last_thought))
885
- yield history
886
-
887
- last_outputs.append(last_outputs_str)
888
-
889
- if next_round:
890
- if self.force_finish:
891
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
892
- conversation, temperature, max_new_tokens, max_token)
893
- for each in history:
894
- if each.metadata is not None:
895
- each.metadata['status'] = 'done'
896
- if '[FinalAnswer]' in last_thought:
897
- final_thought, final_answer = last_thought.split(
898
- '[FinalAnswer]')
899
- history.append(
900
- ChatMessage(role="assistant",
901
- content=final_thought.strip())
902
- )
903
- yield history
904
- history.append(
905
- ChatMessage(
906
- role="assistant", content="**Answer**:\n"+final_answer.strip())
907
- )
908
- yield history
909
- else:
910
- yield "The number of rounds exceeds the maximum limit!"
911
-
912
- except Exception as e:
913
- print(f"Error: {e}")
914
- if self.force_finish:
915
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
916
- conversation,
917
- temperature,
918
- max_new_tokens,
919
- max_token)
920
- for each in history:
921
- if each.metadata is not None:
922
- each.metadata['status'] = 'done'
923
- if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
924
- final_thought, final_answer = last_thought.split(
925
- '[FinalAnswer]')
926
- history.append(
927
- ChatMessage(role="assistant",
928
- content=final_thought.strip())
929
- )
930
- yield history
931
- history.append(
932
- ChatMessage(
933
- role="assistant", content="**Answer**:\n"+final_answer.strip())
934
- )
935
- yield history
936
- else:
937
- return None
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ import json
5
+ import gc
6
+ import numpy as np
7
+ from vllm import LLM, SamplingParams
8
+ from jinja2 import Template
9
+ from typing import List
10
+ import types
11
+ from tooluniverse import ToolUniverse
12
+ from gradio import ChatMessage
13
+ from .toolrag import ToolRAGModel
14
+
15
+ from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
16
+
17
+
18
+ class TxAgent:
19
+ def __init__(self, model_name,
20
+ rag_model_name,
21
+ tool_files_dict=None, # None leads to the default tool files in ToolUniverse
22
+ enable_finish=True,
23
+ enable_rag=True,
24
+ enable_summary=False,
25
+ init_rag_num=0,
26
+ step_rag_num=10,
27
+ summary_mode='step',
28
+ summary_skip_last_k=0,
29
+ summary_context_length=None,
30
+ force_finish=True,
31
+ avoid_repeat=True,
32
+ seed=None,
33
+ enable_checker=False,
34
+ enable_chat=False,
35
+ additional_default_tools=None,
36
+ ):
37
+ self.model_name = model_name
38
+ self.tokenizer = None
39
+ self.terminators = None
40
+ self.rag_model_name = rag_model_name
41
+ self.tool_files_dict = tool_files_dict
42
+ self.model = None
43
+ self.rag_model = ToolRAGModel(rag_model_name)
44
+ self.tooluniverse = None
45
+ # self.tool_desc = None
46
+ self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
47
+ self.self_prompt = "Strictly follow the instruction."
48
+ self.chat_prompt = "You are helpful assistant to chat with the user."
49
+ self.enable_finish = enable_finish
50
+ self.enable_rag = enable_rag
51
+ self.enable_summary = enable_summary
52
+ self.summary_mode = summary_mode
53
+ self.summary_skip_last_k = summary_skip_last_k
54
+ self.summary_context_length = summary_context_length
55
+ self.init_rag_num = init_rag_num
56
+ self.step_rag_num = step_rag_num
57
+ self.force_finish = force_finish
58
+ self.avoid_repeat = avoid_repeat
59
+ self.seed = seed
60
+ self.enable_checker = enable_checker
61
+ self.additional_default_tools = additional_default_tools
62
+ self.print_self_values()
63
+
64
+ def init_model(self):
65
+ self.load_models()
66
+ self.load_tooluniverse()
67
+ self.load_tool_desc_embedding()
68
+
69
+ def print_self_values(self):
70
+ for attr, value in self.__dict__.items():
71
+ print(f"{attr}: {value}")
72
+
73
+ def load_models(self, model_name=None):
74
+ if model_name is not None:
75
+ if model_name == self.model_name:
76
+ return f"The model {model_name} is already loaded."
77
+ self.model_name = model_name
78
+
79
+ self.model = LLM(model=self.model_name)
80
+ self.chat_template = Template(self.model.get_tokenizer().chat_template)
81
+ self.tokenizer = self.model.get_tokenizer()
82
+
83
+ return f"Model {model_name} loaded successfully."
84
+
85
+ def load_tooluniverse(self):
86
+ self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
87
+ self.tooluniverse.load_tools()
88
+ special_tools = self.tooluniverse.prepare_tool_prompts(
89
+ self.tooluniverse.tool_category_dicts["special_tools"])
90
+ self.special_tools_name = [tool['name'] for tool in special_tools]
91
+
92
+ def load_tool_desc_embedding(self):
93
+ self.rag_model.load_tool_desc_embedding(self.tooluniverse)
94
+
95
+ def rag_infer(self, query, top_k=5):
96
+ return self.rag_model.rag_infer(query, top_k)
97
+
98
+ def initialize_tools_prompt(self, call_agent, call_agent_level, message):
99
+ picked_tools_prompt = []
100
+ picked_tools_prompt = self.add_special_tools(
101
+ picked_tools_prompt, call_agent=call_agent)
102
+ if call_agent:
103
+ call_agent_level += 1
104
+ if call_agent_level >= 2:
105
+ call_agent = False
106
+
107
+ if not call_agent:
108
+ picked_tools_prompt += self.tool_RAG(
109
+ message=message, rag_num=self.init_rag_num)
110
+ return picked_tools_prompt, call_agent_level
111
+
112
+ def initialize_conversation(self, message, conversation=None, history=None):
113
+ if conversation is None:
114
+ conversation = []
115
+
116
+ conversation = self.set_system_prompt(
117
+ conversation, self.prompt_multi_step)
118
+ if history is not None:
119
+ if len(history) == 0:
120
+ conversation = []
121
+ print("clear conversation successfully")
122
+ else:
123
+ for i in range(len(history)):
124
+ if history[i]['role'] == 'user':
125
+ if i-1 >= 0 and history[i-1]['role'] == 'assistant':
126
+ conversation.append(
127
+ {"role": "assistant", "content": history[i-1]['content']})
128
+ conversation.append(
129
+ {"role": "user", "content": history[i]['content']})
130
+ if i == len(history)-1 and history[i]['role'] == 'assistant':
131
+ conversation.append(
132
+ {"role": "assistant", "content": history[i]['content']})
133
+
134
+ conversation.append({"role": "user", "content": message})
135
+
136
+ return conversation
137
+
138
+ def tool_RAG(self, message=None,
139
+ picked_tool_names=None,
140
+ existing_tools_prompt=[],
141
+ rag_num=5,
142
+ return_call_result=False):
143
+ extra_factor = 30 # Factor to retrieve more than rag_num
144
+ if picked_tool_names is None:
145
+ assert picked_tool_names is not None or message is not None
146
+ picked_tool_names = self.rag_infer(
147
+ message, top_k=rag_num*extra_factor)
148
+
149
+ picked_tool_names_no_special = []
150
+ for tool in picked_tool_names:
151
+ if tool not in self.special_tools_name:
152
+ picked_tool_names_no_special.append(tool)
153
+ picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
154
+ picked_tool_names = picked_tool_names_no_special[:rag_num]
155
+
156
+ picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
157
+ picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
158
+ picked_tools)
159
+ if return_call_result:
160
+ return picked_tools_prompt, picked_tool_names
161
+ return picked_tools_prompt
162
+
163
+ def add_special_tools(self, tools, call_agent=False):
164
+ if self.enable_finish:
165
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
166
+ 'Finish', return_prompt=True))
167
+ print("Finish tool is added")
168
+ if call_agent:
169
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
170
+ 'CallAgent', return_prompt=True))
171
+ print("CallAgent tool is added")
172
+ else:
173
+ if self.enable_rag:
174
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
+ 'Tool_RAG', return_prompt=True))
176
+ print("Tool_RAG tool is added")
177
+
178
+ if self.additional_default_tools is not None:
179
+ for each_tool_name in self.additional_default_tools:
180
+ tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
181
+ each_tool_name, return_prompt=True)
182
+ if tool_prompt is not None:
183
+ print(f"{each_tool_name} tool is added")
184
+ tools.append(tool_prompt)
185
+ return tools
186
+
187
+ def add_finish_tools(self, tools):
188
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
189
+ 'Finish', return_prompt=True))
190
+ print("Finish tool is added")
191
+ return tools
192
+
193
+ def set_system_prompt(self, conversation, sys_prompt):
194
+ if len(conversation) == 0:
195
+ conversation.append(
196
+ {"role": "system", "content": sys_prompt})
197
+ else:
198
+ conversation[0] = {"role": "system", "content": sys_prompt}
199
+ return conversation
200
+
201
+ def run_function_call(self, fcall_str,
202
+ return_message=False,
203
+ existing_tools_prompt=None,
204
+ message_for_call_agent=None,
205
+ call_agent=False,
206
+ call_agent_level=None,
207
+ temperature=None):
208
+
209
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
210
+ fcall_str, return_message=return_message, verbose=False)
211
+ call_results = []
212
+ special_tool_call = ''
213
+ if function_call_json is not None:
214
+ if isinstance(function_call_json, list):
215
+ for i in range(len(function_call_json)):
216
+ print("\033[94mTool Call:\033[0m", function_call_json[i])
217
+ if function_call_json[i]["name"] == 'Finish':
218
+ special_tool_call = 'Finish'
219
+ break
220
+ elif function_call_json[i]["name"] == 'Tool_RAG':
221
+ new_tools_prompt, call_result = self.tool_RAG(
222
+ message=message,
223
+ existing_tools_prompt=existing_tools_prompt,
224
+ rag_num=self.step_rag_num,
225
+ return_call_result=True)
226
+ existing_tools_prompt += new_tools_prompt
227
+ elif function_call_json[i]["name"] == 'CallAgent':
228
+ if call_agent_level < 2 and call_agent:
229
+ solution_plan = function_call_json[i]['arguments']['solution']
230
+ full_message = (
231
+ message_for_call_agent +
232
+ "\nYou must follow the following plan to answer the question: " +
233
+ str(solution_plan)
234
+ )
235
+ call_result = self.run_multistep_agent(
236
+ full_message, temperature=temperature,
237
+ max_new_tokens=1024, max_token=99999,
238
+ call_agent=False, call_agent_level=call_agent_level)
239
+ call_result = call_result.split(
240
+ '[FinalAnswer]')[-1].strip()
241
+ else:
242
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
243
+ else:
244
+ call_result = self.tooluniverse.run_one_function(
245
+ function_call_json[i])
246
+
247
+ call_id = self.tooluniverse.call_id_gen()
248
+ function_call_json[i]["call_id"] = call_id
249
+ print("\033[94mTool Call Result:\033[0m", call_result)
250
+ call_results.append({
251
+ "role": "tool",
252
+ "content": json.dumps({"content": call_result, "call_id": call_id})
253
+ })
254
+ else:
255
+ call_results.append({
256
+ "role": "tool",
257
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
258
+ })
259
+
260
+ revised_messages = [{
261
+ "role": "assistant",
262
+ "content": message.strip(),
263
+ "tool_calls": json.dumps(function_call_json)
264
+ }] + call_results
265
+
266
+ # Yield the final result.
267
+ return revised_messages, existing_tools_prompt, special_tool_call
268
+
269
+ def run_function_call_stream(self, fcall_str,
270
+ return_message=False,
271
+ existing_tools_prompt=None,
272
+ message_for_call_agent=None,
273
+ call_agent=False,
274
+ call_agent_level=None,
275
+ temperature=None,
276
+ return_gradio_history=True):
277
+
278
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
279
+ fcall_str, return_message=return_message, verbose=False)
280
+ call_results = []
281
+ special_tool_call = ''
282
+ if return_gradio_history:
283
+ gradio_history = []
284
+ if function_call_json is not None:
285
+ if isinstance(function_call_json, list):
286
+ for i in range(len(function_call_json)):
287
+ if function_call_json[i]["name"] == 'Finish':
288
+ special_tool_call = 'Finish'
289
+ break
290
+ elif function_call_json[i]["name"] == 'Tool_RAG':
291
+ new_tools_prompt, call_result = self.tool_RAG(
292
+ message=message,
293
+ existing_tools_prompt=existing_tools_prompt,
294
+ rag_num=self.step_rag_num,
295
+ return_call_result=True)
296
+ existing_tools_prompt += new_tools_prompt
297
+ elif function_call_json[i]["name"] == 'DirectResponse':
298
+ call_result = function_call_json[i]['arguments']['respose']
299
+ special_tool_call = 'DirectResponse'
300
+ elif function_call_json[i]["name"] == 'RequireClarification':
301
+ call_result = function_call_json[i]['arguments']['unclear_question']
302
+ special_tool_call = 'RequireClarification'
303
+ elif function_call_json[i]["name"] == 'CallAgent':
304
+ if call_agent_level < 2 and call_agent:
305
+ solution_plan = function_call_json[i]['arguments']['solution']
306
+ full_message = (
307
+ message_for_call_agent +
308
+ "\nYou must follow the following plan to answer the question: " +
309
+ str(solution_plan)
310
+ )
311
+ sub_agent_task = "Sub TxAgent plan: " + \
312
+ str(solution_plan)
313
+ # When streaming, yield responses as they arrive.
314
+ call_result = yield from self.run_gradio_chat(
315
+ full_message, history=[], temperature=temperature,
316
+ max_new_tokens=1024, max_token=99999,
317
+ call_agent=False, call_agent_level=call_agent_level,
318
+ conversation=None,
319
+ sub_agent_task=sub_agent_task)
320
+
321
+ call_result = call_result.split(
322
+ '[FinalAnswer]')[-1]
323
+ else:
324
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
325
+ else:
326
+ call_result = self.tooluniverse.run_one_function(
327
+ function_call_json[i])
328
+
329
+ call_id = self.tooluniverse.call_id_gen()
330
+ function_call_json[i]["call_id"] = call_id
331
+ call_results.append({
332
+ "role": "tool",
333
+ "content": json.dumps({"content": call_result, "call_id": call_id})
334
+ })
335
+ if return_gradio_history and function_call_json[i]["name"] != 'Finish':
336
+ if function_call_json[i]["name"] == 'Tool_RAG':
337
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
338
+ "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
339
+
340
+ else:
341
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
342
+ "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
343
+ else:
344
+ call_results.append({
345
+ "role": "tool",
346
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
347
+ })
348
+
349
+ revised_messages = [{
350
+ "role": "assistant",
351
+ "content": message.strip(),
352
+ "tool_calls": json.dumps(function_call_json)
353
+ }] + call_results
354
+
355
+ # Yield the final result.
356
+ if return_gradio_history:
357
+ return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
358
+ else:
359
+ return revised_messages, existing_tools_prompt, special_tool_call
360
+
361
+ def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
362
+ if conversation[-1]['role'] == 'assisant':
363
+ conversation.append(
364
+ {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
365
+ finish_tools_prompt = self.add_finish_tools([])
366
+
367
+ last_outputs_str = self.llm_infer(messages=conversation,
368
+ temperature=temperature,
369
+ tools=finish_tools_prompt,
370
+ output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
371
+ skip_special_tokens=True,
372
+ max_new_tokens=max_new_tokens, max_token=max_token)
373
+ print(last_outputs_str)
374
+ return last_outputs_str
375
+
376
+ def run_multistep_agent(self, message: str,
377
+ temperature: float,
378
+ max_new_tokens: int,
379
+ max_token: int,
380
+ max_round: int = 20,
381
+ call_agent=False,
382
+ call_agent_level=0) -> str:
383
+ """
384
+ Generate a streaming response using the llama3-8b model.
385
+ Args:
386
+ message (str): The input message.
387
+ temperature (float): The temperature for generating the response.
388
+ max_new_tokens (int): The maximum number of new tokens to generate.
389
+ Returns:
390
+ str: The generated response.
391
+ """
392
+ print("\033[1;32;40mstart\033[0m")
393
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
394
+ call_agent, call_agent_level, message)
395
+ conversation = self.initialize_conversation(message)
396
+
397
+ outputs = []
398
+ last_outputs = []
399
+ next_round = True
400
+ function_call_messages = []
401
+ current_round = 0
402
+ token_overflow = False
403
+ enable_summary = False
404
+ last_status = {}
405
+
406
+ if self.enable_checker:
407
+ checker = ReasoningTraceChecker(message, conversation)
408
+ try:
409
+ while next_round and current_round < max_round:
410
+ current_round += 1
411
+ if len(outputs) > 0:
412
+ function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
413
+ last_outputs, return_message=True,
414
+ existing_tools_prompt=picked_tools_prompt,
415
+ message_for_call_agent=message,
416
+ call_agent=call_agent,
417
+ call_agent_level=call_agent_level,
418
+ temperature=temperature)
419
+
420
+ if special_tool_call == 'Finish':
421
+ next_round = False
422
+ conversation.extend(function_call_messages)
423
+ if isinstance(function_call_messages[0]['content'], types.GeneratorType):
424
+ function_call_messages[0]['content'] = next(
425
+ function_call_messages[0]['content'])
426
+ return function_call_messages[0]['content'].split('[FinalAnswer]')[-1]
427
+
428
+ if (self.enable_summary or token_overflow) and not call_agent:
429
+ if token_overflow:
430
+ print("token_overflow, using summary")
431
+ enable_summary = True
432
+ last_status = self.function_result_summary(
433
+ conversation, status=last_status, enable_summary=enable_summary)
434
+
435
+ if function_call_messages is not None:
436
+ conversation.extend(function_call_messages)
437
+ outputs.append(tool_result_format(
438
+ function_call_messages))
439
+ else:
440
+ next_round = False
441
+ conversation.extend(
442
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
443
+ return ''.join(last_outputs).replace("</s>", "")
444
+ if self.enable_checker:
445
+ good_status, wrong_info = checker.check_conversation()
446
+ if not good_status:
447
+ next_round = False
448
+ print(
449
+ "Internal error in reasoning: " + wrong_info)
450
+ break
451
+ last_outputs = []
452
+ outputs.append("### TxAgent:\n")
453
+ last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
454
+ temperature=temperature,
455
+ tools=picked_tools_prompt,
456
+ skip_special_tokens=False,
457
+ max_new_tokens=max_new_tokens, max_token=max_token,
458
+ check_token_status=True)
459
+ if last_outputs_str is None:
460
+ next_round = False
461
+ print(
462
+ "The number of tokens exceeds the maximum limit.")
463
+ else:
464
+ last_outputs.append(last_outputs_str)
465
+ if max_round == current_round:
466
+ print("The number of rounds exceeds the maximum limit!")
467
+ if self.force_finish:
468
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
469
+ else:
470
+ return None
471
+
472
+ except Exception as e:
473
+ print(f"Error: {e}")
474
+ if self.force_finish:
475
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
476
+ else:
477
+ return None
478
+
479
+ def build_logits_processor(self, messages, llm):
480
+ # Use the tokenizer from the LLM instance.
481
+ tokenizer = llm.get_tokenizer()
482
+ if self.avoid_repeat and len(messages) > 2:
483
+ assistant_messages = []
484
+ for i in range(1, len(messages) + 1):
485
+ if messages[-i]['role'] == 'assistant':
486
+ assistant_messages.append(messages[-i]['content'])
487
+ if len(assistant_messages) == 2:
488
+ break
489
+ forbidden_ids = [tokenizer.encode(
490
+ msg, add_special_tokens=False) for msg in assistant_messages]
491
+ return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
492
+ else:
493
+ return None
494
+
495
+ def llm_infer(self, messages, temperature=0.7, tools=None,
496
+ output_begin_string=None, max_new_tokens=2048,
497
+ max_token=None, skip_special_tokens=True,
498
+ model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
499
+
500
+ if model is None:
501
+ model = self.model
502
+
503
+ # Removed logits_processors for vLLM compatibility
504
+ sampling_params = {
505
+ "temperature": temperature,
506
+ "top_k": 40,
507
+ "top_p": 0.95,
508
+ "max_tokens": max_new_tokens,
509
+ "stop": ["</s>", "<|eot_id|>", "<|endoftext|>"],
510
+ "seed": seed if seed is not None else self.seed
511
+ }
512
+
513
+ prompt = self.chat_template.render(
514
+ messages=messages, tools=tools, add_generation_prompt=True
515
+ )
516
+
517
+ if output_begin_string is not None:
518
+ prompt += output_begin_string
519
+
520
+ token_overflow = False
521
+ if check_token_status and max_token is not None:
522
+ num_input_tokens = len(self.tokenizer.encode(
523
+ prompt, return_tensors="pt")[0])
524
+ if num_input_tokens > max_token:
525
+ import gc
526
+ import torch
527
+ torch.cuda.empty_cache()
528
+ gc.collect()
529
+ print("Number of input tokens before inference:", num_input_tokens)
530
+ token_overflow = True
531
+ return None, token_overflow
532
+
533
+ outputs = model.generate(prompt, sampling_params=sampling_params)
534
+
535
+ if isinstance(outputs, list) and len(outputs) > 0:
536
+ output = outputs[0].outputs[0].text
537
+ else:
538
+ output = "[Error: No output]"
539
+
540
+ print("\033[92m" + output + "\033[0m")
541
+
542
+ if check_token_status and max_token is not None:
543
+ return output, token_overflow
544
+
545
+ return output
546
+
547
+ def run_self_agent(self, message: str,
548
+ temperature: float,
549
+ max_new_tokens: int,
550
+ max_token: int) -> str:
551
+
552
+ print("\033[1;32;40mstart self agent\033[0m")
553
+ conversation = []
554
+ conversation = self.set_system_prompt(conversation, self.self_prompt)
555
+ conversation.append({"role": "user", "content": message})
556
+ return self.llm_infer(messages=conversation,
557
+ temperature=temperature,
558
+ tools=None,
559
+ max_new_tokens=max_new_tokens, max_token=max_token)
560
+
561
+ def run_chat_agent(self, message: str,
562
+ temperature: float,
563
+ max_new_tokens: int,
564
+ max_token: int) -> str:
565
+
566
+ print("\033[1;32;40mstart chat agent\033[0m")
567
+ conversation = []
568
+ conversation = self.set_system_prompt(conversation, self.chat_prompt)
569
+ conversation.append({"role": "user", "content": message})
570
+ return self.llm_infer(messages=conversation,
571
+ temperature=temperature,
572
+ tools=None,
573
+ max_new_tokens=max_new_tokens, max_token=max_token)
574
+
575
+ def run_format_agent(self, message: str,
576
+ answer: str,
577
+ temperature: float,
578
+ max_new_tokens: int,
579
+ max_token: int) -> str:
580
+
581
+ print("\033[1;32;40mstart format agent\033[0m")
582
+ if '[FinalAnswer]' in answer:
583
+ possible_final_answer = answer.split("[FinalAnswer]")[-1]
584
+ elif "\n\n" in answer:
585
+ possible_final_answer = answer.split("\n\n")[-1]
586
+ else:
587
+ possible_final_answer = answer.strip()
588
+ if len(possible_final_answer) == 1:
589
+ choice = possible_final_answer[0]
590
+ if choice in ['A', 'B', 'C', 'D', 'E']:
591
+ return choice
592
+ elif len(possible_final_answer) > 1:
593
+ if possible_final_answer[1] == ':':
594
+ choice = possible_final_answer[0]
595
+ if choice in ['A', 'B', 'C', 'D', 'E']:
596
+ print("choice", choice)
597
+ return choice
598
+
599
+ conversation = []
600
+ format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
601
+ conversation = self.set_system_prompt(conversation, format_prompt)
602
+ conversation.append({"role": "user", "content": message +
603
+ "\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
604
+ return self.llm_infer(messages=conversation,
605
+ temperature=temperature,
606
+ tools=None,
607
+ max_new_tokens=max_new_tokens, max_token=max_token)
608
+
609
+ def run_summary_agent(self, thought_calls: str,
610
+ function_response: str,
611
+ temperature: float,
612
+ max_new_tokens: int,
613
+ max_token: int) -> str:
614
+ print("\033[1;32;40mSummarized Tool Result:\033[0m")
615
+ generate_tool_result_summary_training_prompt = """Thought and function calls:
616
+ {thought_calls}
617
+
618
+ Function calls' responses:
619
+ \"\"\"
620
+ {function_response}
621
+ \"\"\"
622
+
623
+ Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
624
+
625
+ Directly respond with the summarized sentence of the function calls' responses only.
626
+
627
+ Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
628
+ """.format(thought_calls=thought_calls, function_response=function_response)
629
+ conversation = []
630
+ conversation.append(
631
+ {"role": "user", "content": generate_tool_result_summary_training_prompt})
632
+ output = self.llm_infer(messages=conversation,
633
+ temperature=temperature,
634
+ tools=None,
635
+ max_new_tokens=max_new_tokens, max_token=max_token)
636
+
637
+ if '[' in output:
638
+ output = output.split('[')[0]
639
+ return output
640
+
641
+ def function_result_summary(self, input_list, status, enable_summary):
642
+ """
643
+ Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
644
+ Supports 'length' and 'step' modes, and skips the last 'k' groups.
645
+
646
+ Parameters:
647
+ input_list (list): A list of dictionaries containing role and other information.
648
+ summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
649
+ summary_context_length (int): The context length threshold for the 'length' mode.
650
+ last_processed_index (tuple or int): The last processed index.
651
+
652
+ Returns:
653
+ list: A list of extracted information from valid sequences.
654
+ """
655
+ if 'tool_call_step' not in status:
656
+ status['tool_call_step'] = 0
657
+
658
+ for idx in range(len(input_list)):
659
+ pos_id = len(input_list)-idx-1
660
+ if input_list[pos_id]['role'] == 'assistant':
661
+ if 'tool_calls' in input_list[pos_id]:
662
+ if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
663
+ status['tool_call_step'] += 1
664
+ break
665
+
666
+ if 'step' in status:
667
+ status['step'] += 1
668
+ else:
669
+ status['step'] = 0
670
+
671
+ if not enable_summary:
672
+ return status
673
+
674
+ if 'summarized_index' not in status:
675
+ status['summarized_index'] = 0
676
+
677
+ if 'summarized_step' not in status:
678
+ status['summarized_step'] = 0
679
+
680
+ if 'previous_length' not in status:
681
+ status['previous_length'] = 0
682
+
683
+ if 'history' not in status:
684
+ status['history'] = []
685
+
686
+ function_response = ''
687
+ idx = 0
688
+ current_summarized_index = status['summarized_index']
689
+
690
+ status['history'].append(self.summary_mode == 'step' and status['summarized_step']
691
+ < status['step']-status['tool_call_step']-self.summary_skip_last_k)
692
+
693
+ idx = current_summarized_index
694
+ while idx < len(input_list):
695
+ if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
696
+
697
+ if input_list[idx]['role'] == 'assistant':
698
+ if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
699
+ this_thought_calls = None
700
+ else:
701
+ if len(function_response) != 0:
702
+ print("internal summary")
703
+ status['summarized_step'] += 1
704
+ result_summary = self.run_summary_agent(
705
+ thought_calls=this_thought_calls,
706
+ function_response=function_response,
707
+ temperature=0.1,
708
+ max_new_tokens=1024,
709
+ max_token=99999
710
+ )
711
+
712
+ input_list.insert(
713
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
714
+ status['summarized_index'] = last_call_idx + 2
715
+ idx += 1
716
+
717
+ last_call_idx = idx
718
+ this_thought_calls = input_list[idx]['content'] + \
719
+ input_list[idx]['tool_calls']
720
+ function_response = ''
721
+
722
+ elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
723
+ function_response += input_list[idx]['content']
724
+ del input_list[idx]
725
+ idx -= 1
726
+
727
+ else:
728
+ break
729
+ idx += 1
730
+
731
+ if len(function_response) != 0:
732
+ status['summarized_step'] += 1
733
+ result_summary = self.run_summary_agent(
734
+ thought_calls=this_thought_calls,
735
+ function_response=function_response,
736
+ temperature=0.1,
737
+ max_new_tokens=1024,
738
+ max_token=99999
739
+ )
740
+
741
+ tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
742
+ for tool_call in tool_calls:
743
+ del tool_call['call_id']
744
+ input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
745
+ input_list.insert(
746
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
747
+ status['summarized_index'] = last_call_idx + 2
748
+
749
+ return status
750
+
751
+ # Following are Gradio related functions
752
+
753
+ # General update method that accepts any new arguments through kwargs
754
+ def update_parameters(self, **kwargs):
755
+ for key, value in kwargs.items():
756
+ if hasattr(self, key):
757
+ setattr(self, key, value)
758
+
759
+ # Return the updated attributes
760
+ updated_attributes = {key: value for key,
761
+ value in kwargs.items() if hasattr(self, key)}
762
+ return updated_attributes
763
+
764
+ def run_gradio_chat(self, message: str,
765
+ history: list,
766
+ temperature: float,
767
+ max_new_tokens: int,
768
+ max_token: int,
769
+ call_agent: bool,
770
+ conversation: gr.State,
771
+ max_round: int = 20,
772
+ seed: int = None,
773
+ call_agent_level: int = 0,
774
+ sub_agent_task: str = None) -> str:
775
+ """
776
+ Generate a streaming response using the llama3-8b model.
777
+ Args:
778
+ message (str): The input message.
779
+ history (list): The conversation history used by ChatInterface.
780
+ temperature (float): The temperature for generating the response.
781
+ max_new_tokens (int): The maximum number of new tokens to generate.
782
+ Returns:
783
+ str: The generated response.
784
+ """
785
+ print("\033[1;32;40mstart\033[0m")
786
+ print("len(message)", len(message))
787
+ if len(message) <= 10:
788
+ yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
789
+ return "Please provide a valid message."
790
+ outputs = []
791
+ outputs_str = ''
792
+ last_outputs = []
793
+
794
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
795
+ call_agent,
796
+ call_agent_level,
797
+ message)
798
+
799
+ conversation = self.initialize_conversation(
800
+ message,
801
+ conversation=conversation,
802
+ history=history)
803
+ history = []
804
+
805
+ next_round = True
806
+ function_call_messages = []
807
+ current_round = 0
808
+ enable_summary = False
809
+ last_status = {} # for summary
810
+ token_overflow = False
811
+ if self.enable_checker:
812
+ checker = ReasoningTraceChecker(
813
+ message, conversation, init_index=len(conversation))
814
+
815
+ try:
816
+ while next_round and current_round < max_round:
817
+ current_round += 1
818
+ if len(last_outputs) > 0:
819
+ function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
820
+ last_outputs, return_message=True,
821
+ existing_tools_prompt=picked_tools_prompt,
822
+ message_for_call_agent=message,
823
+ call_agent=call_agent,
824
+ call_agent_level=call_agent_level,
825
+ temperature=temperature)
826
+ history.extend(current_gradio_history)
827
+ if special_tool_call == 'Finish':
828
+ yield history
829
+ next_round = False
830
+ conversation.extend(function_call_messages)
831
+ return function_call_messages[0]['content']
832
+ elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
833
+ history.append(
834
+ ChatMessage(role="assistant", content=history[-1].content))
835
+ yield history
836
+ next_round = False
837
+ return history[-1].content
838
+ if (self.enable_summary or token_overflow) and not call_agent:
839
+ if token_overflow:
840
+ print("token_overflow, using summary")
841
+ enable_summary = True
842
+ last_status = self.function_result_summary(
843
+ conversation, status=last_status,
844
+ enable_summary=enable_summary)
845
+ if function_call_messages is not None:
846
+ conversation.extend(function_call_messages)
847
+ formated_md_function_call_messages = tool_result_format(
848
+ function_call_messages)
849
+ yield history
850
+ else:
851
+ next_round = False
852
+ conversation.extend(
853
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
854
+ return ''.join(last_outputs).replace("</s>", "")
855
+ if self.enable_checker:
856
+ good_status, wrong_info = checker.check_conversation()
857
+ if not good_status:
858
+ next_round = False
859
+ print("Internal error in reasoning: " + wrong_info)
860
+ break
861
+ last_outputs = []
862
+ last_outputs_str, token_overflow = self.llm_infer(
863
+ messages=conversation,
864
+ temperature=temperature,
865
+ tools=picked_tools_prompt,
866
+ skip_special_tokens=False,
867
+ max_new_tokens=max_new_tokens,
868
+ max_token=max_token,
869
+ seed=seed,
870
+ check_token_status=True)
871
+ last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
872
+ for each in history:
873
+ if each.metadata is not None:
874
+ each.metadata['status'] = 'done'
875
+ if '[FinalAnswer]' in last_thought:
876
+ final_thought, final_answer = last_thought.split(
877
+ '[FinalAnswer]')
878
+ history.append(
879
+ ChatMessage(role="assistant",
880
+ content=final_thought.strip())
881
+ )
882
+ yield history
883
+ history.append(
884
+ ChatMessage(
885
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
886
+ )
887
+ yield history
888
+ else:
889
+ history.append(ChatMessage(
890
+ role="assistant", content=last_thought))
891
+ yield history
892
+
893
+ last_outputs.append(last_outputs_str)
894
+
895
+ if next_round:
896
+ if self.force_finish:
897
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
898
+ conversation, temperature, max_new_tokens, max_token)
899
+ for each in history:
900
+ if each.metadata is not None:
901
+ each.metadata['status'] = 'done'
902
+ if '[FinalAnswer]' in last_thought:
903
+ final_thought, final_answer = last_thought.split(
904
+ '[FinalAnswer]')
905
+ history.append(
906
+ ChatMessage(role="assistant",
907
+ content=final_thought.strip())
908
+ )
909
+ yield history
910
+ history.append(
911
+ ChatMessage(
912
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
913
+ )
914
+ yield history
915
+ else:
916
+ yield "The number of rounds exceeds the maximum limit!"
917
+
918
+ except Exception as e:
919
+ print(f"Error: {e}")
920
+ if self.force_finish:
921
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
922
+ conversation,
923
+ temperature,
924
+ max_new_tokens,
925
+ max_token)
926
+ for each in history:
927
+ if each.metadata is not None:
928
+ each.metadata['status'] = 'done'
929
+ if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
930
+ final_thought, final_answer = last_thought.split(
931
+ '[FinalAnswer]')
932
+ history.append(
933
+ ChatMessage(role="assistant",
934
+ content=final_thought.strip())
935
+ )
936
+ yield history
937
+ history.append(
938
+ ChatMessage(
939
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
940
+ )
941
+ yield history
942
+ else:
943
+ return None