Ali2206 commited on
Commit
0151c98
·
verified ·
1 Parent(s): ad3299d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -2,11 +2,16 @@ import os
2
  import json
3
  import torch
4
  import logging
 
5
  import gradio as gr
 
6
  from importlib.resources import files
7
  from txagent import TxAgent
8
  from tooluniverse import ToolUniverse
9
 
 
 
 
10
  logging.basicConfig(
11
  level=logging.INFO,
12
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
@@ -48,6 +53,9 @@ def patch_embedding_loading():
48
  if not os.path.exists(CONFIG["embedding_filename"]):
49
  return False
50
  self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
 
 
 
51
 
52
  tools = tooluniverse.get_all_tools() if hasattr(tooluniverse, "get_all_tools") else getattr(tooluniverse, "tools", [])
53
  if len(tools) != len(self.tool_desc_embedding):
@@ -75,6 +83,11 @@ def create_agent():
75
  patch_embedding_loading()
76
  prepare_tool_files()
77
  try:
 
 
 
 
 
78
  agent = TxAgent(
79
  CONFIG["model_name"],
80
  CONFIG["rag_model_name"],
@@ -83,7 +96,7 @@ def create_agent():
83
  enable_checker=True,
84
  step_rag_num=10,
85
  seed=42,
86
- additional_default_tools=["DirectResponse", "RequireClarification"]
87
  )
88
  agent.init_model()
89
  return agent
@@ -91,7 +104,6 @@ def create_agent():
91
  logger.error(f"Agent initialization failed: {e}")
92
  raise
93
 
94
- # ✅ FIXED: Proper message formatting
95
  def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
96
  if not isinstance(msg, str) or len(msg.strip()) <= 10:
97
  return chat_history + [{"role": "assistant", "content": "Hi, I am TxAgent. Please provide a valid message longer than 10 characters."}]
@@ -147,4 +159,4 @@ def main():
147
  demo.launch(share=False)
148
 
149
  if __name__ == "__main__":
150
- main()
 
2
  import json
3
  import torch
4
  import logging
5
+ import numpy
6
  import gradio as gr
7
+ import torch.serialization
8
  from importlib.resources import files
9
  from txagent import TxAgent
10
  from tooluniverse import ToolUniverse
11
 
12
+ # Patch PyTorch to allow loading old numpy pickles
13
+ torch.serialization.add_safe_globals([numpy.core.multiarray._reconstruct])
14
+
15
  logging.basicConfig(
16
  level=logging.INFO,
17
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
 
53
  if not os.path.exists(CONFIG["embedding_filename"]):
54
  return False
55
  self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
56
+ if self.tool_desc_embedding is None:
57
+ logger.error("Tool embedding file could not be loaded.")
58
+ return False
59
 
60
  tools = tooluniverse.get_all_tools() if hasattr(tooluniverse, "get_all_tools") else getattr(tooluniverse, "tools", [])
61
  if len(tools) != len(self.tool_desc_embedding):
 
83
  patch_embedding_loading()
84
  prepare_tool_files()
85
  try:
86
+ tu = ToolUniverse()
87
+ tools = tu.get_all_tools() if hasattr(tu, "get_all_tools") else getattr(tu, "tools", [])
88
+ available_tool_names = [t["name"] for t in tools]
89
+ additional_default_tools = [t for t in ["DirectResponse", "RequireClarification"] if t in available_tool_names]
90
+
91
  agent = TxAgent(
92
  CONFIG["model_name"],
93
  CONFIG["rag_model_name"],
 
96
  enable_checker=True,
97
  step_rag_num=10,
98
  seed=42,
99
+ additional_default_tools=additional_default_tools
100
  )
101
  agent.init_model()
102
  return agent
 
104
  logger.error(f"Agent initialization failed: {e}")
105
  raise
106
 
 
107
  def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
108
  if not isinstance(msg, str) or len(msg.strip()) <= 10:
109
  return chat_history + [{"role": "assistant", "content": "Hi, I am TxAgent. Please provide a valid message longer than 10 characters."}]
 
159
  demo.launch(share=False)
160
 
161
  if __name__ == "__main__":
162
+ main()