File size: 5,708 Bytes
1bb8be7 eaac969 1bb8be7 eaac969 1bb8be7 eaac969 1bb8be7 eaac969 1bb8be7 eaac969 1bb8be7 eaac969 1bb8be7 eaac969 f9a6a36 67e48f6 f9a6a36 67e48f6 f9a6a36 4800765 f9a6a36 4800765 f9a6a36 eaac969 67e48f6 eaac969 ccb72b9 eaac969 ccb72b9 1bb8be7 eaac969 1bb8be7 eaac969 d2631b6 eaac969 896d369 1bb8be7 d2631b6 eaac969 1bb8be7 eaac969 d2631b6 1bb8be7 d2631b6 1bb8be7 d2631b6 1bb8be7 896d369 1bb8be7 eaac969 1bb8be7 eaac969 1bb8be7 eaac969 d2631b6 1bb8be7 eaac969 1bb8be7 eaac969 1bb8be7 eaac969 1bb8be7 eaac969 1bb8be7 eaac969 1bb8be7 eaac969 781bcb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import os
import sys
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect
import json
# === Fix path to include src/txagent
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
# === Import and reload to ensure correct file
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent
# === Debug print
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))
# === Environment
current_dir = os.path.abspath(os.path.dirname(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# === Model config
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
"new_tool": os.path.join(current_dir, "data", "new_tool.json")
}
# === Example prompts
question_examples = [
["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
]
# === Helper: extract tool name from content
def extract_tool_name_and_clean_content(message_obj):
import logging
logging.basicConfig(level=logging.INFO)
tool_name = "Tool Result"
content = ""
if isinstance(message_obj, dict):
role = message_obj.get("role", "assistant")
content = message_obj.get("content", "")
tool_calls = message_obj.get("tool_calls", None)
else:
role = getattr(message_obj, "role", "assistant")
content = getattr(message_obj, "content", "")
tool_calls = getattr(message_obj, "tool_calls", None)
# Try to extract tool name from `tool_calls`
if tool_calls:
try:
if isinstance(tool_calls, str):
import json
tool_calls = json.loads(tool_calls)
tool_name = tool_calls[0].get("name", "Tool Result")
logging.info(f"[extract_tool_name] Extracted from tool_calls: {tool_name}")
except Exception as e:
logging.warning(f"[extract_tool_name] Failed tool_calls parsing: {e}")
# Format clean output
if isinstance(content, (dict, list)):
formatted = json.dumps(content, indent=2)
else:
formatted = str(content)
return f"Tool: {tool_name}", formatted
# === Helper: formatted collapsible output
def format_collapsible(content, title="Answer"):
return (
f"<details style='border: 1px solid #ccc; padding: 8px; margin-top: 8px;'>"
f"<summary style='font-weight: bold;'>{title}</summary>"
f"<div style='margin-top: 8px; white-space: pre-wrap;'>{content}</div></details>"
)
# === UI creation
def create_ui(agent):
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1 style='text-align: center;'>TxAgent: Therapeutic Reasoning</h1>")
gr.Markdown("Ask biomedical or therapeutic questions. Powered by step-by-step reasoning and tools.")
chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
message_input = gr.Textbox(placeholder="Ask your biomedical question...", show_label=False)
send_button = gr.Button("Send", variant="primary")
conversation_state = gr.State([])
# === Core handler (streaming generator)
def handle_chat(message, history, conversation):
generator = agent.run_gradio_chat(
message=message,
history=history,
temperature=0.3,
max_new_tokens=1024,
max_token=8192,
call_agent=False,
conversation=conversation,
max_round=30
)
for update in generator:
formatted_messages = []
for m in update:
role = m["role"] if isinstance(m, dict) else getattr(m, "role", "assistant")
content = m["content"] if isinstance(m, dict) else getattr(m, "content", "")
if role == "assistant":
title, clean = extract_tool_name_and_clean_content(content)
content = format_collapsible(clean, title)
formatted_messages.append({"role": role, "content": content})
yield formatted_messages
# === Trigger handlers
inputs = [message_input, chatbot, conversation_state]
send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)
gr.Examples(examples=question_examples, inputs=message_input)
gr.Markdown("**DISCLAIMER**: This demo is for research purposes only and does not provide medical advice.")
return demo
# === Startup
if __name__ == "__main__":
freeze_support()
try:
agent = TxAgent(
model_name=model_name,
rag_model_name=rag_model_name,
tool_files_dict=new_tool_files,
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=[]
)
agent.init_model()
if not hasattr(agent, "run_gradio_chat"):
raise AttributeError("❌ TxAgent is missing `run_gradio_chat`.")
demo = create_ui(agent)
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=True
)
except Exception as e:
print(f"❌ App failed to start: {e}")
raise
|