Ali2206 commited on
Commit
2eb317a
·
verified ·
1 Parent(s): c186977

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -42
app.py CHANGED
@@ -115,24 +115,28 @@ def patch_embedding_loading():
115
  """Monkey-patch the embedding loading functionality"""
116
  try:
117
  from txagent.toolrag import ToolRAGModel
118
-
119
  original_load = ToolRAGModel.load_tool_desc_embedding
120
-
121
  def patched_load(self, tooluniverse):
122
  try:
123
  if not os.path.exists(CONFIG["embedding_filename"]):
124
  logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
125
  return False
126
-
127
  self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
128
-
129
- tools = tooluniverse.get_all_tools()
 
 
 
 
130
  current_count = len(tools)
131
  embedding_count = len(self.tool_desc_embedding)
132
-
133
  if current_count != embedding_count:
134
  logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
135
-
136
  if current_count < embedding_count:
137
  self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
138
  logger.info(f"Truncated embeddings to match {current_count} tools")
@@ -141,51 +145,20 @@ def patch_embedding_loading():
141
  padding = [last_embedding] * (current_count - embedding_count)
142
  self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding)
143
  logger.info(f"Padded embeddings to match {current_count} tools")
144
-
145
  return True
146
-
147
  except Exception as e:
148
  logger.error(f"Failed to load embeddings: {str(e)}")
149
  return False
150
-
151
  ToolRAGModel.load_tool_desc_embedding = patched_load
152
  logger.info("Successfully patched embedding loading")
153
-
154
  except Exception as e:
155
  logger.error(f"Failed to patch embedding loading: {str(e)}")
156
  raise
157
 
158
- def prepare_tool_files():
159
- """Ensure tool files exist and are populated"""
160
- os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
161
- if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
162
- logger.info("Generating tool list using ToolUniverse...")
163
- tu = ToolUniverse()
164
- tools = tu.get_all_tools() if hasattr(tu, 'get_all_tools') else []
165
- with open(CONFIG["tool_files"]["new_tool"], "w") as f:
166
- json.dump(tools, f, indent=2)
167
- logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
168
-
169
- def create_agent():
170
- """Create and initialize the TxAgent"""
171
- # Apply the embedding patch before creating the agent
172
- patch_embedding_loading()
173
- prepare_tool_files()
174
-
175
- # Initialize the agent
176
- agent = TxAgent(
177
- CONFIG["model_name"],
178
- CONFIG["rag_model_name"],
179
- tool_files_dict=CONFIG["tool_files"],
180
- force_finish=True,
181
- enable_checker=True,
182
- step_rag_num=10,
183
- seed=100,
184
- additional_default_tools=['DirectResponse', 'RequireClarification']
185
- )
186
- agent.init_model()
187
- return agent
188
-
189
  def update_model_parameters(agent, enable_finish, enable_rag, enable_summary,
190
  init_rag_num, step_rag_num, skip_last_k,
191
  summary_mode, summary_skip_last_k, summary_context_length,
 
115
  """Monkey-patch the embedding loading functionality"""
116
  try:
117
  from txagent.toolrag import ToolRAGModel
118
+
119
  original_load = ToolRAGModel.load_tool_desc_embedding
120
+
121
  def patched_load(self, tooluniverse):
122
  try:
123
  if not os.path.exists(CONFIG["embedding_filename"]):
124
  logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
125
  return False
126
+
127
  self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
128
+ if self.tool_desc_embedding is None:
129
+ logger.error("Embedding is None, aborting.")
130
+ return False
131
+
132
+ # Ensure tools is a list (in case it's a generator)
133
+ tools = list(tooluniverse.get_all_tools()) if hasattr(tooluniverse, 'get_all_tools') else []
134
  current_count = len(tools)
135
  embedding_count = len(self.tool_desc_embedding)
136
+
137
  if current_count != embedding_count:
138
  logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
139
+
140
  if current_count < embedding_count:
141
  self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
142
  logger.info(f"Truncated embeddings to match {current_count} tools")
 
145
  padding = [last_embedding] * (current_count - embedding_count)
146
  self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding)
147
  logger.info(f"Padded embeddings to match {current_count} tools")
148
+
149
  return True
150
+
151
  except Exception as e:
152
  logger.error(f"Failed to load embeddings: {str(e)}")
153
  return False
154
+
155
  ToolRAGModel.load_tool_desc_embedding = patched_load
156
  logger.info("Successfully patched embedding loading")
157
+
158
  except Exception as e:
159
  logger.error(f"Failed to patch embedding loading: {str(e)}")
160
  raise
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def update_model_parameters(agent, enable_finish, enable_rag, enable_summary,
163
  init_rag_num, step_rag_num, skip_last_k,
164
  summary_mode, summary_skip_last_k, summary_context_length,