Update app.py
Browse files
app.py
CHANGED
@@ -30,7 +30,6 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
30 |
CONFIG = {
|
31 |
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
32 |
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
33 |
-
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding.pt",
|
34 |
"tool_files": {
|
35 |
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
|
36 |
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
|
@@ -40,16 +39,6 @@ CONFIG = {
|
|
40 |
}
|
41 |
}
|
42 |
|
43 |
-
def generate_tool_embeddings(agent):
|
44 |
-
tu = ToolUniverse(tool_files=CONFIG["tool_files"])
|
45 |
-
tu.load_tools()
|
46 |
-
embedding_tensor = agent.rag_model.load_tool_desc_embedding(tu)
|
47 |
-
if embedding_tensor is not None:
|
48 |
-
torch.save(embedding_tensor, CONFIG["embedding_filename"])
|
49 |
-
logger.info(f"Saved new embedding tensor to {CONFIG['embedding_filename']}")
|
50 |
-
else:
|
51 |
-
logger.warning("Embedding generation returned None")
|
52 |
-
|
53 |
def prepare_tool_files():
|
54 |
os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
|
55 |
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
|
@@ -74,8 +63,6 @@ def create_agent():
|
|
74 |
seed=42,
|
75 |
additional_default_tools=["DirectResponse", "RequireClarification"]
|
76 |
)
|
77 |
-
if not os.path.exists(CONFIG["embedding_filename"]):
|
78 |
-
generate_tool_embeddings(agent)
|
79 |
agent.init_model()
|
80 |
return agent
|
81 |
except Exception as e:
|
@@ -88,7 +75,7 @@ def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_ag
|
|
88 |
|
89 |
message = msg.strip()
|
90 |
chat_history.append({"role": "user", "content": message})
|
91 |
-
formatted_history = chat_history
|
92 |
|
93 |
try:
|
94 |
response_generator = agent.run_gradio_chat(
|
@@ -101,20 +88,18 @@ def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_ag
|
|
101 |
conversation=conversation,
|
102 |
max_round=max_round,
|
103 |
seed=42,
|
104 |
-
call_agent_level=
|
105 |
sub_agent_task=None
|
106 |
)
|
107 |
|
108 |
collected = ""
|
109 |
for chunk in response_generator:
|
110 |
-
if isinstance(chunk,
|
111 |
-
for msg in chunk:
|
112 |
-
if isinstance(msg, dict) and "content" in msg:
|
113 |
-
collected += msg["content"]
|
114 |
-
elif isinstance(chunk, dict) and "content" in chunk:
|
115 |
collected += chunk["content"]
|
116 |
elif isinstance(chunk, str):
|
117 |
collected += chunk
|
|
|
|
|
118 |
|
119 |
chat_history.append({"role": "assistant", "content": collected or "⚠️ No content returned."})
|
120 |
|
@@ -145,7 +130,8 @@ def main():
|
|
145 |
global agent
|
146 |
agent = create_agent()
|
147 |
demo = create_demo(agent)
|
148 |
-
|
|
|
149 |
|
150 |
if __name__ == "__main__":
|
151 |
main()
|
|
|
30 |
CONFIG = {
|
31 |
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
32 |
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
|
|
33 |
"tool_files": {
|
34 |
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
|
35 |
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
|
|
|
39 |
}
|
40 |
}
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def prepare_tool_files():
|
43 |
os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
|
44 |
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
|
|
|
63 |
seed=42,
|
64 |
additional_default_tools=["DirectResponse", "RequireClarification"]
|
65 |
)
|
|
|
|
|
66 |
agent.init_model()
|
67 |
return agent
|
68 |
except Exception as e:
|
|
|
75 |
|
76 |
message = msg.strip()
|
77 |
chat_history.append({"role": "user", "content": message})
|
78 |
+
formatted_history = [(m["role"], m["content"]) for m in chat_history if "role" in m and "content" in m]
|
79 |
|
80 |
try:
|
81 |
response_generator = agent.run_gradio_chat(
|
|
|
88 |
conversation=conversation,
|
89 |
max_round=max_round,
|
90 |
seed=42,
|
91 |
+
call_agent_level=None,
|
92 |
sub_agent_task=None
|
93 |
)
|
94 |
|
95 |
collected = ""
|
96 |
for chunk in response_generator:
|
97 |
+
if isinstance(chunk, dict) and "content" in chunk:
|
|
|
|
|
|
|
|
|
98 |
collected += chunk["content"]
|
99 |
elif isinstance(chunk, str):
|
100 |
collected += chunk
|
101 |
+
elif chunk is not None:
|
102 |
+
collected += str(chunk)
|
103 |
|
104 |
chat_history.append({"role": "assistant", "content": collected or "⚠️ No content returned."})
|
105 |
|
|
|
130 |
global agent
|
131 |
agent = create_agent()
|
132 |
demo = create_demo(agent)
|
133 |
+
print("Exiting after embedding generation. Please restart the Space manually.")
|
134 |
+
exit()
|
135 |
|
136 |
if __name__ == "__main__":
|
137 |
main()
|