Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,8 @@ from tooluniverse import ToolUniverse
|
|
12 |
# Patch PyTorch to allow loading old numpy pickles
|
13 |
torch.serialization.add_safe_globals([
|
14 |
numpy.core.multiarray._reconstruct,
|
15 |
-
numpy.ndarray
|
|
|
16 |
])
|
17 |
|
18 |
logging.basicConfig(
|
@@ -63,7 +64,11 @@ def patch_embedding_loading():
|
|
63 |
tools = tooluniverse.get_all_tools() if hasattr(tooluniverse, "get_all_tools") else getattr(tooluniverse, "tools", [])
|
64 |
if len(tools) != len(self.tool_desc_embedding):
|
65 |
logger.warning("Tool count mismatch.")
|
66 |
-
self.tool_desc_embedding
|
|
|
|
|
|
|
|
|
67 |
return True
|
68 |
except Exception as e:
|
69 |
logger.error(f"Embedding load failed: {e}")
|
|
|
12 |
# Patch PyTorch to allow loading old numpy pickles
|
13 |
torch.serialization.add_safe_globals([
|
14 |
numpy.core.multiarray._reconstruct,
|
15 |
+
numpy.ndarray,
|
16 |
+
numpy.dtype
|
17 |
])
|
18 |
|
19 |
logging.basicConfig(
|
|
|
64 |
tools = tooluniverse.get_all_tools() if hasattr(tooluniverse, "get_all_tools") else getattr(tooluniverse, "tools", [])
|
65 |
if len(tools) != len(self.tool_desc_embedding):
|
66 |
logger.warning("Tool count mismatch.")
|
67 |
+
if len(self.tool_desc_embedding) > len(tools):
|
68 |
+
self.tool_desc_embedding = self.tool_desc_embedding[:len(tools)]
|
69 |
+
else:
|
70 |
+
padding = self.tool_desc_embedding[-1].unsqueeze(0).repeat(len(tools) - len(self.tool_desc_embedding), 1)
|
71 |
+
self.tool_desc_embedding = torch.cat([self.tool_desc_embedding, padding], dim=0)
|
72 |
return True
|
73 |
except Exception as e:
|
74 |
logger.error(f"Embedding load failed: {e}")
|