Ali2206 commited on
Commit
cf95a11
·
verified ·
1 Parent(s): 6f88317

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +76 -304
src/txagent/txagent.py CHANGED
@@ -1,348 +1,120 @@
1
  import os
2
- import sys
3
  import json
4
- import gc
5
- import numpy as np
6
- from vllm import LLM, SamplingParams
7
- from jinja2 import Template
8
- from typing import List, Dict, Optional, Union, Tuple, Generator
9
- import types
10
- from tooluniverse import ToolUniverse
11
- from .toolrag import ToolRAGModel
12
- import torch
13
  import logging
14
- from datetime import datetime
 
 
 
15
 
16
- # Configure logging with a more specific logger name
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
18
  logger = logging.getLogger("TxAgent")
19
 
20
- from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
-
22
  class TxAgent:
23
  def __init__(self,
24
  model_name: str,
25
  rag_model_name: str,
26
  tool_files_dict: Optional[Dict] = None,
27
- enable_finish: bool = True,
28
- enable_rag: bool = False,
29
- enable_summary: bool = False,
30
- init_rag_num: int = 0,
31
- step_rag_num: int = 0,
32
- summary_mode: str = 'step',
33
- summary_skip_last_k: int = 0,
34
- summary_context_length: Optional[int] = None,
35
  force_finish: bool = True,
36
- avoid_repeat: bool = True,
37
- seed: Optional[int] = None,
38
- enable_checker: bool = False,
39
- enable_chat: bool = False,
40
- additional_default_tools: Optional[List] = None):
41
- """
42
- Initialize the TxAgent with specified configuration.
43
- """
44
  self.model_name = model_name
45
- self.tokenizer = None
46
  self.rag_model_name = rag_model_name
47
  self.tool_files_dict = tool_files_dict or {}
48
- self.model = None
49
- self.rag_model = ToolRAGModel(rag_model_name)
50
- self.tooluniverse = None
51
- self.prompt_multi_step = "You are a helpful assistant that solves problems through step-by-step reasoning."
52
- self.self_prompt = "Strictly follow the instruction."
53
- self.chat_prompt = "You are a helpful assistant for user chat."
54
- self.enable_finish = enable_finish
55
- self.enable_rag = enable_rag
56
- self.enable_summary = enable_summary
57
- self.summary_mode = summary_mode
58
- self.summary_skip_last_k = summary_skip_last_k
59
- self.summary_context_length = summary_context_length
60
- self.init_rag_num = init_rag_num
61
- self.step_rag_num = step_rag_num
62
  self.force_finish = force_finish
63
- self.avoid_repeat = avoid_repeat
64
- self.seed = seed
65
  self.enable_checker = enable_checker
66
- self.additional_default_tools = additional_default_tools or []
67
- logger.info("TxAgent initialized with model: %s, RAG: %s", model_name, rag_model_name)
 
 
 
 
 
 
 
68
 
69
- def init_model(self) -> None:
70
- """Initialize both the main model and tool universe."""
71
  self.load_models()
72
- self.load_tooluniverse()
73
- logger.info("Model and tools initialized successfully")
74
-
75
- def load_models(self, model_name: Optional[str] = None) -> str:
76
- """
77
- Load the specified model or the default model if none specified.
78
- """
79
- if model_name is not None:
80
- if model_name == self.model_name:
81
- return f"The model {model_name} is already loaded."
82
- self.model_name = model_name
83
 
 
 
84
  try:
85
- self.model = LLM(
86
- model=self.model_name,
87
- dtype="float16",
88
- max_model_len=131072,
89
- max_num_batched_tokens=65536,
90
- max_num_seqs=512,
91
- gpu_memory_utilization=0.95,
92
- trust_remote_code=True,
93
  )
94
- self.tokenizer = self.model.get_tokenizer()
95
- self.chat_template = Template(self.tokenizer.chat_template)
96
- logger.info(
97
- "Model %s loaded with max_model_len=%d, max_num_batched_tokens=%d",
98
- self.model_name, 131072, 65536
 
99
  )
100
- return f"Model {model_name} loaded successfully."
 
 
101
  except Exception as e:
102
- logger.error("Failed to load model: %s", str(e))
103
  raise RuntimeError(f"Failed to load model: {str(e)}")
104
 
105
- def load_tooluniverse(self) -> None:
106
- """Load and initialize the tool universe with specified tools."""
107
  try:
108
- self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
109
- self.tooluniverse.load_tools()
110
- special_tools = self.tooluniverse.prepare_tool_prompts(
111
- self.tooluniverse.tool_category_dicts["special_tools"])
112
- self.special_tools_name = [tool['name'] for tool in special_tools]
113
- logger.info("ToolUniverse loaded with %d special tools", len(self.special_tools_name))
114
  except Exception as e:
115
- logger.error("Failed to load tools: %s", str(e))
116
- raise RuntimeError(f"Failed to load tools: {str(e)}")
117
-
118
- def run_multistep_agent(self,
119
- message: str,
120
- temperature: float,
121
- max_new_tokens: int,
122
- max_token: int,
123
- max_round: int = 5,
124
- call_agent: bool = False,
125
- call_agent_level: int = 0) -> Optional[str]:
126
- """
127
- Run multi-step reasoning with the agent.
128
- """
129
- logger.info("Starting multistep agent for message: %s", message[:100])
130
- picked_tools_prompt = []
131
- call_agent_level = 0
132
- if call_agent:
133
- call_agent_level += 1
134
- if call_agent_level >= 2:
135
- call_agent = False
136
 
137
- conversation = []
138
- conversation = self.set_system_prompt(conversation, self.prompt_multi_step)
139
- conversation.append({"role": "user", "content": message})
140
-
141
- outputs = []
142
- last_outputs = []
143
- next_round = True
144
- current_round = 0
145
- token_overflow = False
146
- enable_summary = False
147
- last_status = {}
148
-
149
- while next_round and current_round < max_round:
150
- current_round += 1
151
- if len(outputs) > 0:
152
- function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
153
- last_outputs,
154
- return_message=True,
155
- existing_tools_prompt=picked_tools_prompt,
156
- message_for_call_agent=message,
157
- call_agent=call_agent,
158
- call_agent_level=call_agent_level,
159
- temperature=temperature
160
- )
161
-
162
- if special_tool_call == 'Finish':
163
- next_round = False
164
- conversation.extend(function_call_messages)
165
- content = function_call_messages[0]['content']
166
- if content is None:
167
- return "❌ No content returned after Finish tool call."
168
- return content.split('[FinalAnswer]')[-1]
169
-
170
- if (self.enable_summary or token_overflow) and not call_agent:
171
- enable_summary = True
172
- last_status = self.function_result_summary(
173
- conversation, status=last_status, enable_summary=enable_summary)
174
-
175
- if function_call_messages:
176
- conversation.extend(function_call_messages)
177
- outputs.append(tool_result_format(function_call_messages))
178
- else:
179
- next_round = False
180
- conversation.extend([{"role": "assistant", "content": ''.join(last_outputs)}])
181
- return ''.join(last_outputs).replace("</s>", "")
182
-
183
- last_outputs = []
184
- outputs.append("### TxAgent:\n")
185
- last_outputs_str, token_overflow = self.llm_infer(
186
- messages=conversation,
187
- temperature=temperature,
188
- tools=picked_tools_prompt,
189
- skip_special_tokens=False,
190
- max_new_tokens=2048,
191
- max_token=131072,
192
- check_token_status=True)
193
 
194
- if last_outputs_str is None:
195
- logger.warning("Token limit exceeded")
196
- if self.force_finish:
197
- return self.get_answer_based_on_unfinished_reasoning(
198
- conversation, temperature, max_new_tokens, max_token)
199
- return " Token limit exceeded."
 
 
200
 
201
- last_outputs.append(last_outputs_str)
202
-
203
- if max_round == current_round:
204
- logger.warning("Max rounds exceeded")
205
- if self.force_finish:
206
- return self.get_answer_based_on_unfinished_reasoning(
207
- conversation, temperature, max_new_tokens, max_token)
208
- return None
209
-
210
- def run_function_call(self,
211
- fcall_str: str,
212
- return_message: bool = False,
213
- existing_tools_prompt: Optional[List] = None,
214
- message_for_call_agent: Optional[str] = None,
215
- call_agent: bool = False,
216
- call_agent_level: Optional[int] = None,
217
- temperature: Optional[float] = None) -> Tuple[List[Dict], List, str]:
218
- """
219
- Execute function calls from the model's output.
220
- """
221
- try:
222
- function_call_json, message = self.tooluniverse.extract_function_call_json(
223
- fcall_str, return_message=return_message, verbose=False)
224
  except Exception as e:
225
- logger.error("Tool call parsing failed: %s", e)
226
- function_call_json = []
227
- message = fcall_str
228
 
229
- call_results = []
230
- special_tool_call = ''
231
- if function_call_json:
232
- if isinstance(function_call_json, list):
233
- for i in range(len(function_call_json)):
234
- logger.info("Tool Call: %s", function_call_json[i])
235
- if function_call_json[i]["name"] == 'Finish':
236
- special_tool_call = 'Finish'
237
- break
238
- elif function_call_json[i]["name"] == 'CallAgent':
239
- if call_agent_level is not None and call_agent_level < 2 and call_agent:
240
- solution_plan = function_call_json[i]['arguments']['solution']
241
- full_message = (
242
- (message_for_call_agent or "") +
243
- "\nYou must follow the following plan to answer the question: " +
244
- str(solution_plan))
245
- call_result = self.run_multistep_agent(
246
- full_message,
247
- temperature=temperature,
248
- max_new_tokens=512,
249
- max_token=131072,
250
- call_agent=False,
251
- call_agent_level=call_agent_level
252
- )
253
- if call_result is None:
254
- call_result = "⚠️ No content returned from sub-agent."
255
- else:
256
- call_result = call_result.split('[FinalAnswer]')[-1].strip()
257
- else:
258
- call_result = "Error: CallAgent disabled."
259
- else:
260
- call_result = self.tooluniverse.run_one_function(function_call_json[i])
261
-
262
- call_id = self.tooluniverse.call_id_gen()
263
- function_call_json[i]["call_id"] = call_id
264
- logger.info("Tool Call Result: %s", call_result)
265
- call_results.append({
266
- "role": "tool",
267
- "content": json.dumps({
268
- "tool_name": function_call_json[i]["name"],
269
- "content": call_result,
270
- "call_id": call_id
271
- })
272
- })
273
 
274
- revised_messages = [{
275
- "role": "assistant",
276
- "content": message.strip(),
277
- "tool_calls": json.dumps(function_call_json)
278
- }] + call_results
279
-
280
- return revised_messages, existing_tools_prompt or [], special_tool_call
281
-
282
- def llm_infer(self,
283
- messages: List[Dict],
284
- temperature: float = 0.1,
285
- tools: Optional[List] = None,
286
- output_begin_string: Optional[str] = None,
287
- max_new_tokens: int = 512,
288
- max_token: int = 131072,
289
- skip_special_tokens: bool = True,
290
- model: Optional[LLM] = None,
291
- check_token_status: bool = False) -> Union[str, Tuple[str, bool]]:
292
- """
293
- Perform inference using the LLM.
294
- """
295
- model = model or self.model
296
- tokenizer = self.tokenizer
297
-
298
- sampling_params = SamplingParams(
299
- temperature=temperature,
300
- max_tokens=max_new_tokens,
301
- seed=self.seed,
302
- )
303
-
304
- prompt = self.chat_template.render(
305
- messages=messages, tools=tools, add_generation_prompt=True)
306
- if output_begin_string is not None:
307
- prompt += output_begin_string
308
-
309
- token_overflow = False
310
- if check_token_status and max_token is not None:
311
- num_input_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
312
- logger.info("Input prompt tokens: %d, max_token: %d", num_input_tokens, max_token)
313
- if num_input_tokens > max_token:
314
- torch.cuda.empty_cache()
315
- gc.collect()
316
- logger.warning("Token overflow: %d > %d", num_input_tokens, max_token)
317
- return (None, True) if check_token_status else None
318
-
319
- try:
320
- output = model.generate(prompt, sampling_params=sampling_params)
321
- output_text = output[0].outputs[0].text
322
- output_tokens = len(tokenizer.encode(output_text, add_special_tokens=False))
323
- logger.debug("Inference output: %s (output tokens: %d)", output_text[:100], output_tokens)
324
-
325
- if skip_special_tokens:
326
- output_text = output_text.replace("</s>", "").strip()
327
-
328
- torch.cuda.empty_cache()
329
- gc.collect()
330
-
331
- return (output_text, token_overflow) if check_token_status else output_text
332
- except Exception as e:
333
- logger.error("Inference failed: %s", str(e))
334
- raise RuntimeError(f"Inference failed: {str(e)}")
335
 
336
- def cleanup(self) -> None:
337
- """Clean up resources and clear memory."""
338
  if hasattr(self, 'model'):
339
  del self.model
340
  if hasattr(self, 'rag_model'):
341
  del self.rag_model
342
- if hasattr(self, 'tooluniverse'):
343
- del self.tooluniverse
344
  torch.cuda.empty_cache()
345
- gc.collect()
346
  logger.info("TxAgent resources cleaned up")
347
 
348
  def __del__(self):
 
1
  import os
 
2
  import json
 
 
 
 
 
 
 
 
 
3
  import logging
4
+ import torch
5
+ from typing import List, Dict, Optional, Union
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from sentence_transformers import SentenceTransformer
8
 
9
+ # Configure logging
10
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
11
  logger = logging.getLogger("TxAgent")
12
 
 
 
13
  class TxAgent:
14
  def __init__(self,
15
  model_name: str,
16
  rag_model_name: str,
17
  tool_files_dict: Optional[Dict] = None,
18
+ use_vllm: bool = False,
 
 
 
 
 
 
 
19
  force_finish: bool = True,
20
+ enable_checker: bool = True,
21
+ step_rag_num: int = 4,
22
+ seed: Optional[int] = None):
23
+
 
 
 
 
24
  self.model_name = model_name
 
25
  self.rag_model_name = rag_model_name
26
  self.tool_files_dict = tool_files_dict or {}
27
+ self.use_vllm = use_vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  self.force_finish = force_finish
 
 
29
  self.enable_checker = enable_checker
30
+ self.step_rag_num = step_rag_num
31
+ self.seed = seed
32
+
33
+ self.model = None
34
+ self.tokenizer = None
35
+ self.rag_model = None
36
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+ logger.info(f"Initializing TxAgent with model: {model_name} on device: {self.device}")
39
 
40
+ def init_model(self):
41
+ """Initialize both the main model and RAG model."""
42
  self.load_models()
43
+ self.load_rag_model()
44
+ logger.info("Model initialization complete")
 
 
 
 
 
 
 
 
 
45
 
46
+ def load_models(self):
47
+ """Load the main LLM model."""
48
  try:
49
+ logger.info(f"Loading model: {self.model_name}")
50
+
51
+ self.tokenizer = AutoTokenizer.from_pretrained(
52
+ self.model_name,
53
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
 
 
 
54
  )
55
+
56
+ self.model = AutoModelForCausalLM.from_pretrained(
57
+ self.model_name,
58
+ torch_dtype=torch.float16,
59
+ device_map="auto",
60
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
61
  )
62
+
63
+ logger.info(f"Successfully loaded model on {self.device}")
64
+
65
  except Exception as e:
66
+ logger.error(f"Failed to load model: {str(e)}")
67
  raise RuntimeError(f"Failed to load model: {str(e)}")
68
 
69
+ def load_rag_model(self):
70
+ """Load the RAG model."""
71
  try:
72
+ logger.info(f"Loading RAG model: {self.rag_model_name}")
73
+ self.rag_model = SentenceTransformer(
74
+ self.rag_model_name,
75
+ device=str(self.device)
76
+ )
77
+ logger.info("RAG model loaded successfully")
78
  except Exception as e:
79
+ logger.error(f"Failed to load RAG model: {str(e)}")
80
+ raise RuntimeError(f"Failed to load RAG model: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ def process_document(self, file_path: str) -> Dict:
83
+ """Process a document and return analysis results."""
84
+ try:
85
+ # Extract text (implement your extraction logic)
86
+ text = self.extract_text(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ # Process with LLM (implement your processing logic)
89
+ result = self.analyze_text(text)
90
+
91
+ return {
92
+ "status": "success",
93
+ "analysis": result,
94
+ "model": self.model_name
95
+ }
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  except Exception as e:
98
+ logger.error(f"Document processing failed: {str(e)}")
99
+ raise RuntimeError(f"Document processing failed: {str(e)}")
 
100
 
101
+ def extract_text(self, file_path: str) -> str:
102
+ """Extract text from various file formats."""
103
+ # Implement your text extraction logic here
104
+ return "Sample extracted text"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ def analyze_text(self, text: str) -> str:
107
+ """Analyze extracted text using the LLM."""
108
+ # Implement your text analysis logic here
109
+ return "Sample analysis result"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ def cleanup(self):
112
+ """Clean up resources."""
113
  if hasattr(self, 'model'):
114
  del self.model
115
  if hasattr(self, 'rag_model'):
116
  del self.rag_model
 
 
117
  torch.cuda.empty_cache()
 
118
  logger.info("TxAgent resources cleaned up")
119
 
120
  def __del__(self):