Ali2206 commited on
Commit
a52dfd6
·
verified ·
1 Parent(s): e0a0615

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -19
app.py CHANGED
@@ -30,7 +30,6 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
30
  CONFIG = {
31
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
32
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
33
- "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding.pt",
34
  "tool_files": {
35
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
36
  "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
@@ -53,22 +52,18 @@ def prepare_tool_files():
53
 
54
  def create_agent():
55
  prepare_tool_files()
56
- try:
57
- agent = TxAgent(
58
- CONFIG["model_name"],
59
- CONFIG["rag_model_name"],
60
- tool_files_dict=CONFIG["tool_files"],
61
- force_finish=True,
62
- enable_checker=True,
63
- step_rag_num=10,
64
- seed=42,
65
- additional_default_tools=["DirectResponse", "RequireClarification"]
66
- )
67
- agent.init_model()
68
- return agent
69
- except Exception as e:
70
- logger.error(f"Agent initialization failed: {e}")
71
- raise
72
 
73
  def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
74
  if not isinstance(msg, str) or len(msg.strip()) <= 10:
@@ -131,7 +126,7 @@ def main():
131
  global agent
132
  agent = create_agent()
133
  demo = create_demo(agent)
134
- demo.queue(concurrency_count=1, max_size=20).launch(share=True)
135
 
136
  if __name__ == "__main__":
137
- main()
 
30
  CONFIG = {
31
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
32
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
33
  "tool_files": {
34
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
35
  "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
 
52
 
53
  def create_agent():
54
  prepare_tool_files()
55
+ agent = TxAgent(
56
+ CONFIG["model_name"],
57
+ CONFIG["rag_model_name"],
58
+ tool_files_dict=CONFIG["tool_files"],
59
+ force_finish=True,
60
+ enable_checker=True,
61
+ step_rag_num=10,
62
+ seed=42,
63
+ additional_default_tools=["DirectResponse", "RequireClarification"]
64
+ )
65
+ agent.init_model()
66
+ return agent
 
 
 
 
67
 
68
  def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
69
  if not isinstance(msg, str) or len(msg.strip()) <= 10:
 
126
  global agent
127
  agent = create_agent()
128
  demo = create_demo(agent)
129
+ demo.queue().launch(share=True)
130
 
131
  if __name__ == "__main__":
132
+ main()