File size: 5,404 Bytes
1bb8be7 d2631b6 896d369 d2631b6 ccb72b9 d2631b6 ccb72b9 d2631b6 ccb72b9 1bb8be7 d2631b6 ccb72b9 d2631b6 ccb72b9 1bb8be7 d2631b6 1bb8be7 d2631b6 1bb8be7 896d369 1bb8be7 d2631b6 1bb8be7 d2631b6 1bb8be7 d2631b6 1bb8be7 d2631b6 1bb8be7 896d369 1bb8be7 d2631b6 896d369 1bb8be7 896d369 1bb8be7 d2631b6 1bb8be7 d2631b6 1bb8be7 d2631b6 1bb8be7 896d369 1bb8be7 d2631b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
import sys
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect
import json
# Fix path to include src
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
# Reload TxAgent from txagent.py
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent
# Debug info
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))
# Env vars
current_dir = os.path.abspath(os.path.dirname(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Model config
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
"new_tool": os.path.join(current_dir, "data", "new_tool.json")
}
# Sample questions
question_examples = [
["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
]
# Helper: format assistant responses in collapsible panels
def format_collapsible(content, tool_name=None):
# Try parsing if it's a JSON string
if isinstance(content, str):
try:
content = json.loads(content)
except Exception:
pass
if isinstance(content, dict) and "results" in content:
readable = ""
for i, result in enumerate(content["results"], 1):
readable += f"\n🔹 **Result {i}:**\n"
for key, value in result.items():
key_str = key.replace("openfda.", "").replace("_", " ").capitalize()
val_str = ", ".join(value) if isinstance(value, list) else str(value)
readable += f"- **{key_str}**: {val_str}\n"
formatted = readable.strip()
elif isinstance(content, (dict, list)):
formatted = json.dumps(content, indent=2)
else:
formatted = str(content)
title = f"{tool_name or 'Answer'}"
return (
"<details style='border: 1px solid #aaa; border-radius: 8px; padding: 10px; margin: 12px 0; background-color: #f8f8f8;'>"
f"<summary style='font-weight: bold; font-size: 16px; color: #333;'>{title}</summary>"
f"<div style='white-space: pre-wrap; font-family: sans-serif; color: #222; padding-top: 6px;'>{formatted}</div>"
"</details>"
)
# === UI setup
def create_ui(agent):
with gr.Blocks(css="body { background-color: #f5f5f5; font-family: sans-serif; }") as demo:
gr.Markdown("<h1 style='text-align: center;'>TxAgent: Therapeutic Reasoning</h1>")
gr.Markdown("<p style='text-align: center;'>Ask biomedical or therapeutic questions. Powered by step-by-step reasoning and tools.</p>")
conversation_state = gr.State([])
chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
message_input = gr.Textbox(placeholder="Ask your biomedical question...", show_label=False)
send_button = gr.Button("Send", variant="primary")
# Main handler
def handle_chat(message, history, conversation):
generator = agent.run_gradio_chat(
message=message,
history=history,
temperature=0.3,
max_new_tokens=1024,
max_token=8192,
call_agent=False,
conversation=conversation,
max_round=30
)
for update in generator:
formatted = []
for m in update:
role = m["role"] if isinstance(m, dict) else getattr(m, "role", "assistant")
content = m["content"] if isinstance(m, dict) else getattr(m, "content", "")
tool_name = m.get("tool_name") if isinstance(m, dict) else getattr(m, "tool_name", None)
if role == "assistant":
content = format_collapsible(content, tool_name)
formatted.append({"role": role, "content": content})
yield formatted
inputs = [message_input, chatbot, conversation_state]
send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)
gr.Examples(examples=question_examples, inputs=message_input)
gr.Markdown("<p style='font-size: 12px; text-align: center; color: gray;'>This demo is for research purposes only and does not provide medical advice.</p>")
return demo
# === Entry point
if __name__ == "__main__":
freeze_support()
try:
agent = TxAgent(
model_name=model_name,
rag_model_name=rag_model_name,
tool_files_dict=new_tool_files,
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=[]
)
agent.init_model()
if not hasattr(agent, "run_gradio_chat"):
raise AttributeError("TxAgent missing run_gradio_chat")
demo = create_ui(agent)
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True)
except Exception as e:
print(f"\u274c App failed to start: {e}")
raise |