|
import os
|
|
import sys
|
|
import gradio as gr
|
|
from multiprocessing import freeze_support
|
|
import importlib
|
|
import inspect
|
|
import json
|
|
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
|
|
|
|
|
|
import txagent.txagent
|
|
importlib.reload(txagent.txagent)
|
|
from txagent.txagent import TxAgent
|
|
|
|
|
|
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
|
|
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))
|
|
|
|
|
|
current_dir = os.path.abspath(os.path.dirname(__file__))
|
|
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
|
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
|
|
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
|
|
new_tool_files = {
|
|
"new_tool": os.path.join(current_dir, "data", "new_tool.json")
|
|
}
|
|
|
|
|
|
question_examples = [
|
|
["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
|
|
["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
|
|
]
|
|
|
|
|
|
def format_collapsible(content):
|
|
if isinstance(content, (dict, list)):
|
|
try:
|
|
formatted = json.dumps(content, indent=2)
|
|
except Exception:
|
|
formatted = str(content)
|
|
else:
|
|
formatted = str(content)
|
|
|
|
return (
|
|
"<details style='border: 1px solid #ccc; padding: 8px; margin-top: 8px;'>"
|
|
"<summary style='font-weight: bold;'>Answer</summary>"
|
|
f"<pre style='white-space: pre-wrap;'>{formatted}</pre>"
|
|
"</details>"
|
|
)
|
|
|
|
|
|
def create_ui(agent):
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown("<h1 style='text-align: center;'>TxAgent: Therapeutic Reasoning</h1>")
|
|
gr.Markdown("Ask biomedical or therapeutic questions. Powered by step-by-step reasoning and tools.")
|
|
|
|
temperature = gr.Slider(0, 1, value=0.3, label="Temperature")
|
|
max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
|
|
max_tokens = gr.Slider(128, 32000, value=8192, label="Max Total Tokens")
|
|
max_round = gr.Slider(1, 50, value=30, label="Max Rounds")
|
|
multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
|
|
conversation_state = gr.State([])
|
|
|
|
chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
|
|
message_input = gr.Textbox(placeholder="Ask your biomedical question...", show_label=False)
|
|
send_button = gr.Button("Send", variant="primary")
|
|
|
|
|
|
def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
|
|
generator = agent.run_gradio_chat(
|
|
message=message,
|
|
history=history,
|
|
temperature=temperature,
|
|
max_new_tokens=max_new_tokens,
|
|
max_token=max_tokens,
|
|
call_agent=multi_agent,
|
|
conversation=conversation,
|
|
max_round=max_round
|
|
)
|
|
|
|
for update in generator:
|
|
formatted = []
|
|
for m in update:
|
|
role = m["role"] if isinstance(m, dict) else getattr(m, "role", "assistant")
|
|
content = m["content"] if isinstance(m, dict) else getattr(m, "content", "")
|
|
|
|
if role == "assistant":
|
|
content = format_collapsible(content)
|
|
|
|
formatted.append({"role": role, "content": content})
|
|
yield formatted
|
|
|
|
|
|
inputs = [message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round]
|
|
send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
|
|
message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)
|
|
|
|
gr.Examples(examples=question_examples, inputs=message_input)
|
|
gr.Markdown("**DISCLAIMER**: This demo is for research purposes only and does not provide medical advice.")
|
|
|
|
return demo
|
|
|
|
|
|
if __name__ == "__main__":
|
|
freeze_support()
|
|
|
|
try:
|
|
agent = TxAgent(
|
|
model_name=model_name,
|
|
rag_model_name=rag_model_name,
|
|
tool_files_dict=new_tool_files,
|
|
force_finish=True,
|
|
enable_checker=True,
|
|
step_rag_num=10,
|
|
seed=100,
|
|
additional_default_tools=[]
|
|
)
|
|
agent.init_model()
|
|
|
|
if not hasattr(agent, "run_gradio_chat"):
|
|
raise AttributeError("TxAgent missing run_gradio_chat")
|
|
|
|
demo = create_ui(agent)
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
|
|
|
except Exception as e:
|
|
print(f"❌ App failed to start: {e}")
|
|
raise
|
|
|