Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +10 -10
src/txagent/txagent.py
CHANGED
@@ -73,7 +73,7 @@ class TxAgent:
|
|
73 |
return f"The model {model_name} is already loaded."
|
74 |
self.model_name = model_name
|
75 |
|
76 |
-
self.model = LLM(model=self.model_name, dtype="float16", max_model_len=
|
77 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
78 |
self.tokenizer = self.model.get_tokenizer()
|
79 |
logger.info("Model %s loaded successfully", self.model_name)
|
@@ -197,7 +197,7 @@ class TxAgent:
|
|
197 |
)
|
198 |
call_result = self.run_multistep_agent(
|
199 |
full_message, temperature=temperature,
|
200 |
-
max_new_tokens=
|
201 |
call_agent=False, call_agent_level=call_agent_level)
|
202 |
if call_result is None:
|
203 |
call_result = "⚠️ No content returned from sub-agent."
|
@@ -264,7 +264,7 @@ class TxAgent:
|
|
264 |
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
265 |
call_result = yield from self.run_gradio_chat(
|
266 |
full_message, history=[], temperature=temperature,
|
267 |
-
max_new_tokens=
|
268 |
call_agent=False, call_agent_level=call_agent_level,
|
269 |
conversation=None, sub_agent_task=sub_agent_task)
|
270 |
if call_result is not None and isinstance(call_result, str):
|
@@ -356,7 +356,7 @@ class TxAgent:
|
|
356 |
if (self.enable_summary or token_overflow) and not call_agent:
|
357 |
enable_summary = True
|
358 |
last_status = self.function_result_summary(
|
359 |
-
|
360 |
|
361 |
if function_call_messages:
|
362 |
conversation.extend(function_call_messages)
|
@@ -400,8 +400,8 @@ class TxAgent:
|
|
400 |
return None
|
401 |
|
402 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
403 |
-
output_begin_string=None, max_new_tokens=
|
404 |
-
max_token=
|
405 |
model=None, tokenizer=None, terminators=None,
|
406 |
seed=None, check_token_status=False):
|
407 |
if model is None:
|
@@ -549,8 +549,8 @@ Summarize the function calls' responses in one sentence with all necessary infor
|
|
549 |
thought_calls=this_thought_calls,
|
550 |
function_response=function_response,
|
551 |
temperature=0.1,
|
552 |
-
max_new_tokens=
|
553 |
-
max_token=
|
554 |
input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
|
555 |
status['summarized_index'] = last_call_idx + 2
|
556 |
idx += 1
|
@@ -571,8 +571,8 @@ Summarize the function calls' responses in one sentence with all necessary infor
|
|
571 |
thought_calls=this_thought_calls,
|
572 |
function_response=function_response,
|
573 |
temperature=0.1,
|
574 |
-
max_new_tokens=
|
575 |
-
max_token=
|
576 |
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
577 |
for tool_call in tool_calls:
|
578 |
del tool_call['call_id']
|
|
|
73 |
return f"The model {model_name} is already loaded."
|
74 |
self.model_name = model_name
|
75 |
|
76 |
+
self.model = LLM(model=self.model_name, dtype="float16", max_model_len=1024, gpu_memory_utilization=0.8)
|
77 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
78 |
self.tokenizer = self.model.get_tokenizer()
|
79 |
logger.info("Model %s loaded successfully", self.model_name)
|
|
|
197 |
)
|
198 |
call_result = self.run_multistep_agent(
|
199 |
full_message, temperature=temperature,
|
200 |
+
max_new_tokens=512, max_token=1024,
|
201 |
call_agent=False, call_agent_level=call_agent_level)
|
202 |
if call_result is None:
|
203 |
call_result = "⚠️ No content returned from sub-agent."
|
|
|
264 |
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
265 |
call_result = yield from self.run_gradio_chat(
|
266 |
full_message, history=[], temperature=temperature,
|
267 |
+
max_new_tokens=512, max_token=1024,
|
268 |
call_agent=False, call_agent_level=call_agent_level,
|
269 |
conversation=None, sub_agent_task=sub_agent_task)
|
270 |
if call_result is not None and isinstance(call_result, str):
|
|
|
356 |
if (self.enable_summary or token_overflow) and not call_agent:
|
357 |
enable_summary = True
|
358 |
last_status = self.function_result_summary(
|
359 |
+
conversation, status=last_status, enable_summary=enable_summary)
|
360 |
|
361 |
if function_call_messages:
|
362 |
conversation.extend(function_call_messages)
|
|
|
400 |
return None
|
401 |
|
402 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
403 |
+
output_begin_string=None, max_new_tokens=512,
|
404 |
+
max_token=1024, skip_special_tokens=True,
|
405 |
model=None, tokenizer=None, terminators=None,
|
406 |
seed=None, check_token_status=False):
|
407 |
if model is None:
|
|
|
549 |
thought_calls=this_thought_calls,
|
550 |
function_response=function_response,
|
551 |
temperature=0.1,
|
552 |
+
max_new_tokens=512,
|
553 |
+
max_token=1024)
|
554 |
input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
|
555 |
status['summarized_index'] = last_call_idx + 2
|
556 |
idx += 1
|
|
|
571 |
thought_calls=this_thought_calls,
|
572 |
function_response=function_response,
|
573 |
temperature=0.1,
|
574 |
+
max_new_tokens=512,
|
575 |
+
max_token=1024)
|
576 |
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
577 |
for tool_call in tool_calls:
|
578 |
del tool_call['call_id']
|