Ali2206 commited on
Commit
99cd953
·
verified ·
1 Parent(s): 28d0fa1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -12,7 +12,8 @@ from tooluniverse import ToolUniverse
12
  # Patch PyTorch to allow loading old numpy pickles
13
  torch.serialization.add_safe_globals([
14
  numpy.core.multiarray._reconstruct,
15
- numpy.ndarray
 
16
  ])
17
 
18
  logging.basicConfig(
@@ -63,7 +64,11 @@ def patch_embedding_loading():
63
  tools = tooluniverse.get_all_tools() if hasattr(tooluniverse, "get_all_tools") else getattr(tooluniverse, "tools", [])
64
  if len(tools) != len(self.tool_desc_embedding):
65
  logger.warning("Tool count mismatch.")
66
- self.tool_desc_embedding = self.tool_desc_embedding[:len(tools)]
 
 
 
 
67
  return True
68
  except Exception as e:
69
  logger.error(f"Embedding load failed: {e}")
 
12
  # Patch PyTorch to allow loading old numpy pickles
13
  torch.serialization.add_safe_globals([
14
  numpy.core.multiarray._reconstruct,
15
+ numpy.ndarray,
16
+ numpy.dtype
17
  ])
18
 
19
  logging.basicConfig(
 
64
  tools = tooluniverse.get_all_tools() if hasattr(tooluniverse, "get_all_tools") else getattr(tooluniverse, "tools", [])
65
  if len(tools) != len(self.tool_desc_embedding):
66
  logger.warning("Tool count mismatch.")
67
+ if len(self.tool_desc_embedding) > len(tools):
68
+ self.tool_desc_embedding = self.tool_desc_embedding[:len(tools)]
69
+ else:
70
+ padding = self.tool_desc_embedding[-1].unsqueeze(0).repeat(len(tools) - len(self.tool_desc_embedding), 1)
71
+ self.tool_desc_embedding = torch.cat([self.tool_desc_embedding, padding], dim=0)
72
  return True
73
  except Exception as e:
74
  logger.error(f"Embedding load failed: {e}")