File size: 4,552 Bytes
1155704 8c814cd 47f0902 9aeb1dd 3f69fbe 9c9d2f8 16f16a5 537f975 16f16a5 537f975 6604d0d 1729ddc 4b98818 3f76413 537f975 3f76413 4b98818 537f975 0cec600 1155704 0cec600 537f975 47f0902 0cec600 47f0902 0cec600 537f975 47f0902 91d1d93 47f0902 0cec600 537f975 3f69fbe 91d1d93 537f975 6309d92 537f975 6309d92 91d1d93 537f975 84b4115 537f975 84b4115 537f975 bb37713 537f975 84b4115 91d1d93 6309d92 84b4115 6309d92 84b4115 91d1d93 84b4115 537f975 6309d92 537f975 84b4115 6309d92 537f975 6309d92 3f69fbe 537f975 3f69fbe 9c9d2f8 3f69fbe 9c9d2f8 91d1d93 9c9d2f8 3f69fbe 9c9d2f8 3f69fbe 537f975 9c9d2f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import sys
import random
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect
# === Fix path to include src/txagent
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
# === Import and reload to ensure correct file
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent
# === Debug print
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))
# === Environment
current_dir = os.path.abspath(os.path.dirname(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# === Model config
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
"new_tool": os.path.join(current_dir, "data", "new_tool.json")
}
# === Example prompts
question_examples = [
["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
]
# === UI creation
def create_ui(agent):
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>TxAgent: Therapeutic Reasoning</h1>")
gr.Markdown("Ask biomedical or therapeutic questions. Powered by step-by-step reasoning and tools.")
temperature = gr.Slider(0, 1, value=0.3, label="Temperature")
max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
max_tokens = gr.Slider(128, 32000, value=8192, label="Max Total Tokens")
max_round = gr.Slider(1, 50, value=30, label="Max Rounds")
multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
conversation_state = gr.State([])
chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
message_input = gr.Textbox(placeholder="Ask your biomedical question...", show_label=False)
send_button = gr.Button("Send", variant="primary")
# === Core handler (streaming generator)
def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
# Must yield a list of {"role": ..., "content": ...} dicts
generator = agent.run_gradio_chat(
message=message,
history=history,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_token=max_tokens,
call_agent=multi_agent,
conversation=conversation,
max_round=max_round
)
for update in generator:
# Convert to list of dicts if not already
formatted = [
{"role": m["role"], "content": m["content"]}
if isinstance(m, dict)
else {"role": m.role, "content": m.content}
for m in update
]
yield formatted
# === Trigger handlers
send_button.click(
fn=handle_chat,
inputs=[message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
outputs=chatbot
)
message_input.submit(
fn=handle_chat,
inputs=[message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
outputs=chatbot
)
gr.Examples(examples=question_examples, inputs=message_input)
gr.Markdown("**DISCLAIMER**: This demo is for research purposes only and does not provide medical advice.")
return demo
# === Startup
if __name__ == "__main__":
freeze_support()
try:
agent = TxAgent(
model_name=model_name,
rag_model_name=rag_model_name,
tool_files_dict=new_tool_files,
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
agent.init_model()
if not hasattr(agent, "run_gradio_chat"):
raise AttributeError("❌ TxAgent is missing `run_gradio_chat`.")
demo = create_ui(agent)
demo.launch(show_error=True)
except Exception as e:
print(f"❌ Application failed to start: {e}")
raise
|