File size: 4,193 Bytes
1155704 8c814cd 47f0902 9aeb1dd 3f69fbe 9c9d2f8 16f16a5 91d1d93 16f16a5 91d1d93 6604d0d 1729ddc 4b98818 91d1d93 3f76413 91d1d93 3f76413 4b98818 84b4115 0cec600 1155704 0cec600 91d1d93 0cec600 91d1d93 0cec600 91d1d93 47f0902 0cec600 47f0902 0cec600 47f0902 91d1d93 47f0902 0cec600 91d1d93 3f69fbe 91d1d93 6309d92 91d1d93 84b4115 91d1d93 9c9d2f8 bb37713 84b4115 91d1d93 6309d92 84b4115 6309d92 84b4115 91d1d93 84b4115 6309d92 84b4115 6309d92 84b4115 6309d92 91d1d93 6309d92 3f69fbe 9c9d2f8 3f69fbe 9c9d2f8 91d1d93 9c9d2f8 3f69fbe 9c9d2f8 3f69fbe 91d1d93 9c9d2f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
import sys
import random
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect
# === Path fix
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
# === Reload to avoid stale module
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent
from gradio import ChatMessage
# === Debug
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))
# === Env vars
current_dir = os.path.dirname(os.path.abspath(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# === UI text
DESCRIPTION = '''
<h1 style="text-align: center;">TxAgent: AI for Therapeutic Reasoning</h1>
'''
INTRO = "Ask biomedical or therapeutic questions. Results are powered by tools and reasoning."
LICENSE = "DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE."
# === Model & tool config
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
"new_tool": os.path.join(current_dir, "data", "new_tool.json")
}
question_examples = [
["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
]
# === Gradio UI
def create_ui(agent):
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
gr.Markdown(INTRO)
temperature = gr.Slider(0, 1, step=0.1, value=0.3, label="Temperature")
max_new_tokens = gr.Slider(128, 4096, step=1, value=1024, label="Max New Tokens")
max_tokens = gr.Slider(128, 32000, step=1, value=8192, label="Max Total Tokens")
max_round = gr.Slider(1, 50, step=1, value=30, label="Max Rounds")
multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
conversation_state = gr.State([])
chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
message_input = gr.Textbox(placeholder="Ask a biomedical question...", show_label=False)
send_btn = gr.Button("Send", variant="primary")
def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
# Ensure response is a generator that yields list of {role, content} dictionaries
return agent.run_gradio_chat(
message=message,
history=history,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_token=max_tokens,
call_agent=multi_agent,
conversation=conversation,
max_round=max_round
)
send_btn.click(
fn=handle_chat,
inputs=[message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
outputs=chatbot
)
message_input.submit(
fn=handle_chat,
inputs=[message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
outputs=chatbot
)
gr.Examples(
examples=question_examples,
inputs=message_input
)
gr.Markdown(LICENSE)
return demo
# === App startup
if __name__ == "__main__":
freeze_support()
try:
agent = TxAgent(
model_name=model_name,
rag_model_name=rag_model_name,
tool_files_dict=new_tool_files,
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
agent.init_model()
if not hasattr(agent, "run_gradio_chat"):
raise AttributeError("❌ TxAgent is missing `run_gradio_chat`.")
demo = create_ui(agent)
demo.launch(show_error=True)
except Exception as e:
print(f"🚨 Startup error: {e}")
raise
|