File size: 4,645 Bytes
1155704 8c814cd 47f0902 9aeb1dd 3f69fbe 9c9d2f8 16f16a5 84b4115 16f16a5 84b4115 6604d0d 1729ddc 4b98818 3f76413 84b4115 3f76413 4b98818 84b4115 0cec600 1155704 0cec600 84b4115 0cec600 40ad293 0cec600 47f0902 84b4115 0cec600 47f0902 84b4115 0cec600 84b4115 0cec600 84b4115 47f0902 0cec600 47f0902 0cec600 84b4115 47f0902 0cec600 84b4115 3f69fbe 6309d92 ea92d02 84b4115 6309d92 84b4115 9c9d2f8 bb37713 84b4115 6309d92 84b4115 6309d92 84b4115 6309d92 84b4115 6309d92 84b4115 6309d92 9c9d2f8 6309d92 3f69fbe 9c9d2f8 3f69fbe 9c9d2f8 3f69fbe 84b4115 9c9d2f8 3f69fbe 9c9d2f8 3f69fbe 84b4115 9c9d2f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import sys
import random
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect
# === Fix import path BEFORE loading TxAgent
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
# === Reload to avoid stale cache
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent
# === Debug confirmation
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))
# === Env vars
current_dir = os.path.dirname(os.path.abspath(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# === UI Text
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools</h1>
</div>
'''
INTRO = "Precision therapeutics require multimodal adaptive models..."
LICENSE = "DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE..."
css = """
h1 { text-align: center; }
.gradio-accordion { margin-top: 0 !important; margin-bottom: 0 !important; }
"""
chat_css = """
.gr-button { font-size: 18px !important; }
.gr-button svg { width: 26px !important; height: 26px !important; }
"""
# === Model setup
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
"new_tool": os.path.join(current_dir, "data", "new_tool.json")
}
# === Sample prompts
question_examples = [
["Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering moderate hepatic impairment?"],
["A 30-year-old patient is on Prozac for depression and now diagnosed with WHIM syndrome. Is Xolremdi suitable?"]
]
# === UI
def create_ui(agent):
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.Markdown(INTRO)
temperature = gr.Slider(0, 1, step=0.1, value=0.3, label="Temperature")
max_new_tokens = gr.Slider(128, 4096, step=1, value=1024, label="Max New Tokens")
max_tokens = gr.Slider(128, 32000, step=1, value=8192, label="Max Total Tokens")
max_round = gr.Slider(1, 50, step=1, value=30, label="Max Rounds")
multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
conversation_state = gr.State([])
chatbot = gr.Chatbot(label="TxAgent", height=700, type="messages")
message_input = gr.Textbox(placeholder="Ask a biomedical question...", show_label=False)
send_btn = gr.Button("Send", variant="primary")
# === Streaming handler
def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
return agent.run_gradio_chat(
message=message,
history=history,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_token=max_tokens,
call_agent=multi_agent,
conversation=conversation,
max_round=max_round
)
# === Submit handlers
send_btn.click(
fn=handle_chat,
inputs=[message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
outputs=chatbot,
)
message_input.submit(
fn=handle_chat,
inputs=[message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
outputs=chatbot,
)
# === Example buttons
gr.Examples(
examples=question_examples,
inputs=message_input
)
gr.Markdown(LICENSE)
return demo
# === App start
if __name__ == "__main__":
freeze_support()
try:
agent = TxAgent(
model_name=model_name,
rag_model_name=rag_model_name,
tool_files_dict=new_tool_files,
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
agent.init_model()
if not hasattr(agent, 'run_gradio_chat'):
raise AttributeError("TxAgent is missing `run_gradio_chat`. Make sure the correct txagent.py is used.")
demo = create_ui(agent)
demo.launch(show_error=True)
except Exception as e:
print(f"🚨 App failed to start: {e}")
raise
|