Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +72 -20
src/txagent/txagent.py
CHANGED
|
@@ -74,10 +74,20 @@ class TxAgent:
|
|
| 74 |
return f"The model {model_name} is already loaded."
|
| 75 |
self.model_name = model_name
|
| 76 |
|
| 77 |
-
self.model = LLM(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
| 79 |
self.tokenizer = self.model.get_tokenizer()
|
| 80 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
| 81 |
return f"Model {model_name} loaded successfully."
|
| 82 |
|
| 83 |
def load_tooluniverse(self):
|
|
@@ -204,7 +214,7 @@ class TxAgent:
|
|
| 204 |
)
|
| 205 |
call_result = self.run_multistep_agent(
|
| 206 |
full_message, temperature=temperature,
|
| 207 |
-
max_new_tokens=512, max_token=
|
| 208 |
call_agent=False, call_agent_level=call_agent_level)
|
| 209 |
if call_result is None:
|
| 210 |
call_result = "⚠️ No content returned from sub-agent."
|
|
@@ -277,7 +287,7 @@ class TxAgent:
|
|
| 277 |
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
| 278 |
call_result = yield from self.run_gradio_chat(
|
| 279 |
full_message, history=[], temperature=temperature,
|
| 280 |
-
max_new_tokens=512, max_token=
|
| 281 |
call_agent=False, call_agent_level=call_agent_level,
|
| 282 |
conversation=None, sub_agent_task=sub_agent_task)
|
| 283 |
if call_result is not None and isinstance(call_result, str):
|
|
@@ -387,7 +397,7 @@ class TxAgent:
|
|
| 387 |
tools=picked_tools_prompt,
|
| 388 |
skip_special_tokens=False,
|
| 389 |
max_new_tokens=2048,
|
| 390 |
-
max_token=
|
| 391 |
check_token_status=True)
|
| 392 |
if last_outputs_str is None:
|
| 393 |
logger.warning("Token limit exceeded")
|
|
@@ -410,7 +420,7 @@ class TxAgent:
|
|
| 410 |
|
| 411 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
| 412 |
output_begin_string=None, max_new_tokens=512,
|
| 413 |
-
max_token=
|
| 414 |
model=None, tokenizer=None, terminators=None,
|
| 415 |
seed=None, check_token_status=False):
|
| 416 |
if model is None:
|
|
@@ -430,21 +440,23 @@ class TxAgent:
|
|
| 430 |
|
| 431 |
if check_token_status and max_token is not None:
|
| 432 |
token_overflow = False
|
| 433 |
-
num_input_tokens = len(self.tokenizer.encode(prompt,
|
|
|
|
| 434 |
if num_input_tokens > max_token:
|
| 435 |
torch.cuda.empty_cache()
|
| 436 |
gc.collect()
|
| 437 |
-
logger.
|
| 438 |
return None, True
|
| 439 |
|
| 440 |
output = model.generate(prompt, sampling_params=sampling_params)
|
| 441 |
-
|
| 442 |
-
|
|
|
|
| 443 |
torch.cuda.empty_cache()
|
| 444 |
gc.collect()
|
| 445 |
if check_token_status and max_token is not None:
|
| 446 |
-
return
|
| 447 |
-
return
|
| 448 |
|
| 449 |
def run_self_agent(self, message: str,
|
| 450 |
temperature: float,
|
|
@@ -514,7 +526,7 @@ Function calls' responses:
|
|
| 514 |
\"\"\"
|
| 515 |
{function_response}
|
| 516 |
\"\"\"
|
| 517 |
-
Summarize the function calls' responses in one sentence with all necessary information.
|
| 518 |
"""
|
| 519 |
conversation = [{"role": "user", "content": prompt}]
|
| 520 |
output = self.llm_infer(
|
|
@@ -559,7 +571,7 @@ Summarize the function calls' responses in one sentence with all necessary infor
|
|
| 559 |
function_response=function_response,
|
| 560 |
temperature=0.1,
|
| 561 |
max_new_tokens=512,
|
| 562 |
-
max_token=
|
| 563 |
input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
|
| 564 |
status['summarized_index'] = last_call_idx + 2
|
| 565 |
idx += 1
|
|
@@ -581,7 +593,7 @@ Summarize the function calls' responses in one sentence with all necessary infor
|
|
| 581 |
function_response=function_response,
|
| 582 |
temperature=0.1,
|
| 583 |
max_new_tokens=512,
|
| 584 |
-
max_token=
|
| 585 |
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
| 586 |
for tool_call in tool_calls:
|
| 587 |
del tool_call['call_id']
|
|
@@ -603,10 +615,10 @@ Summarize the function calls' responses in one sentence with all necessary infor
|
|
| 603 |
def run_gradio_chat(self, message: str,
|
| 604 |
history: list,
|
| 605 |
temperature: float,
|
| 606 |
-
max_new_tokens: 2048,
|
| 607 |
-
max_token:
|
| 608 |
-
call_agent: bool,
|
| 609 |
-
conversation: gr.State,
|
| 610 |
max_round: int = 5,
|
| 611 |
seed: int = None,
|
| 612 |
call_agent_level: int = 0,
|
|
@@ -755,4 +767,44 @@ Summarize the function calls' responses in one sentence with all necessary infor
|
|
| 755 |
logger.info("Forced final answer after error: %s", final_answer[:100])
|
| 756 |
yield history
|
| 757 |
return final_answer
|
| 758 |
-
return error_msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
return f"The model {model_name} is already loaded."
|
| 75 |
self.model_name = model_name
|
| 76 |
|
| 77 |
+
self.model = LLM(
|
| 78 |
+
model=self.model_name,
|
| 79 |
+
dtype="float16",
|
| 80 |
+
max_model_len=131072,
|
| 81 |
+
max_num_batched_tokens=32768, # Increased for A100 80GB
|
| 82 |
+
gpu_memory_utilization=0.9, # Higher utilization for better performance
|
| 83 |
+
trust_remote_code=True
|
| 84 |
+
)
|
| 85 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
| 86 |
self.tokenizer = self.model.get_tokenizer()
|
| 87 |
+
logger.info(
|
| 88 |
+
"Model %s loaded with max_model_len=%d, max_num_batched_tokens=%d, gpu_memory_utilization=%.2f",
|
| 89 |
+
self.model_name, 131072, 32768, 0.9
|
| 90 |
+
)
|
| 91 |
return f"Model {model_name} loaded successfully."
|
| 92 |
|
| 93 |
def load_tooluniverse(self):
|
|
|
|
| 214 |
)
|
| 215 |
call_result = self.run_multistep_agent(
|
| 216 |
full_message, temperature=temperature,
|
| 217 |
+
max_new_tokens=512, max_token=131072,
|
| 218 |
call_agent=False, call_agent_level=call_agent_level)
|
| 219 |
if call_result is None:
|
| 220 |
call_result = "⚠️ No content returned from sub-agent."
|
|
|
|
| 287 |
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
| 288 |
call_result = yield from self.run_gradio_chat(
|
| 289 |
full_message, history=[], temperature=temperature,
|
| 290 |
+
max_new_tokens=512, max_token=131072,
|
| 291 |
call_agent=False, call_agent_level=call_agent_level,
|
| 292 |
conversation=None, sub_agent_task=sub_agent_task)
|
| 293 |
if call_result is not None and isinstance(call_result, str):
|
|
|
|
| 397 |
tools=picked_tools_prompt,
|
| 398 |
skip_special_tokens=False,
|
| 399 |
max_new_tokens=2048,
|
| 400 |
+
max_token=131072,
|
| 401 |
check_token_status=True)
|
| 402 |
if last_outputs_str is None:
|
| 403 |
logger.warning("Token limit exceeded")
|
|
|
|
| 420 |
|
| 421 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
| 422 |
output_begin_string=None, max_new_tokens=512,
|
| 423 |
+
max_token=131072, skip_special_tokens=True,
|
| 424 |
model=None, tokenizer=None, terminators=None,
|
| 425 |
seed=None, check_token_status=False):
|
| 426 |
if model is None:
|
|
|
|
| 440 |
|
| 441 |
if check_token_status and max_token is not None:
|
| 442 |
token_overflow = False
|
| 443 |
+
num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False))
|
| 444 |
+
logger.info("Input prompt tokens: %d, max_token: %d", num_input_tokens, max_token)
|
| 445 |
if num_input_tokens > max_token:
|
| 446 |
torch.cuda.empty_cache()
|
| 447 |
gc.collect()
|
| 448 |
+
logger.warning("Token overflow: %d > %d", num_input_tokens, max_token)
|
| 449 |
return None, True
|
| 450 |
|
| 451 |
output = model.generate(prompt, sampling_params=sampling_params)
|
| 452 |
+
output_text = output[0].outputs[0].text
|
| 453 |
+
output_tokens = len(self.tokenizer.encode(output_text, add_special_tokens=False))
|
| 454 |
+
logger.debug("Inference output: %s (output tokens: %d)", output_text[:100], output_tokens)
|
| 455 |
torch.cuda.empty_cache()
|
| 456 |
gc.collect()
|
| 457 |
if check_token_status and max_token is not None:
|
| 458 |
+
return output_text, token_overflow
|
| 459 |
+
return output_text
|
| 460 |
|
| 461 |
def run_self_agent(self, message: str,
|
| 462 |
temperature: float,
|
|
|
|
| 526 |
\"\"\"
|
| 527 |
{function_response}
|
| 528 |
\"\"\"
|
| 529 |
+
Summarize the function calls' l responses in one sentence with all necessary information.
|
| 530 |
"""
|
| 531 |
conversation = [{"role": "user", "content": prompt}]
|
| 532 |
output = self.llm_infer(
|
|
|
|
| 571 |
function_response=function_response,
|
| 572 |
temperature=0.1,
|
| 573 |
max_new_tokens=512,
|
| 574 |
+
max_token=131072)
|
| 575 |
input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
|
| 576 |
status['summarized_index'] = last_call_idx + 2
|
| 577 |
idx += 1
|
|
|
|
| 593 |
function_response=function_response,
|
| 594 |
temperature=0.1,
|
| 595 |
max_new_tokens=512,
|
| 596 |
+
max_token=131072)
|
| 597 |
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
| 598 |
for tool_call in tool_calls:
|
| 599 |
del tool_call['call_id']
|
|
|
|
| 615 |
def run_gradio_chat(self, message: str,
|
| 616 |
history: list,
|
| 617 |
temperature: float,
|
| 618 |
+
max_new_tokens: int = 2048,
|
| 619 |
+
max_token: int = 131072,
|
| 620 |
+
call_agent: bool = False,
|
| 621 |
+
conversation: gr.State = None,
|
| 622 |
max_round: int = 5,
|
| 623 |
seed: int = None,
|
| 624 |
call_agent_level: int = 0,
|
|
|
|
| 767 |
logger.info("Forced final answer after error: %s", final_answer[:100])
|
| 768 |
yield history
|
| 769 |
return final_answer
|
| 770 |
+
return error_msg
|
| 771 |
+
|
| 772 |
+
def run_gradio_chat_batch(self, messages: List[str],
|
| 773 |
+
temperature: float,
|
| 774 |
+
max_new_tokens: int = 2048,
|
| 775 |
+
max_token: int = 131072,
|
| 776 |
+
call_agent: bool = False,
|
| 777 |
+
conversation: List = None,
|
| 778 |
+
max_round: int = 5,
|
| 779 |
+
seed: int = None,
|
| 780 |
+
call_agent_level: int = 0):
|
| 781 |
+
"""Run batch inference for multiple messages."""
|
| 782 |
+
logger.info("Starting batch chat for %d messages", len(messages))
|
| 783 |
+
batch_results = []
|
| 784 |
+
|
| 785 |
+
for message in messages:
|
| 786 |
+
# Initialize conversation for each message
|
| 787 |
+
conv = self.initialize_conversation(message, conversation, history=None)
|
| 788 |
+
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
| 789 |
+
call_agent, call_agent_level, message)
|
| 790 |
+
|
| 791 |
+
# Run single inference for simplicity (extend for multi-round if needed)
|
| 792 |
+
output, token_overflow = self.llm_infer(
|
| 793 |
+
messages=conv,
|
| 794 |
+
temperature=temperature,
|
| 795 |
+
tools=picked_tools_prompt,
|
| 796 |
+
max_new_tokens=max_new_tokens,
|
| 797 |
+
max_token=max_token,
|
| 798 |
+
skip_special_tokens=False,
|
| 799 |
+
seed=seed,
|
| 800 |
+
check_token_status=True
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
if output is None:
|
| 804 |
+
logger.warning("Token limit exceeded for message: %s", message[:100])
|
| 805 |
+
batch_results.append("Token limit exceeded.")
|
| 806 |
+
else:
|
| 807 |
+
batch_results.append(output)
|
| 808 |
+
|
| 809 |
+
logger.info("Batch chat completed for %d messages", len(messages))
|
| 810 |
+
return batch_results
|