Ali2206 commited on
Commit
9fcd791
·
verified ·
1 Parent(s): c10dd55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import gradio as gr
2
  import logging
3
  import os
 
4
 
 
5
  logging.basicConfig(level=logging.INFO)
6
  logger = logging.getLogger(__name__)
7
 
8
  tx_app = None
9
- TOOL_CACHE_PATH = "/home/user/.cache/tool_embeddings_done" # flag file for skip
10
 
11
  def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
12
  global tx_app
@@ -44,7 +46,7 @@ def respond(message, chat_history, temperature, max_new_tokens, max_tokens, mult
44
  logger.error(f"Respond error: {e}")
45
  yield chat_history + [("", f"⚠️ Error: {e}")]
46
 
47
- # === Define Gradio interface ===
48
  with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
49
  gr.Markdown("# 🧠 TxAgent Biomedical Assistant")
50
 
@@ -74,15 +76,15 @@ with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
74
  chatbot
75
  )
76
 
77
- # === Safe model init block for vLLM + Hugging Face ===
78
  if __name__ == "__main__":
79
- import multiprocessing
80
  multiprocessing.set_start_method("spawn", force=True)
81
 
82
  from txagent import TxAgent
83
  from importlib.resources import files
 
84
 
85
- logger.info("🔥 Initializing TxAgent inside __main__...")
86
 
87
  tool_files = {
88
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
@@ -91,7 +93,6 @@ if __name__ == "__main__":
91
  "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
92
  }
93
 
94
- # Initialize agent
95
  tx_app = TxAgent(
96
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
97
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -109,18 +110,14 @@ if __name__ == "__main__":
109
  seed=42,
110
  enable_checker=True,
111
  enable_chat=False,
112
- additional_default_tools=["DirectResponse", "RequireClarification"]
 
113
  )
114
 
115
- # ✅ Only do tool embedding the first time
116
  if not os.path.exists(TOOL_CACHE_PATH):
117
- logger.info("🔧 First run: running full model + embedding")
118
- tx_app.init_model() # runs full setup
119
  os.makedirs(os.path.dirname(TOOL_CACHE_PATH), exist_ok=True)
120
  with open(TOOL_CACHE_PATH, "w") as f:
121
  f.write("done")
122
  else:
123
- logger.info("⚡️ Skipping tool embedding (cached)...")
124
- tx_app.init_model(skip_tool_embedding=True) # assumes this param is supported
125
-
126
- logger.info("✅ TxAgent is ready!")
 
1
  import gradio as gr
2
  import logging
3
  import os
4
+ import multiprocessing
5
 
6
+ # Configure logging
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
  tx_app = None
11
+ TOOL_CACHE_PATH = "/home/user/.cache/tool_embeddings_done"
12
 
13
  def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
14
  global tx_app
 
46
  logger.error(f"Respond error: {e}")
47
  yield chat_history + [("", f"⚠️ Error: {e}")]
48
 
49
+ # === Gradio UI ===
50
  with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
51
  gr.Markdown("# 🧠 TxAgent Biomedical Assistant")
52
 
 
76
  chatbot
77
  )
78
 
79
+ # === Safe model init ===
80
  if __name__ == "__main__":
 
81
  multiprocessing.set_start_method("spawn", force=True)
82
 
83
  from txagent import TxAgent
84
  from importlib.resources import files
85
+ from patched_tooluniverse import PatchedToolUniverse
86
 
87
+ logger.info("🔥 Initializing patched TxAgent with tool embedding skip")
88
 
89
  tool_files = {
90
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
 
93
  "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
94
  }
95
 
 
96
  tx_app = TxAgent(
97
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
98
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
110
  seed=42,
111
  enable_checker=True,
112
  enable_chat=False,
113
+ additional_default_tools=["DirectResponse", "RequireClarification"],
114
+ tooluniverse_class=PatchedToolUniverse # ✅ Custom patch!
115
  )
116
 
 
117
  if not os.path.exists(TOOL_CACHE_PATH):
118
+ tx_app.init_model()
 
119
  os.makedirs(os.path.dirname(TOOL_CACHE_PATH), exist_ok=True)
120
  with open(TOOL_CACHE_PATH, "w") as f:
121
  f.write("done")
122
  else:
123
+ tx_app.init_model(skip_tool_embedding=True) # only if this param is supported