Ali2206 commited on
Commit
dffc0b0
Β·
verified Β·
1 Parent(s): c84492a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -28
app.py CHANGED
@@ -6,11 +6,11 @@ from txagent import TxAgent
6
  import gradio as gr
7
  from tooluniverse import ToolUniverse
8
 
9
- # Configuration - Using remote Hugging Face models
10
  CONFIG = {
11
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
12
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
13
- "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
14
  "tool_files": {
15
  "new_tool": "./data/new_tool.json"
16
  }
@@ -35,26 +35,19 @@ def prepare_tool_files():
35
 
36
  def load_embeddings(agent):
37
  embedding_path = CONFIG["embedding_filename"]
38
- if os.path.exists(embedding_path):
39
- logger.info("βœ… Loading pre-generated embeddings file")
40
- try:
41
  embeddings = torch.load(embedding_path)
42
  agent.rag_model.tool_desc_embedding = embeddings
43
- return
44
- except Exception as e:
45
- logger.error(f"Failed to load embeddings: {e}")
46
-
47
- logger.info("Generating tool embeddings...")
48
- try:
49
- tools = agent.tooluniverse.get_all_tools()
50
- descriptions = [tool["description"] for tool in tools]
51
- embeddings = agent.rag_model.generate_embeddings(descriptions)
52
- torch.save(embeddings, embedding_path)
53
- agent.rag_model.tool_desc_embedding = embeddings
54
- logger.info(f"Embeddings saved to {embedding_path}")
55
  except Exception as e:
56
- logger.error(f"Failed to generate embeddings: {e}")
57
- raise
58
 
59
  class TxAgentApp:
60
  def __init__(self):
@@ -66,9 +59,9 @@ class TxAgentApp:
66
  return "βœ… Already initialized"
67
 
68
  try:
69
- logger.info("Initializing TxAgent with remote models...")
70
 
71
- # Initialize without local_files_only parameter
72
  self.agent = TxAgent(
73
  CONFIG["model_name"],
74
  CONFIG["rag_model_name"],
@@ -80,14 +73,18 @@ class TxAgentApp:
80
  additional_default_tools=["DirectResponse", "RequireClarification"]
81
  )
82
 
83
- logger.info("Loading models from Hugging Face Hub...")
 
84
  self.agent.init_model()
85
 
86
- logger.info("Preparing embeddings...")
87
- load_embeddings(self.agent)
 
 
88
 
89
  self.is_initialized = True
90
- return "βœ… TxAgent initialized successfully (using remote models)"
 
91
  except Exception as e:
92
  logger.error(f"Initialization failed: {str(e)}")
93
  return f"❌ Initialization failed: {str(e)}"
@@ -126,7 +123,6 @@ def create_interface():
126
  ) as demo:
127
  gr.Markdown("""
128
  # 🧠 TxAgent: Therapeutic Reasoning AI
129
- ### (Running with remote Hugging Face models)
130
  """)
131
 
132
  with gr.Row():
@@ -168,7 +164,14 @@ if __name__ == "__main__":
168
  try:
169
  logger.info("Starting application...")
170
 
171
- # Prepare local tool files
 
 
 
 
 
 
 
172
  prepare_tool_files()
173
 
174
  # Launch interface
@@ -179,5 +182,5 @@ if __name__ == "__main__":
179
  share=False
180
  )
181
  except Exception as e:
182
- logger.error(f"Fatal error: {str(e)}")
183
  raise
 
6
  import gradio as gr
7
  from tooluniverse import ToolUniverse
8
 
9
+ # Configuration - Using your existing embedding file
10
  CONFIG = {
11
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
12
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
13
+ "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt", # Your exact filename
14
  "tool_files": {
15
  "new_tool": "./data/new_tool.json"
16
  }
 
35
 
36
  def load_embeddings(agent):
37
  embedding_path = CONFIG["embedding_filename"]
38
+ try:
39
+ if os.path.exists(embedding_path):
40
+ logger.info(f"βœ… Loading existing embeddings from {embedding_path}")
41
  embeddings = torch.load(embedding_path)
42
  agent.rag_model.tool_desc_embedding = embeddings
43
+ return True
44
+ else:
45
+ logger.error(f"❌ Embedding file not found at {embedding_path}")
46
+ logger.info("Please ensure the embedding file is in the root directory")
47
+ return False
 
 
 
 
 
 
 
48
  except Exception as e:
49
+ logger.error(f"Failed to load embeddings: {str(e)}")
50
+ return False
51
 
52
  class TxAgentApp:
53
  def __init__(self):
 
59
  return "βœ… Already initialized"
60
 
61
  try:
62
+ logger.info("Initializing TxAgent...")
63
 
64
+ # Initialize TxAgent
65
  self.agent = TxAgent(
66
  CONFIG["model_name"],
67
  CONFIG["rag_model_name"],
 
73
  additional_default_tools=["DirectResponse", "RequireClarification"]
74
  )
75
 
76
+ # Initialize models
77
+ logger.info("Loading models...")
78
  self.agent.init_model()
79
 
80
+ # Load embeddings
81
+ logger.info("Loading embeddings...")
82
+ if not load_embeddings(self.agent):
83
+ return "❌ Failed to load embeddings - check logs"
84
 
85
  self.is_initialized = True
86
+ return "βœ… TxAgent initialized successfully"
87
+
88
  except Exception as e:
89
  logger.error(f"Initialization failed: {str(e)}")
90
  return f"❌ Initialization failed: {str(e)}"
 
123
  ) as demo:
124
  gr.Markdown("""
125
  # 🧠 TxAgent: Therapeutic Reasoning AI
 
126
  """)
127
 
128
  with gr.Row():
 
164
  try:
165
  logger.info("Starting application...")
166
 
167
+ # Verify embedding file exists
168
+ if not os.path.exists(CONFIG["embedding_filename"]):
169
+ logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
170
+ logger.info("Please ensure the file is in the root directory")
171
+ else:
172
+ logger.info(f"Found embedding file: {CONFIG['embedding_filename']}")
173
+
174
+ # Prepare tool files
175
  prepare_tool_files()
176
 
177
  # Launch interface
 
182
  share=False
183
  )
184
  except Exception as e:
185
+ logger.error(f"Application failed to start: {str(e)}")
186
  raise