File size: 5,341 Bytes
1bb8be7 fe2e04a 1bb8be7 eaac969 1bb8be7 eaac969 1bb8be7 fe2e04a 1bb8be7 fe2e04a eaac969 1bb8be7 eaac969 1bb8be7 eaac969 1bb8be7 fe2e04a f9a6a36 fe2e04a f9a6a36 fe2e04a f9a6a36 fe2e04a 4800765 f9a6a36 fe2e04a eaac969 fe2e04a eaac969 ccb72b9 fe2e04a ccb72b9 1bb8be7 fe2e04a 1bb8be7 eaac969 fe2e04a 896d369 1bb8be7 fe2e04a d2631b6 eaac969 1bb8be7 d2631b6 1bb8be7 d2631b6 1bb8be7 d2631b6 1bb8be7 fe2e04a 1bb8be7 fe2e04a 1bb8be7 fe2e04a eaac969 fe2e04a 1bb8be7 d2631b6 1bb8be7 fe2e04a 1bb8be7 fe2e04a 1bb8be7 eaac969 1bb8be7 eaac969 fe2e04a eaac969 1bb8be7 eaac969 781bcb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import sys
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect
import json
import logging
# === Fix path to include src/txagent
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
# === Import and reload to ensure correct file
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent
# === Debug info
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))
# === Logging
logging.basicConfig(level=logging.INFO)
# === Environment
current_dir = os.path.abspath(os.path.dirname(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# === Model config
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
"new_tool": os.path.join(current_dir, "data", "new_tool.json")
}
# === Example prompts
question_examples = [
["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
]
# === Extract tool name and format output
def extract_tool_name_and_clean_content(msg):
tool_name = "Tool Result"
content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", "")
tool_calls = msg.get("tool_calls") if isinstance(msg, dict) else getattr(msg, "tool_calls", None)
if tool_calls:
try:
if isinstance(tool_calls, str):
tool_calls = json.loads(tool_calls)
tool_name = tool_calls[0].get("name", "Tool Result")
logging.info(f"[extract_tool_name] Parsed tool name: {tool_name}")
except Exception as e:
logging.warning(f"[extract_tool_name] Failed parsing tool_calls: {e}")
if isinstance(content, (dict, list)):
content = json.dumps(content, indent=2)
return f"Tool: {tool_name}", content
# === Format answer in collapsible box
def format_collapsible(content, title="Answer"):
return (
f"<details style='border: 1px solid #ccc; border-radius: 8px; padding: 10px; margin-top: 10px;'>"
f"<summary style='font-size: 16px; font-weight: bold; color: #3B82F6;'>{title}</summary>"
f"<div style='margin-top: 8px; font-size: 15px; line-height: 1.6; white-space: pre-wrap;'>{content}</div></details>"
)
# === Build UI
def create_ui(agent):
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1 style='text-align: center;'>💊 TxAgent: Therapeutic Reasoning</h1>")
gr.Markdown("Ask biomedical or therapeutic questions. Powered by tool-augmented reasoning.")
chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
message_input = gr.Textbox(placeholder="Ask a biomedical question...", show_label=False)
send_button = gr.Button("Send", variant="primary")
conversation_state = gr.State([])
def handle_chat(message, history, conversation):
generator = agent.run_gradio_chat(
message=message,
history=history,
temperature=0.3,
max_new_tokens=1024,
max_token=8192,
call_agent=False,
conversation=conversation,
max_round=30
)
for update in generator:
formatted = []
for m in update:
role = m.get("role") if isinstance(m, dict) else getattr(m, "role", "assistant")
if role == "assistant":
title, clean = extract_tool_name_and_clean_content(m)
content = format_collapsible(clean, title)
else:
content = m.get("content") if isinstance(m, dict) else getattr(m, "content", "")
formatted.append({"role": role, "content": content})
yield formatted
inputs = [message_input, chatbot, conversation_state]
send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)
gr.Examples(examples=question_examples, inputs=message_input)
gr.Markdown("<small style='color: gray;'>DISCLAIMER: This demo is for research purposes only and does not provide medical advice.</small>")
return demo
# === Main
if __name__ == "__main__":
freeze_support()
try:
agent = TxAgent(
model_name=model_name,
rag_model_name=rag_model_name,
tool_files_dict=new_tool_files,
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=[]
)
agent.init_model()
if not hasattr(agent, "run_gradio_chat"):
raise AttributeError("❌ TxAgent is missing `run_gradio_chat`.")
demo = create_ui(agent)
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)
except Exception as e:
print(f"❌ App failed to start: {e}")
raise
|