Update app.py
Browse files
app.py
CHANGED
@@ -30,7 +30,6 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
30 |
CONFIG = {
|
31 |
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
32 |
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
33 |
-
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding.pt",
|
34 |
"tool_files": {
|
35 |
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
|
36 |
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
|
@@ -53,22 +52,18 @@ def prepare_tool_files():
|
|
53 |
|
54 |
def create_agent():
|
55 |
prepare_tool_files()
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
return agent
|
69 |
-
except Exception as e:
|
70 |
-
logger.error(f"Agent initialization failed: {e}")
|
71 |
-
raise
|
72 |
|
73 |
def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
|
74 |
if not isinstance(msg, str) or len(msg.strip()) <= 10:
|
@@ -131,7 +126,7 @@ def main():
|
|
131 |
global agent
|
132 |
agent = create_agent()
|
133 |
demo = create_demo(agent)
|
134 |
-
demo.queue(
|
135 |
|
136 |
if __name__ == "__main__":
|
137 |
-
main()
|
|
|
30 |
CONFIG = {
|
31 |
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
32 |
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
|
|
33 |
"tool_files": {
|
34 |
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
|
35 |
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
|
|
|
52 |
|
53 |
def create_agent():
|
54 |
prepare_tool_files()
|
55 |
+
agent = TxAgent(
|
56 |
+
CONFIG["model_name"],
|
57 |
+
CONFIG["rag_model_name"],
|
58 |
+
tool_files_dict=CONFIG["tool_files"],
|
59 |
+
force_finish=True,
|
60 |
+
enable_checker=True,
|
61 |
+
step_rag_num=10,
|
62 |
+
seed=42,
|
63 |
+
additional_default_tools=["DirectResponse", "RequireClarification"]
|
64 |
+
)
|
65 |
+
agent.init_model()
|
66 |
+
return agent
|
|
|
|
|
|
|
|
|
67 |
|
68 |
def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
|
69 |
if not isinstance(msg, str) or len(msg.strip()) <= 10:
|
|
|
126 |
global agent
|
127 |
agent = create_agent()
|
128 |
demo = create_demo(agent)
|
129 |
+
demo.queue().launch(share=True)
|
130 |
|
131 |
if __name__ == "__main__":
|
132 |
+
main()
|