Update app.py
Browse files
app.py
CHANGED
@@ -11,73 +11,60 @@ from tooluniverse import ToolUniverse
|
|
11 |
CONFIG = {
|
12 |
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
13 |
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
14 |
-
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.
|
15 |
"local_dir": "./models",
|
16 |
"tool_files": {
|
17 |
"new_tool": "./data/new_tool.json"
|
18 |
}
|
19 |
}
|
20 |
|
21 |
-
# Logging
|
22 |
logging.basicConfig(level=logging.INFO)
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
25 |
def prepare_tool_files():
|
26 |
os.makedirs("./data", exist_ok=True)
|
27 |
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
|
28 |
-
logger.info("Generating tool list using ToolUniverse...")
|
29 |
tu = ToolUniverse()
|
30 |
tools = tu.get_all_tools()
|
31 |
with open(CONFIG["tool_files"]["new_tool"], "w") as f:
|
32 |
json.dump(tools, f, indent=2)
|
33 |
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
|
34 |
|
35 |
-
|
36 |
def download_model_files():
|
37 |
os.makedirs(CONFIG["local_dir"], exist_ok=True)
|
38 |
-
|
39 |
|
40 |
snapshot_download(
|
41 |
repo_id=CONFIG["model_name"],
|
42 |
local_dir=os.path.join(CONFIG["local_dir"], CONFIG["model_name"]),
|
43 |
resume_download=True
|
44 |
)
|
45 |
-
|
46 |
snapshot_download(
|
47 |
repo_id=CONFIG["rag_model_name"],
|
48 |
local_dir=os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]),
|
49 |
resume_download=True
|
50 |
)
|
51 |
|
52 |
-
try:
|
53 |
-
hf_hub_download(
|
54 |
-
repo_id=CONFIG["rag_model_name"],
|
55 |
-
filename=CONFIG["embedding_filename"],
|
56 |
-
local_dir=CONFIG["local_dir"],
|
57 |
-
resume_download=True
|
58 |
-
)
|
59 |
-
print("Embeddings file downloaded successfully")
|
60 |
-
except Exception as e:
|
61 |
-
print(f"Could not download embeddings file: {e}")
|
62 |
-
print("Will attempt to generate it instead")
|
63 |
-
|
64 |
def generate_embeddings(agent):
|
65 |
-
|
66 |
-
|
67 |
-
if os.path.exists(embedding_path):
|
68 |
-
print("Embeddings file already exists")
|
69 |
-
return
|
70 |
-
|
71 |
-
print("Generating missing tool embeddings...")
|
72 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
tools = agent.tooluniverse.get_all_tools()
|
74 |
-
|
75 |
-
embeddings = agent.rag_model.generate_embeddings(
|
76 |
-
torch.save(embeddings,
|
77 |
agent.rag_model.tool_desc_embedding = embeddings
|
78 |
-
|
|
|
79 |
except Exception as e:
|
80 |
-
|
81 |
raise
|
82 |
|
83 |
class TxAgentApp:
|
@@ -124,7 +111,6 @@ class TxAgentApp:
|
|
124 |
max_round=30
|
125 |
):
|
126 |
response += chunk
|
127 |
-
|
128 |
return history + [(message, response)]
|
129 |
except Exception as e:
|
130 |
return history + [(message, f"Error: {str(e)}")]
|
|
|
11 |
CONFIG = {
|
12 |
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
13 |
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
14 |
+
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_e27fb393f3144ec28f620f33d4d79911.pt",
|
15 |
"local_dir": "./models",
|
16 |
"tool_files": {
|
17 |
"new_tool": "./data/new_tool.json"
|
18 |
}
|
19 |
}
|
20 |
|
21 |
+
# Logging
|
22 |
logging.basicConfig(level=logging.INFO)
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
25 |
def prepare_tool_files():
|
26 |
os.makedirs("./data", exist_ok=True)
|
27 |
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
|
|
|
28 |
tu = ToolUniverse()
|
29 |
tools = tu.get_all_tools()
|
30 |
with open(CONFIG["tool_files"]["new_tool"], "w") as f:
|
31 |
json.dump(tools, f, indent=2)
|
32 |
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
|
33 |
|
|
|
34 |
def download_model_files():
|
35 |
os.makedirs(CONFIG["local_dir"], exist_ok=True)
|
36 |
+
logger.info("Downloading model files...")
|
37 |
|
38 |
snapshot_download(
|
39 |
repo_id=CONFIG["model_name"],
|
40 |
local_dir=os.path.join(CONFIG["local_dir"], CONFIG["model_name"]),
|
41 |
resume_download=True
|
42 |
)
|
|
|
43 |
snapshot_download(
|
44 |
repo_id=CONFIG["rag_model_name"],
|
45 |
local_dir=os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]),
|
46 |
resume_download=True
|
47 |
)
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def generate_embeddings(agent):
|
50 |
+
"""Generates and assigns embeddings manually"""
|
51 |
+
path = os.path.join(CONFIG["local_dir"], CONFIG["embedding_filename"])
|
|
|
|
|
|
|
|
|
|
|
52 |
try:
|
53 |
+
if os.path.exists(path):
|
54 |
+
logger.info("Embedding file already exists, loading...")
|
55 |
+
agent.rag_model.tool_desc_embedding = torch.load(path)
|
56 |
+
return
|
57 |
+
|
58 |
+
logger.info("Generating embeddings and saving...")
|
59 |
tools = agent.tooluniverse.get_all_tools()
|
60 |
+
descs = [t["description"] for t in tools]
|
61 |
+
embeddings = agent.rag_model.generate_embeddings(descs)
|
62 |
+
torch.save(embeddings, path)
|
63 |
agent.rag_model.tool_desc_embedding = embeddings
|
64 |
+
logger.info("✅ Embeddings generated and saved")
|
65 |
+
|
66 |
except Exception as e:
|
67 |
+
logger.error(f"Failed to generate embeddings: {e}")
|
68 |
raise
|
69 |
|
70 |
class TxAgentApp:
|
|
|
111 |
max_round=30
|
112 |
):
|
113 |
response += chunk
|
|
|
114 |
return history + [(message, response)]
|
115 |
except Exception as e:
|
116 |
return history + [(message, f"Error: {str(e)}")]
|