Update app.py
Browse files
app.py
CHANGED
@@ -115,24 +115,28 @@ def patch_embedding_loading():
|
|
115 |
"""Monkey-patch the embedding loading functionality"""
|
116 |
try:
|
117 |
from txagent.toolrag import ToolRAGModel
|
118 |
-
|
119 |
original_load = ToolRAGModel.load_tool_desc_embedding
|
120 |
-
|
121 |
def patched_load(self, tooluniverse):
|
122 |
try:
|
123 |
if not os.path.exists(CONFIG["embedding_filename"]):
|
124 |
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
|
125 |
return False
|
126 |
-
|
127 |
self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
130 |
current_count = len(tools)
|
131 |
embedding_count = len(self.tool_desc_embedding)
|
132 |
-
|
133 |
if current_count != embedding_count:
|
134 |
logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
|
135 |
-
|
136 |
if current_count < embedding_count:
|
137 |
self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
|
138 |
logger.info(f"Truncated embeddings to match {current_count} tools")
|
@@ -141,51 +145,20 @@ def patch_embedding_loading():
|
|
141 |
padding = [last_embedding] * (current_count - embedding_count)
|
142 |
self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding)
|
143 |
logger.info(f"Padded embeddings to match {current_count} tools")
|
144 |
-
|
145 |
return True
|
146 |
-
|
147 |
except Exception as e:
|
148 |
logger.error(f"Failed to load embeddings: {str(e)}")
|
149 |
return False
|
150 |
-
|
151 |
ToolRAGModel.load_tool_desc_embedding = patched_load
|
152 |
logger.info("Successfully patched embedding loading")
|
153 |
-
|
154 |
except Exception as e:
|
155 |
logger.error(f"Failed to patch embedding loading: {str(e)}")
|
156 |
raise
|
157 |
|
158 |
-
def prepare_tool_files():
|
159 |
-
"""Ensure tool files exist and are populated"""
|
160 |
-
os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
|
161 |
-
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
|
162 |
-
logger.info("Generating tool list using ToolUniverse...")
|
163 |
-
tu = ToolUniverse()
|
164 |
-
tools = tu.get_all_tools() if hasattr(tu, 'get_all_tools') else []
|
165 |
-
with open(CONFIG["tool_files"]["new_tool"], "w") as f:
|
166 |
-
json.dump(tools, f, indent=2)
|
167 |
-
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
|
168 |
-
|
169 |
-
def create_agent():
|
170 |
-
"""Create and initialize the TxAgent"""
|
171 |
-
# Apply the embedding patch before creating the agent
|
172 |
-
patch_embedding_loading()
|
173 |
-
prepare_tool_files()
|
174 |
-
|
175 |
-
# Initialize the agent
|
176 |
-
agent = TxAgent(
|
177 |
-
CONFIG["model_name"],
|
178 |
-
CONFIG["rag_model_name"],
|
179 |
-
tool_files_dict=CONFIG["tool_files"],
|
180 |
-
force_finish=True,
|
181 |
-
enable_checker=True,
|
182 |
-
step_rag_num=10,
|
183 |
-
seed=100,
|
184 |
-
additional_default_tools=['DirectResponse', 'RequireClarification']
|
185 |
-
)
|
186 |
-
agent.init_model()
|
187 |
-
return agent
|
188 |
-
|
189 |
def update_model_parameters(agent, enable_finish, enable_rag, enable_summary,
|
190 |
init_rag_num, step_rag_num, skip_last_k,
|
191 |
summary_mode, summary_skip_last_k, summary_context_length,
|
|
|
115 |
"""Monkey-patch the embedding loading functionality"""
|
116 |
try:
|
117 |
from txagent.toolrag import ToolRAGModel
|
118 |
+
|
119 |
original_load = ToolRAGModel.load_tool_desc_embedding
|
120 |
+
|
121 |
def patched_load(self, tooluniverse):
|
122 |
try:
|
123 |
if not os.path.exists(CONFIG["embedding_filename"]):
|
124 |
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
|
125 |
return False
|
126 |
+
|
127 |
self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
|
128 |
+
if self.tool_desc_embedding is None:
|
129 |
+
logger.error("Embedding is None, aborting.")
|
130 |
+
return False
|
131 |
+
|
132 |
+
# Ensure tools is a list (in case it's a generator)
|
133 |
+
tools = list(tooluniverse.get_all_tools()) if hasattr(tooluniverse, 'get_all_tools') else []
|
134 |
current_count = len(tools)
|
135 |
embedding_count = len(self.tool_desc_embedding)
|
136 |
+
|
137 |
if current_count != embedding_count:
|
138 |
logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
|
139 |
+
|
140 |
if current_count < embedding_count:
|
141 |
self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
|
142 |
logger.info(f"Truncated embeddings to match {current_count} tools")
|
|
|
145 |
padding = [last_embedding] * (current_count - embedding_count)
|
146 |
self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding)
|
147 |
logger.info(f"Padded embeddings to match {current_count} tools")
|
148 |
+
|
149 |
return True
|
150 |
+
|
151 |
except Exception as e:
|
152 |
logger.error(f"Failed to load embeddings: {str(e)}")
|
153 |
return False
|
154 |
+
|
155 |
ToolRAGModel.load_tool_desc_embedding = patched_load
|
156 |
logger.info("Successfully patched embedding loading")
|
157 |
+
|
158 |
except Exception as e:
|
159 |
logger.error(f"Failed to patch embedding loading: {str(e)}")
|
160 |
raise
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
def update_model_parameters(agent, enable_finish, enable_rag, enable_summary,
|
163 |
init_rag_num, step_rag_num, skip_last_k,
|
164 |
summary_mode, summary_skip_last_k, summary_context_length,
|