Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +76 -304
src/txagent/txagent.py
CHANGED
@@ -1,348 +1,120 @@
|
|
1 |
import os
|
2 |
-
import sys
|
3 |
import json
|
4 |
-
import gc
|
5 |
-
import numpy as np
|
6 |
-
from vllm import LLM, SamplingParams
|
7 |
-
from jinja2 import Template
|
8 |
-
from typing import List, Dict, Optional, Union, Tuple, Generator
|
9 |
-
import types
|
10 |
-
from tooluniverse import ToolUniverse
|
11 |
-
from .toolrag import ToolRAGModel
|
12 |
-
import torch
|
13 |
import logging
|
14 |
-
|
|
|
|
|
|
|
15 |
|
16 |
-
# Configure logging
|
17 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
18 |
logger = logging.getLogger("TxAgent")
|
19 |
|
20 |
-
from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
|
21 |
-
|
22 |
class TxAgent:
|
23 |
def __init__(self,
|
24 |
model_name: str,
|
25 |
rag_model_name: str,
|
26 |
tool_files_dict: Optional[Dict] = None,
|
27 |
-
|
28 |
-
enable_rag: bool = False,
|
29 |
-
enable_summary: bool = False,
|
30 |
-
init_rag_num: int = 0,
|
31 |
-
step_rag_num: int = 0,
|
32 |
-
summary_mode: str = 'step',
|
33 |
-
summary_skip_last_k: int = 0,
|
34 |
-
summary_context_length: Optional[int] = None,
|
35 |
force_finish: bool = True,
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
additional_default_tools: Optional[List] = None):
|
41 |
-
"""
|
42 |
-
Initialize the TxAgent with specified configuration.
|
43 |
-
"""
|
44 |
self.model_name = model_name
|
45 |
-
self.tokenizer = None
|
46 |
self.rag_model_name = rag_model_name
|
47 |
self.tool_files_dict = tool_files_dict or {}
|
48 |
-
self.
|
49 |
-
self.rag_model = ToolRAGModel(rag_model_name)
|
50 |
-
self.tooluniverse = None
|
51 |
-
self.prompt_multi_step = "You are a helpful assistant that solves problems through step-by-step reasoning."
|
52 |
-
self.self_prompt = "Strictly follow the instruction."
|
53 |
-
self.chat_prompt = "You are a helpful assistant for user chat."
|
54 |
-
self.enable_finish = enable_finish
|
55 |
-
self.enable_rag = enable_rag
|
56 |
-
self.enable_summary = enable_summary
|
57 |
-
self.summary_mode = summary_mode
|
58 |
-
self.summary_skip_last_k = summary_skip_last_k
|
59 |
-
self.summary_context_length = summary_context_length
|
60 |
-
self.init_rag_num = init_rag_num
|
61 |
-
self.step_rag_num = step_rag_num
|
62 |
self.force_finish = force_finish
|
63 |
-
self.avoid_repeat = avoid_repeat
|
64 |
-
self.seed = seed
|
65 |
self.enable_checker = enable_checker
|
66 |
-
self.
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
def init_model(self)
|
70 |
-
"""Initialize both the main model and
|
71 |
self.load_models()
|
72 |
-
self.
|
73 |
-
logger.info("Model
|
74 |
-
|
75 |
-
def load_models(self, model_name: Optional[str] = None) -> str:
|
76 |
-
"""
|
77 |
-
Load the specified model or the default model if none specified.
|
78 |
-
"""
|
79 |
-
if model_name is not None:
|
80 |
-
if model_name == self.model_name:
|
81 |
-
return f"The model {model_name} is already loaded."
|
82 |
-
self.model_name = model_name
|
83 |
|
|
|
|
|
84 |
try:
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
max_num_seqs=512,
|
91 |
-
gpu_memory_utilization=0.95,
|
92 |
-
trust_remote_code=True,
|
93 |
)
|
94 |
-
|
95 |
-
self.
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
99 |
)
|
100 |
-
|
|
|
|
|
101 |
except Exception as e:
|
102 |
-
logger.error("Failed to load model:
|
103 |
raise RuntimeError(f"Failed to load model: {str(e)}")
|
104 |
|
105 |
-
def
|
106 |
-
"""Load
|
107 |
try:
|
108 |
-
|
109 |
-
self.
|
110 |
-
|
111 |
-
self.
|
112 |
-
|
113 |
-
logger.info("
|
114 |
except Exception as e:
|
115 |
-
logger.error("Failed to load
|
116 |
-
raise RuntimeError(f"Failed to load
|
117 |
-
|
118 |
-
def run_multistep_agent(self,
|
119 |
-
message: str,
|
120 |
-
temperature: float,
|
121 |
-
max_new_tokens: int,
|
122 |
-
max_token: int,
|
123 |
-
max_round: int = 5,
|
124 |
-
call_agent: bool = False,
|
125 |
-
call_agent_level: int = 0) -> Optional[str]:
|
126 |
-
"""
|
127 |
-
Run multi-step reasoning with the agent.
|
128 |
-
"""
|
129 |
-
logger.info("Starting multistep agent for message: %s", message[:100])
|
130 |
-
picked_tools_prompt = []
|
131 |
-
call_agent_level = 0
|
132 |
-
if call_agent:
|
133 |
-
call_agent_level += 1
|
134 |
-
if call_agent_level >= 2:
|
135 |
-
call_agent = False
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
last_outputs = []
|
143 |
-
next_round = True
|
144 |
-
current_round = 0
|
145 |
-
token_overflow = False
|
146 |
-
enable_summary = False
|
147 |
-
last_status = {}
|
148 |
-
|
149 |
-
while next_round and current_round < max_round:
|
150 |
-
current_round += 1
|
151 |
-
if len(outputs) > 0:
|
152 |
-
function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
|
153 |
-
last_outputs,
|
154 |
-
return_message=True,
|
155 |
-
existing_tools_prompt=picked_tools_prompt,
|
156 |
-
message_for_call_agent=message,
|
157 |
-
call_agent=call_agent,
|
158 |
-
call_agent_level=call_agent_level,
|
159 |
-
temperature=temperature
|
160 |
-
)
|
161 |
-
|
162 |
-
if special_tool_call == 'Finish':
|
163 |
-
next_round = False
|
164 |
-
conversation.extend(function_call_messages)
|
165 |
-
content = function_call_messages[0]['content']
|
166 |
-
if content is None:
|
167 |
-
return "❌ No content returned after Finish tool call."
|
168 |
-
return content.split('[FinalAnswer]')[-1]
|
169 |
-
|
170 |
-
if (self.enable_summary or token_overflow) and not call_agent:
|
171 |
-
enable_summary = True
|
172 |
-
last_status = self.function_result_summary(
|
173 |
-
conversation, status=last_status, enable_summary=enable_summary)
|
174 |
-
|
175 |
-
if function_call_messages:
|
176 |
-
conversation.extend(function_call_messages)
|
177 |
-
outputs.append(tool_result_format(function_call_messages))
|
178 |
-
else:
|
179 |
-
next_round = False
|
180 |
-
conversation.extend([{"role": "assistant", "content": ''.join(last_outputs)}])
|
181 |
-
return ''.join(last_outputs).replace("</s>", "")
|
182 |
-
|
183 |
-
last_outputs = []
|
184 |
-
outputs.append("### TxAgent:\n")
|
185 |
-
last_outputs_str, token_overflow = self.llm_infer(
|
186 |
-
messages=conversation,
|
187 |
-
temperature=temperature,
|
188 |
-
tools=picked_tools_prompt,
|
189 |
-
skip_special_tokens=False,
|
190 |
-
max_new_tokens=2048,
|
191 |
-
max_token=131072,
|
192 |
-
check_token_status=True)
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
200 |
|
201 |
-
last_outputs.append(last_outputs_str)
|
202 |
-
|
203 |
-
if max_round == current_round:
|
204 |
-
logger.warning("Max rounds exceeded")
|
205 |
-
if self.force_finish:
|
206 |
-
return self.get_answer_based_on_unfinished_reasoning(
|
207 |
-
conversation, temperature, max_new_tokens, max_token)
|
208 |
-
return None
|
209 |
-
|
210 |
-
def run_function_call(self,
|
211 |
-
fcall_str: str,
|
212 |
-
return_message: bool = False,
|
213 |
-
existing_tools_prompt: Optional[List] = None,
|
214 |
-
message_for_call_agent: Optional[str] = None,
|
215 |
-
call_agent: bool = False,
|
216 |
-
call_agent_level: Optional[int] = None,
|
217 |
-
temperature: Optional[float] = None) -> Tuple[List[Dict], List, str]:
|
218 |
-
"""
|
219 |
-
Execute function calls from the model's output.
|
220 |
-
"""
|
221 |
-
try:
|
222 |
-
function_call_json, message = self.tooluniverse.extract_function_call_json(
|
223 |
-
fcall_str, return_message=return_message, verbose=False)
|
224 |
except Exception as e:
|
225 |
-
logger.error("
|
226 |
-
|
227 |
-
message = fcall_str
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
for i in range(len(function_call_json)):
|
234 |
-
logger.info("Tool Call: %s", function_call_json[i])
|
235 |
-
if function_call_json[i]["name"] == 'Finish':
|
236 |
-
special_tool_call = 'Finish'
|
237 |
-
break
|
238 |
-
elif function_call_json[i]["name"] == 'CallAgent':
|
239 |
-
if call_agent_level is not None and call_agent_level < 2 and call_agent:
|
240 |
-
solution_plan = function_call_json[i]['arguments']['solution']
|
241 |
-
full_message = (
|
242 |
-
(message_for_call_agent or "") +
|
243 |
-
"\nYou must follow the following plan to answer the question: " +
|
244 |
-
str(solution_plan))
|
245 |
-
call_result = self.run_multistep_agent(
|
246 |
-
full_message,
|
247 |
-
temperature=temperature,
|
248 |
-
max_new_tokens=512,
|
249 |
-
max_token=131072,
|
250 |
-
call_agent=False,
|
251 |
-
call_agent_level=call_agent_level
|
252 |
-
)
|
253 |
-
if call_result is None:
|
254 |
-
call_result = "⚠️ No content returned from sub-agent."
|
255 |
-
else:
|
256 |
-
call_result = call_result.split('[FinalAnswer]')[-1].strip()
|
257 |
-
else:
|
258 |
-
call_result = "Error: CallAgent disabled."
|
259 |
-
else:
|
260 |
-
call_result = self.tooluniverse.run_one_function(function_call_json[i])
|
261 |
-
|
262 |
-
call_id = self.tooluniverse.call_id_gen()
|
263 |
-
function_call_json[i]["call_id"] = call_id
|
264 |
-
logger.info("Tool Call Result: %s", call_result)
|
265 |
-
call_results.append({
|
266 |
-
"role": "tool",
|
267 |
-
"content": json.dumps({
|
268 |
-
"tool_name": function_call_json[i]["name"],
|
269 |
-
"content": call_result,
|
270 |
-
"call_id": call_id
|
271 |
-
})
|
272 |
-
})
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
}] + call_results
|
279 |
-
|
280 |
-
return revised_messages, existing_tools_prompt or [], special_tool_call
|
281 |
-
|
282 |
-
def llm_infer(self,
|
283 |
-
messages: List[Dict],
|
284 |
-
temperature: float = 0.1,
|
285 |
-
tools: Optional[List] = None,
|
286 |
-
output_begin_string: Optional[str] = None,
|
287 |
-
max_new_tokens: int = 512,
|
288 |
-
max_token: int = 131072,
|
289 |
-
skip_special_tokens: bool = True,
|
290 |
-
model: Optional[LLM] = None,
|
291 |
-
check_token_status: bool = False) -> Union[str, Tuple[str, bool]]:
|
292 |
-
"""
|
293 |
-
Perform inference using the LLM.
|
294 |
-
"""
|
295 |
-
model = model or self.model
|
296 |
-
tokenizer = self.tokenizer
|
297 |
-
|
298 |
-
sampling_params = SamplingParams(
|
299 |
-
temperature=temperature,
|
300 |
-
max_tokens=max_new_tokens,
|
301 |
-
seed=self.seed,
|
302 |
-
)
|
303 |
-
|
304 |
-
prompt = self.chat_template.render(
|
305 |
-
messages=messages, tools=tools, add_generation_prompt=True)
|
306 |
-
if output_begin_string is not None:
|
307 |
-
prompt += output_begin_string
|
308 |
-
|
309 |
-
token_overflow = False
|
310 |
-
if check_token_status and max_token is not None:
|
311 |
-
num_input_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
312 |
-
logger.info("Input prompt tokens: %d, max_token: %d", num_input_tokens, max_token)
|
313 |
-
if num_input_tokens > max_token:
|
314 |
-
torch.cuda.empty_cache()
|
315 |
-
gc.collect()
|
316 |
-
logger.warning("Token overflow: %d > %d", num_input_tokens, max_token)
|
317 |
-
return (None, True) if check_token_status else None
|
318 |
-
|
319 |
-
try:
|
320 |
-
output = model.generate(prompt, sampling_params=sampling_params)
|
321 |
-
output_text = output[0].outputs[0].text
|
322 |
-
output_tokens = len(tokenizer.encode(output_text, add_special_tokens=False))
|
323 |
-
logger.debug("Inference output: %s (output tokens: %d)", output_text[:100], output_tokens)
|
324 |
-
|
325 |
-
if skip_special_tokens:
|
326 |
-
output_text = output_text.replace("</s>", "").strip()
|
327 |
-
|
328 |
-
torch.cuda.empty_cache()
|
329 |
-
gc.collect()
|
330 |
-
|
331 |
-
return (output_text, token_overflow) if check_token_status else output_text
|
332 |
-
except Exception as e:
|
333 |
-
logger.error("Inference failed: %s", str(e))
|
334 |
-
raise RuntimeError(f"Inference failed: {str(e)}")
|
335 |
|
336 |
-
def cleanup(self)
|
337 |
-
"""Clean up resources
|
338 |
if hasattr(self, 'model'):
|
339 |
del self.model
|
340 |
if hasattr(self, 'rag_model'):
|
341 |
del self.rag_model
|
342 |
-
if hasattr(self, 'tooluniverse'):
|
343 |
-
del self.tooluniverse
|
344 |
torch.cuda.empty_cache()
|
345 |
-
gc.collect()
|
346 |
logger.info("TxAgent resources cleaned up")
|
347 |
|
348 |
def __del__(self):
|
|
|
1 |
import os
|
|
|
2 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import logging
|
4 |
+
import torch
|
5 |
+
from typing import List, Dict, Optional, Union
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
|
9 |
+
# Configure logging
|
10 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
11 |
logger = logging.getLogger("TxAgent")
|
12 |
|
|
|
|
|
13 |
class TxAgent:
|
14 |
def __init__(self,
|
15 |
model_name: str,
|
16 |
rag_model_name: str,
|
17 |
tool_files_dict: Optional[Dict] = None,
|
18 |
+
use_vllm: bool = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
force_finish: bool = True,
|
20 |
+
enable_checker: bool = True,
|
21 |
+
step_rag_num: int = 4,
|
22 |
+
seed: Optional[int] = None):
|
23 |
+
|
|
|
|
|
|
|
|
|
24 |
self.model_name = model_name
|
|
|
25 |
self.rag_model_name = rag_model_name
|
26 |
self.tool_files_dict = tool_files_dict or {}
|
27 |
+
self.use_vllm = use_vllm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
self.force_finish = force_finish
|
|
|
|
|
29 |
self.enable_checker = enable_checker
|
30 |
+
self.step_rag_num = step_rag_num
|
31 |
+
self.seed = seed
|
32 |
+
|
33 |
+
self.model = None
|
34 |
+
self.tokenizer = None
|
35 |
+
self.rag_model = None
|
36 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
|
38 |
+
logger.info(f"Initializing TxAgent with model: {model_name} on device: {self.device}")
|
39 |
|
40 |
+
def init_model(self):
|
41 |
+
"""Initialize both the main model and RAG model."""
|
42 |
self.load_models()
|
43 |
+
self.load_rag_model()
|
44 |
+
logger.info("Model initialization complete")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
def load_models(self):
|
47 |
+
"""Load the main LLM model."""
|
48 |
try:
|
49 |
+
logger.info(f"Loading model: {self.model_name}")
|
50 |
+
|
51 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
52 |
+
self.model_name,
|
53 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
|
|
|
|
|
|
54 |
)
|
55 |
+
|
56 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
57 |
+
self.model_name,
|
58 |
+
torch_dtype=torch.float16,
|
59 |
+
device_map="auto",
|
60 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
61 |
)
|
62 |
+
|
63 |
+
logger.info(f"Successfully loaded model on {self.device}")
|
64 |
+
|
65 |
except Exception as e:
|
66 |
+
logger.error(f"Failed to load model: {str(e)}")
|
67 |
raise RuntimeError(f"Failed to load model: {str(e)}")
|
68 |
|
69 |
+
def load_rag_model(self):
|
70 |
+
"""Load the RAG model."""
|
71 |
try:
|
72 |
+
logger.info(f"Loading RAG model: {self.rag_model_name}")
|
73 |
+
self.rag_model = SentenceTransformer(
|
74 |
+
self.rag_model_name,
|
75 |
+
device=str(self.device)
|
76 |
+
)
|
77 |
+
logger.info("RAG model loaded successfully")
|
78 |
except Exception as e:
|
79 |
+
logger.error(f"Failed to load RAG model: {str(e)}")
|
80 |
+
raise RuntimeError(f"Failed to load RAG model: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
+
def process_document(self, file_path: str) -> Dict:
|
83 |
+
"""Process a document and return analysis results."""
|
84 |
+
try:
|
85 |
+
# Extract text (implement your extraction logic)
|
86 |
+
text = self.extract_text(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
+
# Process with LLM (implement your processing logic)
|
89 |
+
result = self.analyze_text(text)
|
90 |
+
|
91 |
+
return {
|
92 |
+
"status": "success",
|
93 |
+
"analysis": result,
|
94 |
+
"model": self.model_name
|
95 |
+
}
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
except Exception as e:
|
98 |
+
logger.error(f"Document processing failed: {str(e)}")
|
99 |
+
raise RuntimeError(f"Document processing failed: {str(e)}")
|
|
|
100 |
|
101 |
+
def extract_text(self, file_path: str) -> str:
|
102 |
+
"""Extract text from various file formats."""
|
103 |
+
# Implement your text extraction logic here
|
104 |
+
return "Sample extracted text"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
+
def analyze_text(self, text: str) -> str:
|
107 |
+
"""Analyze extracted text using the LLM."""
|
108 |
+
# Implement your text analysis logic here
|
109 |
+
return "Sample analysis result"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
def cleanup(self):
|
112 |
+
"""Clean up resources."""
|
113 |
if hasattr(self, 'model'):
|
114 |
del self.model
|
115 |
if hasattr(self, 'rag_model'):
|
116 |
del self.rag_model
|
|
|
|
|
117 |
torch.cuda.empty_cache()
|
|
|
118 |
logger.info("TxAgent resources cleaned up")
|
119 |
|
120 |
def __del__(self):
|