Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +31 -12
src/txagent/txagent.py
CHANGED
@@ -73,7 +73,7 @@ class TxAgent:
|
|
73 |
return f"The model {model_name} is already loaded."
|
74 |
self.model_name = model_name
|
75 |
|
76 |
-
self.model = LLM(model=self.model_name, dtype="float16", max_model_len=
|
77 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
78 |
self.tokenizer = self.model.get_tokenizer()
|
79 |
logger.info("Model %s loaded successfully", self.model_name)
|
@@ -176,8 +176,14 @@ class TxAgent:
|
|
176 |
call_agent=False,
|
177 |
call_agent_level=None,
|
178 |
temperature=None):
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
call_results = []
|
182 |
special_tool_call = ''
|
183 |
if function_call_json:
|
@@ -197,7 +203,7 @@ class TxAgent:
|
|
197 |
)
|
198 |
call_result = self.run_multistep_agent(
|
199 |
full_message, temperature=temperature,
|
200 |
-
max_new_tokens=512, max_token=
|
201 |
call_agent=False, call_agent_level=call_agent_level)
|
202 |
if call_result is None:
|
203 |
call_result = "⚠️ No content returned from sub-agent."
|
@@ -217,7 +223,7 @@ class TxAgent:
|
|
217 |
else:
|
218 |
call_results.append({
|
219 |
"role": "tool",
|
220 |
-
"content": json.dumps({"content": "Invalid function call
|
221 |
})
|
222 |
|
223 |
revised_messages = [{
|
@@ -235,8 +241,14 @@ class TxAgent:
|
|
235 |
call_agent_level=None,
|
236 |
temperature=None,
|
237 |
return_gradio_history=True):
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
call_results = []
|
241 |
special_tool_call = ''
|
242 |
if return_gradio_history:
|
@@ -264,7 +276,7 @@ class TxAgent:
|
|
264 |
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
265 |
call_result = yield from self.run_gradio_chat(
|
266 |
full_message, history=[], temperature=temperature,
|
267 |
-
max_new_tokens=512, max_token=
|
268 |
call_agent=False, call_agent_level=call_agent_level,
|
269 |
conversation=None, sub_agent_task=sub_agent_task)
|
270 |
if call_result is not None and isinstance(call_result, str):
|
@@ -287,7 +299,7 @@ class TxAgent:
|
|
287 |
else:
|
288 |
call_results.append({
|
289 |
"role": "tool",
|
290 |
-
"content": json.dumps({"content": "Invalid function call
|
291 |
})
|
292 |
|
293 |
revised_messages = [{
|
@@ -300,6 +312,13 @@ class TxAgent:
|
|
300 |
return revised_messages, existing_tools_prompt, special_tool_call
|
301 |
|
302 |
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
if conversation[-1]['role'] == 'assistant':
|
304 |
conversation.append(
|
305 |
{'role': 'tool', 'content': 'Errors occurred during function call; provide final answer with current information.'})
|
@@ -401,7 +420,7 @@ class TxAgent:
|
|
401 |
|
402 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
403 |
output_begin_string=None, max_new_tokens=512,
|
404 |
-
max_token=
|
405 |
model=None, tokenizer=None, terminators=None,
|
406 |
seed=None, check_token_status=False):
|
407 |
if model is None:
|
@@ -550,7 +569,7 @@ Summarize the function calls' responses in one sentence with all necessary infor
|
|
550 |
function_response=function_response,
|
551 |
temperature=0.1,
|
552 |
max_new_tokens=512,
|
553 |
-
max_token=
|
554 |
input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
|
555 |
status['summarized_index'] = last_call_idx + 2
|
556 |
idx += 1
|
@@ -572,7 +591,7 @@ Summarize the function calls' responses in one sentence with all necessary infor
|
|
572 |
function_response=function_response,
|
573 |
temperature=0.1,
|
574 |
max_new_tokens=512,
|
575 |
-
max_token=
|
576 |
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
577 |
for tool_call in tool_calls:
|
578 |
del tool_call['call_id']
|
|
|
73 |
return f"The model {model_name} is already loaded."
|
74 |
self.model_name = model_name
|
75 |
|
76 |
+
self.model = LLM(model=self.model_name, dtype="float16", max_model_len=2048, gpu_memory_utilization=0.8)
|
77 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
78 |
self.tokenizer = self.model.get_tokenizer()
|
79 |
logger.info("Model %s loaded successfully", self.model_name)
|
|
|
176 |
call_agent=False,
|
177 |
call_agent_level=None,
|
178 |
temperature=None):
|
179 |
+
try:
|
180 |
+
function_call_json, message = self.tooluniverse.extract_function_call_json(
|
181 |
+
fcall_str, return_message=return_message, verbose=False)
|
182 |
+
except Exception as e:
|
183 |
+
logger.error("Tool call parsing failed: %s", e)
|
184 |
+
function_call_json = []
|
185 |
+
message = fcall_str
|
186 |
+
|
187 |
call_results = []
|
188 |
special_tool_call = ''
|
189 |
if function_call_json:
|
|
|
203 |
)
|
204 |
call_result = self.run_multistep_agent(
|
205 |
full_message, temperature=temperature,
|
206 |
+
max_new_tokens=512, max_token=2048,
|
207 |
call_agent=False, call_agent_level=call_agent_level)
|
208 |
if call_result is None:
|
209 |
call_result = "⚠️ No content returned from sub-agent."
|
|
|
223 |
else:
|
224 |
call_results.append({
|
225 |
"role": "tool",
|
226 |
+
"content": json.dumps({"content": "Invalid or no function call detected."})
|
227 |
})
|
228 |
|
229 |
revised_messages = [{
|
|
|
241 |
call_agent_level=None,
|
242 |
temperature=None,
|
243 |
return_gradio_history=True):
|
244 |
+
try:
|
245 |
+
function_call_json, message = self.tooluniverse.extract_function_call_json(
|
246 |
+
fcall_str, return_message=return_message, verbose=False)
|
247 |
+
except Exception as e:
|
248 |
+
logger.error("Tool call parsing failed: %s", e)
|
249 |
+
function_call_json = []
|
250 |
+
message = fcall_str
|
251 |
+
|
252 |
call_results = []
|
253 |
special_tool_call = ''
|
254 |
if return_gradio_history:
|
|
|
276 |
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
277 |
call_result = yield from self.run_gradio_chat(
|
278 |
full_message, history=[], temperature=temperature,
|
279 |
+
max_new_tokens=512, max_token=2048,
|
280 |
call_agent=False, call_agent_level=call_agent_level,
|
281 |
conversation=None, sub_agent_task=sub_agent_task)
|
282 |
if call_result is not None and isinstance(call_result, str):
|
|
|
299 |
else:
|
300 |
call_results.append({
|
301 |
"role": "tool",
|
302 |
+
"content": json.dumps({"content": "Invalid or no function call detected."})
|
303 |
})
|
304 |
|
305 |
revised_messages = [{
|
|
|
312 |
return revised_messages, existing_tools_prompt, special_tool_call
|
313 |
|
314 |
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
|
315 |
+
# Truncate conversation to fit within max_token
|
316 |
+
tokenized = self.tokenizer.encode(json.dumps(conversation), add_special_tokens=False)
|
317 |
+
if len(tokenized) > max_token - 100:
|
318 |
+
logger.warning("Truncating conversation to fit max_token=%d", max_token)
|
319 |
+
while len(tokenized) > max_token - 100 and len(conversation) > 1:
|
320 |
+
conversation.pop(1) # Keep system prompt and latest message
|
321 |
+
tokenized = self.tokenizer.encode(json.dumps(conversation), add_special_tokens=False)
|
322 |
if conversation[-1]['role'] == 'assistant':
|
323 |
conversation.append(
|
324 |
{'role': 'tool', 'content': 'Errors occurred during function call; provide final answer with current information.'})
|
|
|
420 |
|
421 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
422 |
output_begin_string=None, max_new_tokens=512,
|
423 |
+
max_token=2048, skip_special_tokens=True,
|
424 |
model=None, tokenizer=None, terminators=None,
|
425 |
seed=None, check_token_status=False):
|
426 |
if model is None:
|
|
|
569 |
function_response=function_response,
|
570 |
temperature=0.1,
|
571 |
max_new_tokens=512,
|
572 |
+
max_token=2048)
|
573 |
input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
|
574 |
status['summarized_index'] = last_call_idx + 2
|
575 |
idx += 1
|
|
|
591 |
function_response=function_response,
|
592 |
temperature=0.1,
|
593 |
max_new_tokens=512,
|
594 |
+
max_token=2048)
|
595 |
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
596 |
for tool_call in tool_calls:
|
597 |
del tool_call['call_id']
|