Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +15 -16
src/txagent/txagent.py
CHANGED
@@ -44,7 +44,7 @@ class TxAgent:
|
|
44 |
self.rag_model_name = rag_model_name
|
45 |
self.tool_files_dict = tool_files_dict
|
46 |
self.model = None
|
47 |
-
self.rag_model = ToolRAGModel(rag_model_name) if enable_rag else None
|
48 |
self.tooluniverse = None
|
49 |
self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning."
|
50 |
self.self_prompt = "Strictly follow the instruction."
|
@@ -66,7 +66,7 @@ class TxAgent:
|
|
66 |
|
67 |
def init_model(self):
|
68 |
self.load_models()
|
69 |
-
if self.enable_rag:
|
70 |
self.load_tooluniverse()
|
71 |
self.load_tool_desc_embedding()
|
72 |
|
@@ -79,8 +79,7 @@ class TxAgent:
|
|
79 |
if model_name == self.model_name:
|
80 |
return f"The model {model_name} is already loaded."
|
81 |
self.model_name = model_name
|
82 |
-
|
83 |
-
self.model = LLM(model=self.model_name, enforce_eager=True) # MODIFIED: Force no torch.compile
|
84 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
85 |
self.tokenizer = self.model.get_tokenizer()
|
86 |
return f"Model {model_name} loaded successfully."
|
@@ -106,7 +105,7 @@ class TxAgent:
|
|
106 |
|
107 |
def initialize_tools_prompt(self, call_agent, call_agent_level, message):
|
108 |
picked_tools_prompt = []
|
109 |
-
if not self.enable_rag:
|
110 |
return picked_tools_prompt, call_agent_level
|
111 |
picked_tools_prompt = self.add_special_tools(
|
112 |
picked_tools_prompt, call_agent=call_agent)
|
@@ -164,18 +163,18 @@ class TxAgent:
|
|
164 |
return picked_tools_prompt
|
165 |
|
166 |
def add_special_tools(self, tools, call_agent=False):
|
167 |
-
if not self.enable_rag and not self.enable_finish:
|
168 |
return tools
|
169 |
-
if self.enable_finish:
|
170 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
171 |
logger.info("Finish tool is added")
|
172 |
-
if call_agent:
|
173 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
|
174 |
logger.info("CallAgent tool is added")
|
175 |
-
elif self.enable_rag:
|
176 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
|
177 |
logger.info("Tool_RAG tool is added")
|
178 |
-
if self.additional_default_tools is not None:
|
179 |
for each_tool_name in self.additional_default_tools:
|
180 |
tool_prompt = self.tooluniverse.get_one_tool_by_one_name(each_tool_name, return_prompt=True)
|
181 |
if tool_prompt is not None:
|
@@ -184,7 +183,7 @@ class TxAgent:
|
|
184 |
return tools
|
185 |
|
186 |
def add_finish_tools(self, tools):
|
187 |
-
if not self.enable_finish:
|
188 |
return tools
|
189 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
190 |
logger.info("Finish tool is added")
|
@@ -204,7 +203,7 @@ class TxAgent:
|
|
204 |
call_agent=False,
|
205 |
call_agent_level=None,
|
206 |
temperature=None):
|
207 |
-
if not self.enable_rag:
|
208 |
return [{"role": "assistant", "content": fcall_str.strip()}], existing_tools_prompt or [], ''
|
209 |
function_call_json, message = self.tooluniverse.extract_function_call_json(fcall_str, return_message=return_message, verbose=False)
|
210 |
call_results = []
|
@@ -270,7 +269,7 @@ class TxAgent:
|
|
270 |
call_agent_level=None,
|
271 |
temperature=None,
|
272 |
return_gradio_history=True):
|
273 |
-
if not self.enable_rag:
|
274 |
gradio_history = [] if return_gradio_history else None
|
275 |
return [{"role": "assistant", "content": fcall_str.strip()}], existing_tools_prompt or [], '', gradio_history
|
276 |
function_call_json, message = self.tooluniverse.extract_function_call_json(fcall_str, return_message=return_message, verbose=False)
|
@@ -452,7 +451,7 @@ class TxAgent:
|
|
452 |
if len(assistant_messages) == 2:
|
453 |
break
|
454 |
forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
|
455 |
-
return [NoRepeatSentenceProcessor(forbidden_ids, 3)]
|
456 |
return None
|
457 |
|
458 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
@@ -661,7 +660,7 @@ Generate one summarized sentence about "function calls' responses" with necessar
|
|
661 |
call_agent, call_agent_level, message)
|
662 |
conversation = self.initialize_conversation(
|
663 |
message, conversation=conversation, history=history)
|
664 |
-
history = []
|
665 |
next_round = True
|
666 |
function_call_messages = []
|
667 |
current_round = 0
|
@@ -674,7 +673,7 @@ Generate one summarized sentence about "function calls' responses" with necessar
|
|
674 |
while next_round and current_round < max_round:
|
675 |
current_round += 1
|
676 |
logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
|
677 |
-
if last_outputs and self.enable_rag:
|
678 |
function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
|
679 |
last_outputs, return_message=True,
|
680 |
existing_tools_prompt=picked_tools_prompt,
|
|
|
44 |
self.rag_model_name = rag_model_name
|
45 |
self.tool_files_dict = tool_files_dict
|
46 |
self.model = None
|
47 |
+
self.rag_model = ToolRAGModel(rag_model_name) if enable_rag else None
|
48 |
self.tooluniverse = None
|
49 |
self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning."
|
50 |
self.self_prompt = "Strictly follow the instruction."
|
|
|
66 |
|
67 |
def init_model(self):
|
68 |
self.load_models()
|
69 |
+
if self.enable_rag:
|
70 |
self.load_tooluniverse()
|
71 |
self.load_tool_desc_embedding()
|
72 |
|
|
|
79 |
if model_name == self.model_name:
|
80 |
return f"The model {model_name} is already loaded."
|
81 |
self.model_name = model_name
|
82 |
+
self.model = LLM(model=self.model_name, enforce_eager=True)
|
|
|
83 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
84 |
self.tokenizer = self.model.get_tokenizer()
|
85 |
return f"Model {model_name} loaded successfully."
|
|
|
105 |
|
106 |
def initialize_tools_prompt(self, call_agent, call_agent_level, message):
|
107 |
picked_tools_prompt = []
|
108 |
+
if not self.enable_rag:
|
109 |
return picked_tools_prompt, call_agent_level
|
110 |
picked_tools_prompt = self.add_special_tools(
|
111 |
picked_tools_prompt, call_agent=call_agent)
|
|
|
163 |
return picked_tools_prompt
|
164 |
|
165 |
def add_special_tools(self, tools, call_agent=False):
|
166 |
+
if not self.enable_rag and not self.enable_finish:
|
167 |
return tools
|
168 |
+
if self.enable_finish and self.tooluniverse: # MODIFIED: Check tooluniverse
|
169 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
170 |
logger.info("Finish tool is added")
|
171 |
+
if call_agent and self.tooluniverse: # MODIFIED: Check tooluniverse
|
172 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
|
173 |
logger.info("CallAgent tool is added")
|
174 |
+
elif self.enable_rag and self.tooluniverse: # MODIFIED: Check tooluniverse
|
175 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
|
176 |
logger.info("Tool_RAG tool is added")
|
177 |
+
if self.additional_default_tools is not None and self.tooluniverse: # MODIFIED: Check tooluniverse
|
178 |
for each_tool_name in self.additional_default_tools:
|
179 |
tool_prompt = self.tooluniverse.get_one_tool_by_one_name(each_tool_name, return_prompt=True)
|
180 |
if tool_prompt is not None:
|
|
|
183 |
return tools
|
184 |
|
185 |
def add_finish_tools(self, tools):
|
186 |
+
if not self.enable_finish or not self.tooluniverse: # MODIFIED: Check tooluniverse
|
187 |
return tools
|
188 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
189 |
logger.info("Finish tool is added")
|
|
|
203 |
call_agent=False,
|
204 |
call_agent_level=None,
|
205 |
temperature=None):
|
206 |
+
if not self.enable_rag:
|
207 |
return [{"role": "assistant", "content": fcall_str.strip()}], existing_tools_prompt or [], ''
|
208 |
function_call_json, message = self.tooluniverse.extract_function_call_json(fcall_str, return_message=return_message, verbose=False)
|
209 |
call_results = []
|
|
|
269 |
call_agent_level=None,
|
270 |
temperature=None,
|
271 |
return_gradio_history=True):
|
272 |
+
if not self.enable_rag:
|
273 |
gradio_history = [] if return_gradio_history else None
|
274 |
return [{"role": "assistant", "content": fcall_str.strip()}], existing_tools_prompt or [], '', gradio_history
|
275 |
function_call_json, message = self.tooluniverse.extract_function_call_json(fcall_str, return_message=return_message, verbose=False)
|
|
|
451 |
if len(assistant_messages) == 2:
|
452 |
break
|
453 |
forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
|
454 |
+
return [NoRepeatSentenceProcessor(forbidden_ids, 3)]
|
455 |
return None
|
456 |
|
457 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
|
|
660 |
call_agent, call_agent_level, message)
|
661 |
conversation = self.initialize_conversation(
|
662 |
message, conversation=conversation, history=history)
|
663 |
+
history = [] if not history else history # MODIFIED: Simplify history
|
664 |
next_round = True
|
665 |
function_call_messages = []
|
666 |
current_round = 0
|
|
|
673 |
while next_round and current_round < max_round:
|
674 |
current_round += 1
|
675 |
logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
|
676 |
+
if last_outputs and self.enable_rag:
|
677 |
function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
|
678 |
last_outputs, return_message=True,
|
679 |
existing_tools_prompt=picked_tools_prompt,
|