Ali2206 commited on
Commit
b8c0ae3
·
verified ·
1 Parent(s): 9cb1bd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -8
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import gradio as gr
3
  from txagent import TxAgent
 
4
 
5
  # ========== Configuration ==========
6
  current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -17,7 +18,8 @@ MODEL_CONFIG = {
17
  'force_finish': True,
18
  'enable_checker': True,
19
  'step_rag_num': 10,
20
- 'seed': 100
 
21
  }
22
  }
23
 
@@ -55,7 +57,7 @@ class TxAgentApplication:
55
  return "Model already initialized"
56
 
57
  try:
58
- # Initialize the agent
59
  self.agent = TxAgent(
60
  MODEL_CONFIG['model_name'],
61
  MODEL_CONFIG['rag_model_name'],
@@ -67,11 +69,18 @@ class TxAgentApplication:
67
  try:
68
  self.agent.init_model()
69
  except Exception as e:
70
- # Handle specific tool embedding error
71
  if "No such file or directory" in str(e) and "tool_embedding" in str(e):
72
- return ("Error: Missing tool embedding file. "
73
- "Please ensure the RAG model files are properly downloaded.")
74
- raise
 
 
 
 
 
 
 
 
75
 
76
  self.is_initialized = True
77
  self.initialization_error = None
@@ -81,6 +90,25 @@ class TxAgentApplication:
81
  self.initialization_error = str(e)
82
  return f"Initialization failed: {str(e)}"
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def chat(self, message, chat_history):
85
  if not self.is_initialized:
86
  if self.initialization_error:
@@ -189,5 +217,6 @@ if __name__ == "__main__":
189
  try:
190
  interface.queue().launch(**launch_params)
191
  except Exception as e:
192
- print(f"Error launching interface: {e}")
193
- interface.launch(**launch_params) # Fallback without queue
 
 
1
  import os
2
  import gradio as gr
3
  from txagent import TxAgent
4
+ import torch
5
 
6
  # ========== Configuration ==========
7
  current_dir = os.path.dirname(os.path.abspath(__file__))
 
18
  'force_finish': True,
19
  'enable_checker': True,
20
  'step_rag_num': 10,
21
+ 'seed': 100,
22
+ 'enable_rag': False # Disable RAG until we resolve the embedding issue
23
  }
24
  }
25
 
 
57
  return "Model already initialized"
58
 
59
  try:
60
+ # Initialize the agent with RAG disabled initially
61
  self.agent = TxAgent(
62
  MODEL_CONFIG['model_name'],
63
  MODEL_CONFIG['rag_model_name'],
 
69
  try:
70
  self.agent.init_model()
71
  except Exception as e:
 
72
  if "No such file or directory" in str(e) and "tool_embedding" in str(e):
73
+ # Try to generate embeddings if missing
74
+ try:
75
+ self._generate_missing_embeddings()
76
+ # Retry initialization
77
+ self.agent.init_model()
78
+ except Exception as gen_e:
79
+ return (f"Error: Could not generate missing embeddings. "
80
+ f"Please ensure all model files are properly downloaded. "
81
+ f"Technical details: {str(gen_e)}")
82
+ else:
83
+ raise
84
 
85
  self.is_initialized = True
86
  self.initialization_error = None
 
90
  self.initialization_error = str(e)
91
  return f"Initialization failed: {str(e)}"
92
 
93
+ def _generate_missing_embeddings(self):
94
+ """Attempt to generate missing tool embeddings"""
95
+ if not hasattr(self.agent, 'rag_model'):
96
+ raise ValueError("RAG model not initialized")
97
+
98
+ # Get the tools from the tool universe
99
+ tools = self.agent.tooluniverse.get_all_tools()
100
+ tool_descriptions = [tool['description'] for tool in tools]
101
+
102
+ # Generate embeddings using the RAG model
103
+ embeddings = self.agent.rag_model.generate_embeddings(tool_descriptions)
104
+
105
+ # Save the embeddings
106
+ embedding_path = f"{MODEL_CONFIG['rag_model_name']}_tool_embedding_e27fb393f3144ec28f620f33d4d79911.pt"
107
+ torch.save(embeddings, embedding_path)
108
+
109
+ # Update the RAG model to use the new embeddings
110
+ self.agent.rag_model.tool_desc_embedding = embeddings
111
+
112
  def chat(self, message, chat_history):
113
  if not self.is_initialized:
114
  if self.initialization_error:
 
217
  try:
218
  interface.queue().launch(**launch_params)
219
  except Exception as e:
220
+ print(f"Error launching with queue: {e}")
221
+ print("Falling back to non-queued launch")
222
+ interface.launch(**launch_params)