Ali2206 commited on
Commit
32df88c
·
verified ·
1 Parent(s): 9c9d2f8

Delete src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +0 -943
src/txagent/txagent.py DELETED
@@ -1,943 +0,0 @@
1
- import gradio as gr
2
- import os
3
- import sys
4
- import json
5
- import gc
6
- import numpy as np
7
- from vllm import LLM, SamplingParams
8
- from jinja2 import Template
9
- from typing import List
10
- import types
11
- from tooluniverse import ToolUniverse
12
- from gradio import ChatMessage
13
- from .toolrag import ToolRAGModel
14
-
15
- from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
16
-
17
-
18
- class TxAgent:
19
- def __init__(self, model_name,
20
- rag_model_name,
21
- tool_files_dict=None, # None leads to the default tool files in ToolUniverse
22
- enable_finish=True,
23
- enable_rag=True,
24
- enable_summary=False,
25
- init_rag_num=0,
26
- step_rag_num=10,
27
- summary_mode='step',
28
- summary_skip_last_k=0,
29
- summary_context_length=None,
30
- force_finish=True,
31
- avoid_repeat=True,
32
- seed=None,
33
- enable_checker=False,
34
- enable_chat=False,
35
- additional_default_tools=None,
36
- ):
37
- self.model_name = model_name
38
- self.tokenizer = None
39
- self.terminators = None
40
- self.rag_model_name = rag_model_name
41
- self.tool_files_dict = tool_files_dict
42
- self.model = None
43
- self.rag_model = ToolRAGModel(rag_model_name)
44
- self.tooluniverse = None
45
- # self.tool_desc = None
46
- self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
47
- self.self_prompt = "Strictly follow the instruction."
48
- self.chat_prompt = "You are helpful assistant to chat with the user."
49
- self.enable_finish = enable_finish
50
- self.enable_rag = enable_rag
51
- self.enable_summary = enable_summary
52
- self.summary_mode = summary_mode
53
- self.summary_skip_last_k = summary_skip_last_k
54
- self.summary_context_length = summary_context_length
55
- self.init_rag_num = init_rag_num
56
- self.step_rag_num = step_rag_num
57
- self.force_finish = force_finish
58
- self.avoid_repeat = avoid_repeat
59
- self.seed = seed
60
- self.enable_checker = enable_checker
61
- self.additional_default_tools = additional_default_tools
62
- self.print_self_values()
63
-
64
- def init_model(self):
65
- self.load_models()
66
- self.load_tooluniverse()
67
- self.load_tool_desc_embedding()
68
-
69
- def print_self_values(self):
70
- for attr, value in self.__dict__.items():
71
- print(f"{attr}: {value}")
72
-
73
- def load_models(self, model_name=None):
74
- if model_name is not None:
75
- if model_name == self.model_name:
76
- return f"The model {model_name} is already loaded."
77
- self.model_name = model_name
78
-
79
- self.model = LLM(model=self.model_name)
80
- self.chat_template = Template(self.model.get_tokenizer().chat_template)
81
- self.tokenizer = self.model.get_tokenizer()
82
-
83
- return f"Model {model_name} loaded successfully."
84
-
85
- def load_tooluniverse(self):
86
- self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
87
- self.tooluniverse.load_tools()
88
- special_tools = self.tooluniverse.prepare_tool_prompts(
89
- self.tooluniverse.tool_category_dicts["special_tools"])
90
- self.special_tools_name = [tool['name'] for tool in special_tools]
91
-
92
- def load_tool_desc_embedding(self):
93
- self.rag_model.load_tool_desc_embedding(self.tooluniverse)
94
-
95
- def rag_infer(self, query, top_k=5):
96
- return self.rag_model.rag_infer(query, top_k)
97
-
98
- def initialize_tools_prompt(self, call_agent, call_agent_level, message):
99
- picked_tools_prompt = []
100
- picked_tools_prompt = self.add_special_tools(
101
- picked_tools_prompt, call_agent=call_agent)
102
- if call_agent:
103
- call_agent_level += 1
104
- if call_agent_level >= 2:
105
- call_agent = False
106
-
107
- if not call_agent:
108
- picked_tools_prompt += self.tool_RAG(
109
- message=message, rag_num=self.init_rag_num)
110
- return picked_tools_prompt, call_agent_level
111
-
112
- def initialize_conversation(self, message, conversation=None, history=None):
113
- if conversation is None:
114
- conversation = []
115
-
116
- conversation = self.set_system_prompt(
117
- conversation, self.prompt_multi_step)
118
- if history is not None:
119
- if len(history) == 0:
120
- conversation = []
121
- print("clear conversation successfully")
122
- else:
123
- for i in range(len(history)):
124
- if history[i]['role'] == 'user':
125
- if i-1 >= 0 and history[i-1]['role'] == 'assistant':
126
- conversation.append(
127
- {"role": "assistant", "content": history[i-1]['content']})
128
- conversation.append(
129
- {"role": "user", "content": history[i]['content']})
130
- if i == len(history)-1 and history[i]['role'] == 'assistant':
131
- conversation.append(
132
- {"role": "assistant", "content": history[i]['content']})
133
-
134
- conversation.append({"role": "user", "content": message})
135
-
136
- return conversation
137
-
138
- def tool_RAG(self, message=None,
139
- picked_tool_names=None,
140
- existing_tools_prompt=[],
141
- rag_num=5,
142
- return_call_result=False):
143
- extra_factor = 30 # Factor to retrieve more than rag_num
144
- if picked_tool_names is None:
145
- assert picked_tool_names is not None or message is not None
146
- picked_tool_names = self.rag_infer(
147
- message, top_k=rag_num*extra_factor)
148
-
149
- picked_tool_names_no_special = []
150
- for tool in picked_tool_names:
151
- if tool not in self.special_tools_name:
152
- picked_tool_names_no_special.append(tool)
153
- picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
154
- picked_tool_names = picked_tool_names_no_special[:rag_num]
155
-
156
- picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
157
- picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
158
- picked_tools)
159
- if return_call_result:
160
- return picked_tools_prompt, picked_tool_names
161
- return picked_tools_prompt
162
-
163
- def add_special_tools(self, tools, call_agent=False):
164
- if self.enable_finish:
165
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
166
- 'Finish', return_prompt=True))
167
- print("Finish tool is added")
168
- if call_agent:
169
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
170
- 'CallAgent', return_prompt=True))
171
- print("CallAgent tool is added")
172
- else:
173
- if self.enable_rag:
174
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
- 'Tool_RAG', return_prompt=True))
176
- print("Tool_RAG tool is added")
177
-
178
- if self.additional_default_tools is not None:
179
- for each_tool_name in self.additional_default_tools:
180
- tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
181
- each_tool_name, return_prompt=True)
182
- if tool_prompt is not None:
183
- print(f"{each_tool_name} tool is added")
184
- tools.append(tool_prompt)
185
- return tools
186
-
187
- def add_finish_tools(self, tools):
188
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
189
- 'Finish', return_prompt=True))
190
- print("Finish tool is added")
191
- return tools
192
-
193
- def set_system_prompt(self, conversation, sys_prompt):
194
- if len(conversation) == 0:
195
- conversation.append(
196
- {"role": "system", "content": sys_prompt})
197
- else:
198
- conversation[0] = {"role": "system", "content": sys_prompt}
199
- return conversation
200
-
201
- def run_function_call(self, fcall_str,
202
- return_message=False,
203
- existing_tools_prompt=None,
204
- message_for_call_agent=None,
205
- call_agent=False,
206
- call_agent_level=None,
207
- temperature=None):
208
-
209
- function_call_json, message = self.tooluniverse.extract_function_call_json(
210
- fcall_str, return_message=return_message, verbose=False)
211
- call_results = []
212
- special_tool_call = ''
213
- if function_call_json is not None:
214
- if isinstance(function_call_json, list):
215
- for i in range(len(function_call_json)):
216
- print("\033[94mTool Call:\033[0m", function_call_json[i])
217
- if function_call_json[i]["name"] == 'Finish':
218
- special_tool_call = 'Finish'
219
- break
220
- elif function_call_json[i]["name"] == 'Tool_RAG':
221
- new_tools_prompt, call_result = self.tool_RAG(
222
- message=message,
223
- existing_tools_prompt=existing_tools_prompt,
224
- rag_num=self.step_rag_num,
225
- return_call_result=True)
226
- existing_tools_prompt += new_tools_prompt
227
- elif function_call_json[i]["name"] == 'CallAgent':
228
- if call_agent_level < 2 and call_agent:
229
- solution_plan = function_call_json[i]['arguments']['solution']
230
- full_message = (
231
- message_for_call_agent +
232
- "\nYou must follow the following plan to answer the question: " +
233
- str(solution_plan)
234
- )
235
- call_result = self.run_multistep_agent(
236
- full_message, temperature=temperature,
237
- max_new_tokens=1024, max_token=99999,
238
- call_agent=False, call_agent_level=call_agent_level)
239
- call_result = call_result.split(
240
- '[FinalAnswer]')[-1].strip()
241
- else:
242
- call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
243
- else:
244
- call_result = self.tooluniverse.run_one_function(
245
- function_call_json[i])
246
-
247
- call_id = self.tooluniverse.call_id_gen()
248
- function_call_json[i]["call_id"] = call_id
249
- print("\033[94mTool Call Result:\033[0m", call_result)
250
- call_results.append({
251
- "role": "tool",
252
- "content": json.dumps({"content": call_result, "call_id": call_id})
253
- })
254
- else:
255
- call_results.append({
256
- "role": "tool",
257
- "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
258
- })
259
-
260
- revised_messages = [{
261
- "role": "assistant",
262
- "content": message.strip(),
263
- "tool_calls": json.dumps(function_call_json)
264
- }] + call_results
265
-
266
- # Yield the final result.
267
- return revised_messages, existing_tools_prompt, special_tool_call
268
-
269
- def run_function_call_stream(self, fcall_str,
270
- return_message=False,
271
- existing_tools_prompt=None,
272
- message_for_call_agent=None,
273
- call_agent=False,
274
- call_agent_level=None,
275
- temperature=None,
276
- return_gradio_history=True):
277
-
278
- function_call_json, message = self.tooluniverse.extract_function_call_json(
279
- fcall_str, return_message=return_message, verbose=False)
280
- call_results = []
281
- special_tool_call = ''
282
- if return_gradio_history:
283
- gradio_history = []
284
- if function_call_json is not None:
285
- if isinstance(function_call_json, list):
286
- for i in range(len(function_call_json)):
287
- if function_call_json[i]["name"] == 'Finish':
288
- special_tool_call = 'Finish'
289
- break
290
- elif function_call_json[i]["name"] == 'Tool_RAG':
291
- new_tools_prompt, call_result = self.tool_RAG(
292
- message=message,
293
- existing_tools_prompt=existing_tools_prompt,
294
- rag_num=self.step_rag_num,
295
- return_call_result=True)
296
- existing_tools_prompt += new_tools_prompt
297
- elif function_call_json[i]["name"] == 'DirectResponse':
298
- call_result = function_call_json[i]['arguments']['respose']
299
- special_tool_call = 'DirectResponse'
300
- elif function_call_json[i]["name"] == 'RequireClarification':
301
- call_result = function_call_json[i]['arguments']['unclear_question']
302
- special_tool_call = 'RequireClarification'
303
- elif function_call_json[i]["name"] == 'CallAgent':
304
- if call_agent_level < 2 and call_agent:
305
- solution_plan = function_call_json[i]['arguments']['solution']
306
- full_message = (
307
- message_for_call_agent +
308
- "\nYou must follow the following plan to answer the question: " +
309
- str(solution_plan)
310
- )
311
- sub_agent_task = "Sub TxAgent plan: " + \
312
- str(solution_plan)
313
- # When streaming, yield responses as they arrive.
314
- call_result = yield from self.run_gradio_chat(
315
- full_message, history=[], temperature=temperature,
316
- max_new_tokens=1024, max_token=99999,
317
- call_agent=False, call_agent_level=call_agent_level,
318
- conversation=None,
319
- sub_agent_task=sub_agent_task)
320
-
321
- call_result = call_result.split(
322
- '[FinalAnswer]')[-1]
323
- else:
324
- call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
325
- else:
326
- call_result = self.tooluniverse.run_one_function(
327
- function_call_json[i])
328
-
329
- call_id = self.tooluniverse.call_id_gen()
330
- function_call_json[i]["call_id"] = call_id
331
- call_results.append({
332
- "role": "tool",
333
- "content": json.dumps({"content": call_result, "call_id": call_id})
334
- })
335
- if return_gradio_history and function_call_json[i]["name"] != 'Finish':
336
- if function_call_json[i]["name"] == 'Tool_RAG':
337
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
338
- "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
339
-
340
- else:
341
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
342
- "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
343
- else:
344
- call_results.append({
345
- "role": "tool",
346
- "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
347
- })
348
-
349
- revised_messages = [{
350
- "role": "assistant",
351
- "content": message.strip(),
352
- "tool_calls": json.dumps(function_call_json)
353
- }] + call_results
354
-
355
- # Yield the final result.
356
- if return_gradio_history:
357
- return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
358
- else:
359
- return revised_messages, existing_tools_prompt, special_tool_call
360
-
361
- def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
362
- if conversation[-1]['role'] == 'assisant':
363
- conversation.append(
364
- {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
365
- finish_tools_prompt = self.add_finish_tools([])
366
-
367
- last_outputs_str = self.llm_infer(messages=conversation,
368
- temperature=temperature,
369
- tools=finish_tools_prompt,
370
- output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
371
- skip_special_tokens=True,
372
- max_new_tokens=max_new_tokens, max_token=max_token)
373
- print(last_outputs_str)
374
- return last_outputs_str
375
-
376
- def run_multistep_agent(self, message: str,
377
- temperature: float,
378
- max_new_tokens: int,
379
- max_token: int,
380
- max_round: int = 20,
381
- call_agent=False,
382
- call_agent_level=0) -> str:
383
- """
384
- Generate a streaming response using the llama3-8b model.
385
- Args:
386
- message (str): The input message.
387
- temperature (float): The temperature for generating the response.
388
- max_new_tokens (int): The maximum number of new tokens to generate.
389
- Returns:
390
- str: The generated response.
391
- """
392
- print("\033[1;32;40mstart\033[0m")
393
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
394
- call_agent, call_agent_level, message)
395
- conversation = self.initialize_conversation(message)
396
-
397
- outputs = []
398
- last_outputs = []
399
- next_round = True
400
- function_call_messages = []
401
- current_round = 0
402
- token_overflow = False
403
- enable_summary = False
404
- last_status = {}
405
-
406
- if self.enable_checker:
407
- checker = ReasoningTraceChecker(message, conversation)
408
- try:
409
- while next_round and current_round < max_round:
410
- current_round += 1
411
- if len(outputs) > 0:
412
- function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
413
- last_outputs, return_message=True,
414
- existing_tools_prompt=picked_tools_prompt,
415
- message_for_call_agent=message,
416
- call_agent=call_agent,
417
- call_agent_level=call_agent_level,
418
- temperature=temperature)
419
-
420
- if special_tool_call == 'Finish':
421
- next_round = False
422
- conversation.extend(function_call_messages)
423
- if isinstance(function_call_messages[0]['content'], types.GeneratorType):
424
- function_call_messages[0]['content'] = next(
425
- function_call_messages[0]['content'])
426
- return function_call_messages[0]['content'].split('[FinalAnswer]')[-1]
427
-
428
- if (self.enable_summary or token_overflow) and not call_agent:
429
- if token_overflow:
430
- print("token_overflow, using summary")
431
- enable_summary = True
432
- last_status = self.function_result_summary(
433
- conversation, status=last_status, enable_summary=enable_summary)
434
-
435
- if function_call_messages is not None:
436
- conversation.extend(function_call_messages)
437
- outputs.append(tool_result_format(
438
- function_call_messages))
439
- else:
440
- next_round = False
441
- conversation.extend(
442
- [{"role": "assistant", "content": ''.join(last_outputs)}])
443
- return ''.join(last_outputs).replace("</s>", "")
444
- if self.enable_checker:
445
- good_status, wrong_info = checker.check_conversation()
446
- if not good_status:
447
- next_round = False
448
- print(
449
- "Internal error in reasoning: " + wrong_info)
450
- break
451
- last_outputs = []
452
- outputs.append("### TxAgent:\n")
453
- last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
454
- temperature=temperature,
455
- tools=picked_tools_prompt,
456
- skip_special_tokens=False,
457
- max_new_tokens=max_new_tokens, max_token=max_token,
458
- check_token_status=True)
459
- if last_outputs_str is None:
460
- next_round = False
461
- print(
462
- "The number of tokens exceeds the maximum limit.")
463
- else:
464
- last_outputs.append(last_outputs_str)
465
- if max_round == current_round:
466
- print("The number of rounds exceeds the maximum limit!")
467
- if self.force_finish:
468
- return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
469
- else:
470
- return None
471
-
472
- except Exception as e:
473
- print(f"Error: {e}")
474
- if self.force_finish:
475
- return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
476
- else:
477
- return None
478
-
479
- def build_logits_processor(self, messages, llm):
480
- # Use the tokenizer from the LLM instance.
481
- tokenizer = llm.get_tokenizer()
482
- if self.avoid_repeat and len(messages) > 2:
483
- assistant_messages = []
484
- for i in range(1, len(messages) + 1):
485
- if messages[-i]['role'] == 'assistant':
486
- assistant_messages.append(messages[-i]['content'])
487
- if len(assistant_messages) == 2:
488
- break
489
- forbidden_ids = [tokenizer.encode(
490
- msg, add_special_tokens=False) for msg in assistant_messages]
491
- return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
492
- else:
493
- return None
494
-
495
- def llm_infer(self, messages, temperature=0.7, tools=None,
496
- output_begin_string=None, max_new_tokens=2048,
497
- max_token=None, skip_special_tokens=True,
498
- model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
499
-
500
- if model is None:
501
- model = self.model
502
-
503
- # Removed logits_processors for vLLM compatibility
504
- sampling_params = {
505
- "temperature": temperature,
506
- "top_k": 40,
507
- "top_p": 0.95,
508
- "max_tokens": max_new_tokens,
509
- "stop": ["</s>", "<|eot_id|>", "<|endoftext|>"],
510
- "seed": seed if seed is not None else self.seed
511
- }
512
-
513
- prompt = self.chat_template.render(
514
- messages=messages, tools=tools, add_generation_prompt=True
515
- )
516
-
517
- if output_begin_string is not None:
518
- prompt += output_begin_string
519
-
520
- token_overflow = False
521
- if check_token_status and max_token is not None:
522
- num_input_tokens = len(self.tokenizer.encode(
523
- prompt, return_tensors="pt")[0])
524
- if num_input_tokens > max_token:
525
- import gc
526
- import torch
527
- torch.cuda.empty_cache()
528
- gc.collect()
529
- print("Number of input tokens before inference:", num_input_tokens)
530
- token_overflow = True
531
- return None, token_overflow
532
-
533
- outputs = model.generate(prompt, sampling_params=sampling_params)
534
-
535
- if isinstance(outputs, list) and len(outputs) > 0:
536
- output = outputs[0].outputs[0].text
537
- else:
538
- output = "[Error: No output]"
539
-
540
- print("\033[92m" + output + "\033[0m")
541
-
542
- if check_token_status and max_token is not None:
543
- return output, token_overflow
544
-
545
- return output
546
-
547
- def run_self_agent(self, message: str,
548
- temperature: float,
549
- max_new_tokens: int,
550
- max_token: int) -> str:
551
-
552
- print("\033[1;32;40mstart self agent\033[0m")
553
- conversation = []
554
- conversation = self.set_system_prompt(conversation, self.self_prompt)
555
- conversation.append({"role": "user", "content": message})
556
- return self.llm_infer(messages=conversation,
557
- temperature=temperature,
558
- tools=None,
559
- max_new_tokens=max_new_tokens, max_token=max_token)
560
-
561
- def run_chat_agent(self, message: str,
562
- temperature: float,
563
- max_new_tokens: int,
564
- max_token: int) -> str:
565
-
566
- print("\033[1;32;40mstart chat agent\033[0m")
567
- conversation = []
568
- conversation = self.set_system_prompt(conversation, self.chat_prompt)
569
- conversation.append({"role": "user", "content": message})
570
- return self.llm_infer(messages=conversation,
571
- temperature=temperature,
572
- tools=None,
573
- max_new_tokens=max_new_tokens, max_token=max_token)
574
-
575
- def run_format_agent(self, message: str,
576
- answer: str,
577
- temperature: float,
578
- max_new_tokens: int,
579
- max_token: int) -> str:
580
-
581
- print("\033[1;32;40mstart format agent\033[0m")
582
- if '[FinalAnswer]' in answer:
583
- possible_final_answer = answer.split("[FinalAnswer]")[-1]
584
- elif "\n\n" in answer:
585
- possible_final_answer = answer.split("\n\n")[-1]
586
- else:
587
- possible_final_answer = answer.strip()
588
- if len(possible_final_answer) == 1:
589
- choice = possible_final_answer[0]
590
- if choice in ['A', 'B', 'C', 'D', 'E']:
591
- return choice
592
- elif len(possible_final_answer) > 1:
593
- if possible_final_answer[1] == ':':
594
- choice = possible_final_answer[0]
595
- if choice in ['A', 'B', 'C', 'D', 'E']:
596
- print("choice", choice)
597
- return choice
598
-
599
- conversation = []
600
- format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
601
- conversation = self.set_system_prompt(conversation, format_prompt)
602
- conversation.append({"role": "user", "content": message +
603
- "\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
604
- return self.llm_infer(messages=conversation,
605
- temperature=temperature,
606
- tools=None,
607
- max_new_tokens=max_new_tokens, max_token=max_token)
608
-
609
- def run_summary_agent(self, thought_calls: str,
610
- function_response: str,
611
- temperature: float,
612
- max_new_tokens: int,
613
- max_token: int) -> str:
614
- print("\033[1;32;40mSummarized Tool Result:\033[0m")
615
- generate_tool_result_summary_training_prompt = """Thought and function calls:
616
- {thought_calls}
617
-
618
- Function calls' responses:
619
- \"\"\"
620
- {function_response}
621
- \"\"\"
622
-
623
- Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
624
-
625
- Directly respond with the summarized sentence of the function calls' responses only.
626
-
627
- Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
628
- """.format(thought_calls=thought_calls, function_response=function_response)
629
- conversation = []
630
- conversation.append(
631
- {"role": "user", "content": generate_tool_result_summary_training_prompt})
632
- output = self.llm_infer(messages=conversation,
633
- temperature=temperature,
634
- tools=None,
635
- max_new_tokens=max_new_tokens, max_token=max_token)
636
-
637
- if '[' in output:
638
- output = output.split('[')[0]
639
- return output
640
-
641
- def function_result_summary(self, input_list, status, enable_summary):
642
- """
643
- Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
644
- Supports 'length' and 'step' modes, and skips the last 'k' groups.
645
-
646
- Parameters:
647
- input_list (list): A list of dictionaries containing role and other information.
648
- summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
649
- summary_context_length (int): The context length threshold for the 'length' mode.
650
- last_processed_index (tuple or int): The last processed index.
651
-
652
- Returns:
653
- list: A list of extracted information from valid sequences.
654
- """
655
- if 'tool_call_step' not in status:
656
- status['tool_call_step'] = 0
657
-
658
- for idx in range(len(input_list)):
659
- pos_id = len(input_list)-idx-1
660
- if input_list[pos_id]['role'] == 'assistant':
661
- if 'tool_calls' in input_list[pos_id]:
662
- if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
663
- status['tool_call_step'] += 1
664
- break
665
-
666
- if 'step' in status:
667
- status['step'] += 1
668
- else:
669
- status['step'] = 0
670
-
671
- if not enable_summary:
672
- return status
673
-
674
- if 'summarized_index' not in status:
675
- status['summarized_index'] = 0
676
-
677
- if 'summarized_step' not in status:
678
- status['summarized_step'] = 0
679
-
680
- if 'previous_length' not in status:
681
- status['previous_length'] = 0
682
-
683
- if 'history' not in status:
684
- status['history'] = []
685
-
686
- function_response = ''
687
- idx = 0
688
- current_summarized_index = status['summarized_index']
689
-
690
- status['history'].append(self.summary_mode == 'step' and status['summarized_step']
691
- < status['step']-status['tool_call_step']-self.summary_skip_last_k)
692
-
693
- idx = current_summarized_index
694
- while idx < len(input_list):
695
- if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
696
-
697
- if input_list[idx]['role'] == 'assistant':
698
- if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
699
- this_thought_calls = None
700
- else:
701
- if len(function_response) != 0:
702
- print("internal summary")
703
- status['summarized_step'] += 1
704
- result_summary = self.run_summary_agent(
705
- thought_calls=this_thought_calls,
706
- function_response=function_response,
707
- temperature=0.1,
708
- max_new_tokens=1024,
709
- max_token=99999
710
- )
711
-
712
- input_list.insert(
713
- last_call_idx+1, {'role': 'tool', 'content': result_summary})
714
- status['summarized_index'] = last_call_idx + 2
715
- idx += 1
716
-
717
- last_call_idx = idx
718
- this_thought_calls = input_list[idx]['content'] + \
719
- input_list[idx]['tool_calls']
720
- function_response = ''
721
-
722
- elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
723
- function_response += input_list[idx]['content']
724
- del input_list[idx]
725
- idx -= 1
726
-
727
- else:
728
- break
729
- idx += 1
730
-
731
- if len(function_response) != 0:
732
- status['summarized_step'] += 1
733
- result_summary = self.run_summary_agent(
734
- thought_calls=this_thought_calls,
735
- function_response=function_response,
736
- temperature=0.1,
737
- max_new_tokens=1024,
738
- max_token=99999
739
- )
740
-
741
- tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
742
- for tool_call in tool_calls:
743
- del tool_call['call_id']
744
- input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
745
- input_list.insert(
746
- last_call_idx+1, {'role': 'tool', 'content': result_summary})
747
- status['summarized_index'] = last_call_idx + 2
748
-
749
- return status
750
-
751
- # Following are Gradio related functions
752
-
753
- # General update method that accepts any new arguments through kwargs
754
- def update_parameters(self, **kwargs):
755
- for key, value in kwargs.items():
756
- if hasattr(self, key):
757
- setattr(self, key, value)
758
-
759
- # Return the updated attributes
760
- updated_attributes = {key: value for key,
761
- value in kwargs.items() if hasattr(self, key)}
762
- return updated_attributes
763
-
764
- def run_gradio_chat(self, message: str,
765
- history: list,
766
- temperature: float,
767
- max_new_tokens: int,
768
- max_token: int,
769
- call_agent: bool,
770
- conversation: gr.State,
771
- max_round: int = 20,
772
- seed: int = None,
773
- call_agent_level: int = 0,
774
- sub_agent_task: str = None) -> str:
775
- """
776
- Generate a streaming response using the llama3-8b model.
777
- Args:
778
- message (str): The input message.
779
- history (list): The conversation history used by ChatInterface.
780
- temperature (float): The temperature for generating the response.
781
- max_new_tokens (int): The maximum number of new tokens to generate.
782
- Returns:
783
- str: The generated response.
784
- """
785
- print("\033[1;32;40mstart\033[0m")
786
- print("len(message)", len(message))
787
- if len(message) <= 10:
788
- yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
789
- return "Please provide a valid message."
790
- outputs = []
791
- outputs_str = ''
792
- last_outputs = []
793
-
794
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
795
- call_agent,
796
- call_agent_level,
797
- message)
798
-
799
- conversation = self.initialize_conversation(
800
- message,
801
- conversation=conversation,
802
- history=history)
803
- history = []
804
-
805
- next_round = True
806
- function_call_messages = []
807
- current_round = 0
808
- enable_summary = False
809
- last_status = {} # for summary
810
- token_overflow = False
811
- if self.enable_checker:
812
- checker = ReasoningTraceChecker(
813
- message, conversation, init_index=len(conversation))
814
-
815
- try:
816
- while next_round and current_round < max_round:
817
- current_round += 1
818
- if len(last_outputs) > 0:
819
- function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
820
- last_outputs, return_message=True,
821
- existing_tools_prompt=picked_tools_prompt,
822
- message_for_call_agent=message,
823
- call_agent=call_agent,
824
- call_agent_level=call_agent_level,
825
- temperature=temperature)
826
- history.extend(current_gradio_history)
827
- if special_tool_call == 'Finish':
828
- yield history
829
- next_round = False
830
- conversation.extend(function_call_messages)
831
- return function_call_messages[0]['content']
832
- elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
833
- history.append(
834
- ChatMessage(role="assistant", content=history[-1].content))
835
- yield history
836
- next_round = False
837
- return history[-1].content
838
- if (self.enable_summary or token_overflow) and not call_agent:
839
- if token_overflow:
840
- print("token_overflow, using summary")
841
- enable_summary = True
842
- last_status = self.function_result_summary(
843
- conversation, status=last_status,
844
- enable_summary=enable_summary)
845
- if function_call_messages is not None:
846
- conversation.extend(function_call_messages)
847
- formated_md_function_call_messages = tool_result_format(
848
- function_call_messages)
849
- yield history
850
- else:
851
- next_round = False
852
- conversation.extend(
853
- [{"role": "assistant", "content": ''.join(last_outputs)}])
854
- return ''.join(last_outputs).replace("</s>", "")
855
- if self.enable_checker:
856
- good_status, wrong_info = checker.check_conversation()
857
- if not good_status:
858
- next_round = False
859
- print("Internal error in reasoning: " + wrong_info)
860
- break
861
- last_outputs = []
862
- last_outputs_str, token_overflow = self.llm_infer(
863
- messages=conversation,
864
- temperature=temperature,
865
- tools=picked_tools_prompt,
866
- skip_special_tokens=False,
867
- max_new_tokens=max_new_tokens,
868
- max_token=max_token,
869
- seed=seed,
870
- check_token_status=True)
871
- last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
872
- for each in history:
873
- if each.metadata is not None:
874
- each.metadata['status'] = 'done'
875
- if '[FinalAnswer]' in last_thought:
876
- final_thought, final_answer = last_thought.split(
877
- '[FinalAnswer]')
878
- history.append(
879
- ChatMessage(role="assistant",
880
- content=final_thought.strip())
881
- )
882
- yield history
883
- history.append(
884
- ChatMessage(
885
- role="assistant", content="**Answer**:\n"+final_answer.strip())
886
- )
887
- yield history
888
- else:
889
- history.append(ChatMessage(
890
- role="assistant", content=last_thought))
891
- yield history
892
-
893
- last_outputs.append(last_outputs_str)
894
-
895
- if next_round:
896
- if self.force_finish:
897
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
898
- conversation, temperature, max_new_tokens, max_token)
899
- for each in history:
900
- if each.metadata is not None:
901
- each.metadata['status'] = 'done'
902
- if '[FinalAnswer]' in last_thought:
903
- final_thought, final_answer = last_thought.split(
904
- '[FinalAnswer]')
905
- history.append(
906
- ChatMessage(role="assistant",
907
- content=final_thought.strip())
908
- )
909
- yield history
910
- history.append(
911
- ChatMessage(
912
- role="assistant", content="**Answer**:\n"+final_answer.strip())
913
- )
914
- yield history
915
- else:
916
- yield "The number of rounds exceeds the maximum limit!"
917
-
918
- except Exception as e:
919
- print(f"Error: {e}")
920
- if self.force_finish:
921
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
922
- conversation,
923
- temperature,
924
- max_new_tokens,
925
- max_token)
926
- for each in history:
927
- if each.metadata is not None:
928
- each.metadata['status'] = 'done'
929
- if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
930
- final_thought, final_answer = last_thought.split(
931
- '[FinalAnswer]')
932
- history.append(
933
- ChatMessage(role="assistant",
934
- content=final_thought.strip())
935
- )
936
- yield history
937
- history.append(
938
- ChatMessage(
939
- role="assistant", content="**Answer**:\n"+final_answer.strip())
940
- )
941
- yield history
942
- else:
943
- return None