Ali2206 commited on
Commit
7b04f0b
·
verified ·
1 Parent(s): 3427062

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +15 -16
src/txagent/txagent.py CHANGED
@@ -44,7 +44,7 @@ class TxAgent:
44
  self.rag_model_name = rag_model_name
45
  self.tool_files_dict = tool_files_dict
46
  self.model = None
47
- self.rag_model = ToolRAGModel(rag_model_name) if enable_rag else None # MODIFIED: Skip RAG model if disabled
48
  self.tooluniverse = None
49
  self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning."
50
  self.self_prompt = "Strictly follow the instruction."
@@ -66,7 +66,7 @@ class TxAgent:
66
 
67
  def init_model(self):
68
  self.load_models()
69
- if self.enable_rag: # MODIFIED: Only load tools if RAG enabled
70
  self.load_tooluniverse()
71
  self.load_tool_desc_embedding()
72
 
@@ -79,8 +79,7 @@ class TxAgent:
79
  if model_name == self.model_name:
80
  return f"The model {model_name} is already loaded."
81
  self.model_name = model_name
82
-
83
- self.model = LLM(model=self.model_name, enforce_eager=True) # MODIFIED: Force no torch.compile
84
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
85
  self.tokenizer = self.model.get_tokenizer()
86
  return f"Model {model_name} loaded successfully."
@@ -106,7 +105,7 @@ class TxAgent:
106
 
107
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
108
  picked_tools_prompt = []
109
- if not self.enable_rag: # MODIFIED: No tools if RAG disabled
110
  return picked_tools_prompt, call_agent_level
111
  picked_tools_prompt = self.add_special_tools(
112
  picked_tools_prompt, call_agent=call_agent)
@@ -164,18 +163,18 @@ class TxAgent:
164
  return picked_tools_prompt
165
 
166
  def add_special_tools(self, tools, call_agent=False):
167
- if not self.enable_rag and not self.enable_finish: # MODIFIED: No tools if RAG disabled
168
  return tools
169
- if self.enable_finish:
170
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
171
  logger.info("Finish tool is added")
172
- if call_agent:
173
  tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
174
  logger.info("CallAgent tool is added")
175
- elif self.enable_rag:
176
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
177
  logger.info("Tool_RAG tool is added")
178
- if self.additional_default_tools is not None:
179
  for each_tool_name in self.additional_default_tools:
180
  tool_prompt = self.tooluniverse.get_one_tool_by_one_name(each_tool_name, return_prompt=True)
181
  if tool_prompt is not None:
@@ -184,7 +183,7 @@ class TxAgent:
184
  return tools
185
 
186
  def add_finish_tools(self, tools):
187
- if not self.enable_finish:
188
  return tools
189
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
190
  logger.info("Finish tool is added")
@@ -204,7 +203,7 @@ class TxAgent:
204
  call_agent=False,
205
  call_agent_level=None,
206
  temperature=None):
207
- if not self.enable_rag: # MODIFIED: Skip function calls if RAG disabled
208
  return [{"role": "assistant", "content": fcall_str.strip()}], existing_tools_prompt or [], ''
209
  function_call_json, message = self.tooluniverse.extract_function_call_json(fcall_str, return_message=return_message, verbose=False)
210
  call_results = []
@@ -270,7 +269,7 @@ class TxAgent:
270
  call_agent_level=None,
271
  temperature=None,
272
  return_gradio_history=True):
273
- if not self.enable_rag: # MODIFIED: Skip function calls if RAG disabled
274
  gradio_history = [] if return_gradio_history else None
275
  return [{"role": "assistant", "content": fcall_str.strip()}], existing_tools_prompt or [], '', gradio_history
276
  function_call_json, message = self.tooluniverse.extract_function_call_json(fcall_str, return_message=return_message, verbose=False)
@@ -452,7 +451,7 @@ class TxAgent:
452
  if len(assistant_messages) == 2:
453
  break
454
  forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
455
- return [NoRepeatSentenceProcessor(forbidden_ids, 3)] # MODIFIED: Stricter penalty
456
  return None
457
 
458
  def llm_infer(self, messages, temperature=0.1, tools=None,
@@ -661,7 +660,7 @@ Generate one summarized sentence about "function calls' responses" with necessar
661
  call_agent, call_agent_level, message)
662
  conversation = self.initialize_conversation(
663
  message, conversation=conversation, history=history)
664
- history = []
665
  next_round = True
666
  function_call_messages = []
667
  current_round = 0
@@ -674,7 +673,7 @@ Generate one summarized sentence about "function calls' responses" with necessar
674
  while next_round and current_round < max_round:
675
  current_round += 1
676
  logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
677
- if last_outputs and self.enable_rag: # MODIFIED: Skip function calls if RAG disabled
678
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
679
  last_outputs, return_message=True,
680
  existing_tools_prompt=picked_tools_prompt,
 
44
  self.rag_model_name = rag_model_name
45
  self.tool_files_dict = tool_files_dict
46
  self.model = None
47
+ self.rag_model = ToolRAGModel(rag_model_name) if enable_rag else None
48
  self.tooluniverse = None
49
  self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning."
50
  self.self_prompt = "Strictly follow the instruction."
 
66
 
67
  def init_model(self):
68
  self.load_models()
69
+ if self.enable_rag:
70
  self.load_tooluniverse()
71
  self.load_tool_desc_embedding()
72
 
 
79
  if model_name == self.model_name:
80
  return f"The model {model_name} is already loaded."
81
  self.model_name = model_name
82
+ self.model = LLM(model=self.model_name, enforce_eager=True)
 
83
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
84
  self.tokenizer = self.model.get_tokenizer()
85
  return f"Model {model_name} loaded successfully."
 
105
 
106
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
107
  picked_tools_prompt = []
108
+ if not self.enable_rag:
109
  return picked_tools_prompt, call_agent_level
110
  picked_tools_prompt = self.add_special_tools(
111
  picked_tools_prompt, call_agent=call_agent)
 
163
  return picked_tools_prompt
164
 
165
  def add_special_tools(self, tools, call_agent=False):
166
+ if not self.enable_rag and not self.enable_finish:
167
  return tools
168
+ if self.enable_finish and self.tooluniverse: # MODIFIED: Check tooluniverse
169
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
170
  logger.info("Finish tool is added")
171
+ if call_agent and self.tooluniverse: # MODIFIED: Check tooluniverse
172
  tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
173
  logger.info("CallAgent tool is added")
174
+ elif self.enable_rag and self.tooluniverse: # MODIFIED: Check tooluniverse
175
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
176
  logger.info("Tool_RAG tool is added")
177
+ if self.additional_default_tools is not None and self.tooluniverse: # MODIFIED: Check tooluniverse
178
  for each_tool_name in self.additional_default_tools:
179
  tool_prompt = self.tooluniverse.get_one_tool_by_one_name(each_tool_name, return_prompt=True)
180
  if tool_prompt is not None:
 
183
  return tools
184
 
185
  def add_finish_tools(self, tools):
186
+ if not self.enable_finish or not self.tooluniverse: # MODIFIED: Check tooluniverse
187
  return tools
188
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
189
  logger.info("Finish tool is added")
 
203
  call_agent=False,
204
  call_agent_level=None,
205
  temperature=None):
206
+ if not self.enable_rag:
207
  return [{"role": "assistant", "content": fcall_str.strip()}], existing_tools_prompt or [], ''
208
  function_call_json, message = self.tooluniverse.extract_function_call_json(fcall_str, return_message=return_message, verbose=False)
209
  call_results = []
 
269
  call_agent_level=None,
270
  temperature=None,
271
  return_gradio_history=True):
272
+ if not self.enable_rag:
273
  gradio_history = [] if return_gradio_history else None
274
  return [{"role": "assistant", "content": fcall_str.strip()}], existing_tools_prompt or [], '', gradio_history
275
  function_call_json, message = self.tooluniverse.extract_function_call_json(fcall_str, return_message=return_message, verbose=False)
 
451
  if len(assistant_messages) == 2:
452
  break
453
  forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
454
+ return [NoRepeatSentenceProcessor(forbidden_ids, 3)]
455
  return None
456
 
457
  def llm_infer(self, messages, temperature=0.1, tools=None,
 
660
  call_agent, call_agent_level, message)
661
  conversation = self.initialize_conversation(
662
  message, conversation=conversation, history=history)
663
+ history = [] if not history else history # MODIFIED: Simplify history
664
  next_round = True
665
  function_call_messages = []
666
  current_round = 0
 
673
  while next_round and current_round < max_round:
674
  current_round += 1
675
  logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
676
+ if last_outputs and self.enable_rag:
677
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
678
  last_outputs, return_message=True,
679
  existing_tools_prompt=picked_tools_prompt,