Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
from txagent import TxAgent
|
|
|
4 |
|
5 |
# ========== Configuration ==========
|
6 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
@@ -17,7 +18,8 @@ MODEL_CONFIG = {
|
|
17 |
'force_finish': True,
|
18 |
'enable_checker': True,
|
19 |
'step_rag_num': 10,
|
20 |
-
'seed': 100
|
|
|
21 |
}
|
22 |
}
|
23 |
|
@@ -55,7 +57,7 @@ class TxAgentApplication:
|
|
55 |
return "Model already initialized"
|
56 |
|
57 |
try:
|
58 |
-
# Initialize the agent
|
59 |
self.agent = TxAgent(
|
60 |
MODEL_CONFIG['model_name'],
|
61 |
MODEL_CONFIG['rag_model_name'],
|
@@ -67,11 +69,18 @@ class TxAgentApplication:
|
|
67 |
try:
|
68 |
self.agent.init_model()
|
69 |
except Exception as e:
|
70 |
-
# Handle specific tool embedding error
|
71 |
if "No such file or directory" in str(e) and "tool_embedding" in str(e):
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
self.is_initialized = True
|
77 |
self.initialization_error = None
|
@@ -81,6 +90,25 @@ class TxAgentApplication:
|
|
81 |
self.initialization_error = str(e)
|
82 |
return f"Initialization failed: {str(e)}"
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
def chat(self, message, chat_history):
|
85 |
if not self.is_initialized:
|
86 |
if self.initialization_error:
|
@@ -189,5 +217,6 @@ if __name__ == "__main__":
|
|
189 |
try:
|
190 |
interface.queue().launch(**launch_params)
|
191 |
except Exception as e:
|
192 |
-
print(f"Error launching
|
193 |
-
|
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
from txagent import TxAgent
|
4 |
+
import torch
|
5 |
|
6 |
# ========== Configuration ==========
|
7 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
18 |
'force_finish': True,
|
19 |
'enable_checker': True,
|
20 |
'step_rag_num': 10,
|
21 |
+
'seed': 100,
|
22 |
+
'enable_rag': False # Disable RAG until we resolve the embedding issue
|
23 |
}
|
24 |
}
|
25 |
|
|
|
57 |
return "Model already initialized"
|
58 |
|
59 |
try:
|
60 |
+
# Initialize the agent with RAG disabled initially
|
61 |
self.agent = TxAgent(
|
62 |
MODEL_CONFIG['model_name'],
|
63 |
MODEL_CONFIG['rag_model_name'],
|
|
|
69 |
try:
|
70 |
self.agent.init_model()
|
71 |
except Exception as e:
|
|
|
72 |
if "No such file or directory" in str(e) and "tool_embedding" in str(e):
|
73 |
+
# Try to generate embeddings if missing
|
74 |
+
try:
|
75 |
+
self._generate_missing_embeddings()
|
76 |
+
# Retry initialization
|
77 |
+
self.agent.init_model()
|
78 |
+
except Exception as gen_e:
|
79 |
+
return (f"Error: Could not generate missing embeddings. "
|
80 |
+
f"Please ensure all model files are properly downloaded. "
|
81 |
+
f"Technical details: {str(gen_e)}")
|
82 |
+
else:
|
83 |
+
raise
|
84 |
|
85 |
self.is_initialized = True
|
86 |
self.initialization_error = None
|
|
|
90 |
self.initialization_error = str(e)
|
91 |
return f"Initialization failed: {str(e)}"
|
92 |
|
93 |
+
def _generate_missing_embeddings(self):
|
94 |
+
"""Attempt to generate missing tool embeddings"""
|
95 |
+
if not hasattr(self.agent, 'rag_model'):
|
96 |
+
raise ValueError("RAG model not initialized")
|
97 |
+
|
98 |
+
# Get the tools from the tool universe
|
99 |
+
tools = self.agent.tooluniverse.get_all_tools()
|
100 |
+
tool_descriptions = [tool['description'] for tool in tools]
|
101 |
+
|
102 |
+
# Generate embeddings using the RAG model
|
103 |
+
embeddings = self.agent.rag_model.generate_embeddings(tool_descriptions)
|
104 |
+
|
105 |
+
# Save the embeddings
|
106 |
+
embedding_path = f"{MODEL_CONFIG['rag_model_name']}_tool_embedding_e27fb393f3144ec28f620f33d4d79911.pt"
|
107 |
+
torch.save(embeddings, embedding_path)
|
108 |
+
|
109 |
+
# Update the RAG model to use the new embeddings
|
110 |
+
self.agent.rag_model.tool_desc_embedding = embeddings
|
111 |
+
|
112 |
def chat(self, message, chat_history):
|
113 |
if not self.is_initialized:
|
114 |
if self.initialization_error:
|
|
|
217 |
try:
|
218 |
interface.queue().launch(**launch_params)
|
219 |
except Exception as e:
|
220 |
+
print(f"Error launching with queue: {e}")
|
221 |
+
print("Falling back to non-queued launch")
|
222 |
+
interface.launch(**launch_params)
|