File size: 4,982 Bytes
1155704 8c814cd 47f0902 9aeb1dd 3f69fbe 9c9d2f8 16f16a5 537f975 16f16a5 537f975 6604d0d 1729ddc 4b98818 3f76413 537f975 3f76413 4b98818 537f975 0cec600 1155704 0cec600 537f975 47f0902 0cec600 47f0902 0cec600 537f975 47f0902 91d1d93 47f0902 0cec600 3ef6302 537f975 3f69fbe 91d1d93 537f975 6309d92 537f975 6309d92 91d1d93 537f975 84b4115 537f975 84b4115 537f975 bb37713 9b25d93 537f975 9b25d93 3ef6302 9b25d93 537f975 84b4115 91d1d93 6309d92 84b4115 6309d92 84b4115 91d1d93 84b4115 537f975 6309d92 537f975 84b4115 6309d92 537f975 6309d92 3f69fbe 537f975 3f69fbe 9c9d2f8 3f69fbe 3ef6302 3f69fbe 9c9d2f8 91d1d93 9c9d2f8 3f69fbe 9c9d2f8 3f69fbe 537f975 9c9d2f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import sys
import random
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect
# === Fix path to include src/txagent
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
# === Import and reload to ensure correct file
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent
# === Debug print
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))
# === Environment
current_dir = os.path.abspath(os.path.dirname(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# === Model config
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
"new_tool": os.path.join(current_dir, "data", "new_tool.json")
}
# === Example prompts
question_examples = [
["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
]
# === Helper: Add collapsible formatting
def format_collapsible_response(content):
return (
f"<details style='border: 1px solid #ccc; padding: 8px; margin-top: 8px;'>"
f"<summary style='font-weight: bold;'>Answer</summary>"
f"<div style='margin-top: 8px;'>{content}</div></details>"
)
# === UI creation
def create_ui(agent):
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>TxAgent: Therapeutic Reasoning</h1>")
gr.Markdown("Ask biomedical or therapeutic questions. Powered by step-by-step reasoning and tools.")
temperature = gr.Slider(0, 1, value=0.3, label="Temperature")
max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
max_tokens = gr.Slider(128, 32000, value=8192, label="Max Total Tokens")
max_round = gr.Slider(1, 50, value=30, label="Max Rounds")
multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
conversation_state = gr.State([])
chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
message_input = gr.Textbox(placeholder="Ask your biomedical question...", show_label=False)
send_button = gr.Button("Send", variant="primary")
# === Core handler (streaming generator)
def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
generator = agent.run_gradio_chat(
message=message,
history=history,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_token=max_tokens,
call_agent=multi_agent,
conversation=conversation,
max_round=max_round
)
for update in generator:
formatted_messages = []
for m in update:
role = m["role"] if isinstance(m, dict) else getattr(m, "role", "assistant")
content = m["content"] if isinstance(m, dict) else getattr(m, "content", "")
if role == "assistant":
content = format_collapsible_response(content)
formatted_messages.append({"role": role, "content": content})
yield formatted_messages
# === Trigger handlers
send_button.click(
fn=handle_chat,
inputs=[message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
outputs=chatbot
)
message_input.submit(
fn=handle_chat,
inputs=[message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
outputs=chatbot
)
gr.Examples(examples=question_examples, inputs=message_input)
gr.Markdown("**DISCLAIMER**: This demo is for research purposes only and does not provide medical advice.")
return demo
# === Startup
if __name__ == "__main__":
freeze_support()
try:
agent = TxAgent(
model_name=model_name,
rag_model_name=rag_model_name,
tool_files_dict=new_tool_files,
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=[] # Removed DirectResponse/RequireClarification to avoid errors
)
agent.init_model()
if not hasattr(agent, "run_gradio_chat"):
raise AttributeError("❌ TxAgent is missing `run_gradio_chat`.")
demo = create_ui(agent)
demo.launch(show_error=True)
except Exception as e:
print(f"❌ Application failed to start: {e}")
raise
|