Update app.py
Browse files
app.py
CHANGED
@@ -30,6 +30,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
30 |
CONFIG = {
|
31 |
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
32 |
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
|
|
33 |
"tool_files": {
|
34 |
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
|
35 |
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
|
@@ -39,6 +40,16 @@ CONFIG = {
|
|
39 |
}
|
40 |
}
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def prepare_tool_files():
|
43 |
os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
|
44 |
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
|
@@ -63,6 +74,10 @@ def create_agent():
|
|
63 |
seed=42,
|
64 |
additional_default_tools=["DirectResponse", "RequireClarification"]
|
65 |
)
|
|
|
|
|
|
|
|
|
66 |
agent.init_model()
|
67 |
return agent
|
68 |
except Exception as e:
|
@@ -130,8 +145,7 @@ def main():
|
|
130 |
global agent
|
131 |
agent = create_agent()
|
132 |
demo = create_demo(agent)
|
133 |
-
|
134 |
-
exit()
|
135 |
|
136 |
if __name__ == "__main__":
|
137 |
main()
|
|
|
30 |
CONFIG = {
|
31 |
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
32 |
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
33 |
+
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding.pt",
|
34 |
"tool_files": {
|
35 |
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
|
36 |
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
|
|
|
40 |
}
|
41 |
}
|
42 |
|
43 |
+
def generate_tool_embeddings(agent):
|
44 |
+
tu = ToolUniverse(tool_files=CONFIG["tool_files"])
|
45 |
+
tu.load_tools()
|
46 |
+
embedding_tensor = agent.rag_model.generate_tool_desc_embedding(tu)
|
47 |
+
if embedding_tensor is not None:
|
48 |
+
torch.save(embedding_tensor, CONFIG["embedding_filename"])
|
49 |
+
logger.info(f"Saved new embedding tensor to {CONFIG['embedding_filename']}")
|
50 |
+
else:
|
51 |
+
logger.warning("Embedding generation returned None")
|
52 |
+
|
53 |
def prepare_tool_files():
|
54 |
os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
|
55 |
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
|
|
|
74 |
seed=42,
|
75 |
additional_default_tools=["DirectResponse", "RequireClarification"]
|
76 |
)
|
77 |
+
if not os.path.exists(CONFIG["embedding_filename"]):
|
78 |
+
generate_tool_embeddings(agent)
|
79 |
+
else:
|
80 |
+
logger.info("Embedding file found. Skipping embedding generation.")
|
81 |
agent.init_model()
|
82 |
return agent
|
83 |
except Exception as e:
|
|
|
145 |
global agent
|
146 |
agent = create_agent()
|
147 |
demo = create_demo(agent)
|
148 |
+
demo.launch(share=True)
|
|
|
149 |
|
150 |
if __name__ == "__main__":
|
151 |
main()
|