|
import random |
|
import datetime |
|
import sys |
|
import os |
|
import torch |
|
import logging |
|
import json |
|
from importlib.resources import files |
|
from txagent import TxAgent |
|
from tooluniverse import ToolUniverse |
|
import gradio as gr |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
os.environ["MKL_THREADING_LAYER"] = "GNU" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
CONFIG = { |
|
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt", |
|
"tool_files": { |
|
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')), |
|
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')), |
|
"special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')), |
|
"monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')), |
|
"new_tool": os.path.join(current_dir, 'data', 'new_tool.json') |
|
} |
|
} |
|
|
|
DESCRIPTION = ''' |
|
<div> |
|
<h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools</h1> |
|
</div> |
|
''' |
|
|
|
INTRO = """ |
|
Precision therapeutics require multimodal adaptive models that provide personalized treatment recommendations. |
|
We introduce TxAgent, an AI agent that leverages multi-step reasoning and real-time biomedical knowledge |
|
retrieval across a toolbox of 211 expert-curated tools to navigate complex drug interactions, |
|
contraindications, and patient-specific treatment strategies, delivering evidence-grounded therapeutic decisions. |
|
""" |
|
|
|
LICENSE = """ |
|
We welcome your feedback and suggestions to enhance your experience with TxAgent, and if you're interested |
|
in collaboration, please email Marinka Zitnik and Shanghua Gao. |
|
|
|
### Medical Advice Disclaimer |
|
DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE |
|
The information, including but not limited to, text, graphics, images and other material contained on this |
|
website are for informational purposes only. No material on this site is intended to be a substitute for |
|
professional medical advice, diagnosis or treatment. |
|
""" |
|
|
|
PLACEHOLDER = """ |
|
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> |
|
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1> |
|
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p> |
|
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️ (top-right) to remove previous context before submitting a new question.</p> |
|
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p> |
|
</div> |
|
""" |
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display: block; |
|
} |
|
|
|
#duplicate-button { |
|
margin: auto; |
|
color: white; |
|
background: #1565c0; |
|
border-radius: 100vh; |
|
} |
|
.small-button button { |
|
font-size: 12px !important; |
|
padding: 4px 8px !important; |
|
height: 6px !important; |
|
width: 4px !important; |
|
} |
|
.gradio-accordion { |
|
margin-top: 0px !important; |
|
margin-bottom: 0px !important; |
|
} |
|
""" |
|
|
|
chat_css = """ |
|
.gr-button { font-size: 20px !important; } |
|
.gr-button svg { width: 32px !important; height: 32px !important; } |
|
""" |
|
|
|
def safe_load_embeddings(filepath: str) -> any: |
|
"""Safely load embeddings with proper weights_only handling""" |
|
try: |
|
|
|
return torch.load(filepath, weights_only=True) |
|
except Exception as e: |
|
logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}") |
|
try: |
|
|
|
return torch.load(filepath, weights_only=False) |
|
except Exception as e: |
|
logger.error(f"Failed to load embeddings: {str(e)}") |
|
return None |
|
|
|
def patch_embedding_loading(): |
|
"""Monkey-patch the embedding loading functionality""" |
|
try: |
|
from txagent.toolrag import ToolRAGModel |
|
|
|
original_load = ToolRAGModel.load_tool_desc_embedding |
|
|
|
def patched_load(self, tooluniverse): |
|
try: |
|
if not os.path.exists(CONFIG["embedding_filename"]): |
|
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}") |
|
return False |
|
|
|
self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"]) |
|
|
|
|
|
if hasattr(tooluniverse, 'get_all_tools'): |
|
tools = tooluniverse.get_all_tools() |
|
elif hasattr(tooluniverse, 'tools'): |
|
tools = tooluniverse.tools |
|
else: |
|
logger.error("No method found to access tools from ToolUniverse") |
|
return False |
|
|
|
current_count = len(tools) |
|
embedding_count = len(self.tool_desc_embedding) |
|
|
|
if current_count != embedding_count: |
|
logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})") |
|
|
|
if current_count < embedding_count: |
|
self.tool_desc_embedding = self.tool_desc_embedding[:current_count] |
|
logger.info(f"Truncated embeddings to match {current_count} tools") |
|
else: |
|
last_embedding = self.tool_desc_embedding[-1] |
|
padding = [last_embedding] * (current_count - embedding_count) |
|
self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding) |
|
logger.info(f"Padded embeddings to match {current_count} tools") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load embeddings: {str(e)}") |
|
return False |
|
|
|
ToolRAGModel.load_tool_desc_embedding = patched_load |
|
logger.info("Successfully patched embedding loading") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to patch embedding loading: {str(e)}") |
|
raise |
|
|
|
def prepare_tool_files(): |
|
"""Ensure tool files exist and are populated""" |
|
os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True) |
|
if not os.path.exists(CONFIG["tool_files"]["new_tool"]): |
|
logger.info("Generating tool list using ToolUniverse...") |
|
try: |
|
tu = ToolUniverse() |
|
if hasattr(tu, 'get_all_tools'): |
|
tools = tu.get_all_tools() |
|
elif hasattr(tu, 'tools'): |
|
tools = tu.tools |
|
else: |
|
tools = [] |
|
logger.error("Could not access tools from ToolUniverse") |
|
|
|
with open(CONFIG["tool_files"]["new_tool"], "w") as f: |
|
json.dump(tools, f, indent=2) |
|
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}") |
|
except Exception as e: |
|
logger.error(f"Failed to prepare tool files: {str(e)}") |
|
|
|
def create_agent(): |
|
"""Create and initialize the TxAgent""" |
|
|
|
patch_embedding_loading() |
|
prepare_tool_files() |
|
|
|
|
|
try: |
|
agent = TxAgent( |
|
CONFIG["model_name"], |
|
CONFIG["rag_model_name"], |
|
tool_files_dict=CONFIG["tool_files"], |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=10, |
|
seed=100, |
|
additional_default_tools=['DirectResponse', 'RequireClarification'] |
|
) |
|
agent.init_model() |
|
return agent |
|
except Exception as e: |
|
logger.error(f"Failed to create agent: {str(e)}") |
|
raise |
|
|
|
def handle_chat_response(history, message, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round): |
|
"""Convert generator output to Gradio-compatible format""" |
|
full_response = "" |
|
for chunk in message: |
|
if isinstance(chunk, dict): |
|
full_response += chunk.get("content", "") |
|
else: |
|
full_response += str(chunk) |
|
history.append((None, full_response)) |
|
return history |
|
|
|
def update_model_parameters(agent, enable_finish, enable_rag, enable_summary, |
|
init_rag_num, step_rag_num, skip_last_k, |
|
summary_mode, summary_skip_last_k, summary_context_length, |
|
force_finish, seed): |
|
"""Update model parameters""" |
|
updated_params = agent.update_parameters( |
|
enable_finish=enable_finish, |
|
enable_rag=enable_rag, |
|
enable_summary=enable_summary, |
|
init_rag_num=init_rag_num, |
|
step_rag_num=step_rag_num, |
|
skip_last_k=skip_last_k, |
|
summary_mode=summary_mode, |
|
summary_skip_last_k=summary_skip_last_k, |
|
summary_context_length=summary_context_length, |
|
force_finish=force_finish, |
|
seed=seed, |
|
) |
|
return updated_params |
|
|
|
def update_seed(agent): |
|
"""Update random seed""" |
|
seed = random.randint(0, 10000) |
|
updated_params = agent.update_parameters(seed=seed) |
|
return updated_params |
|
|
|
def handle_retry(agent, history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round): |
|
"""Handle retry functionality""" |
|
print("Updated seed:", update_seed(agent)) |
|
new_history = history[:retry_data.index] |
|
previous_prompt = history[retry_data.index]['content'] |
|
print("previous_prompt", previous_prompt) |
|
response = agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}], |
|
temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round) |
|
yield from handle_chat_response(new_history, response, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round) |
|
|
|
PASSWORD = "mypassword" |
|
|
|
def check_password(input_password): |
|
"""Check password for protected settings""" |
|
if input_password == PASSWORD: |
|
return gr.update(visible=True), "" |
|
else: |
|
return gr.update(visible=False), "Incorrect password, try again!" |
|
|
|
def create_demo(agent): |
|
"""Create the Gradio interface""" |
|
default_temperature = 0.3 |
|
default_max_new_tokens = 1024 |
|
default_max_tokens = 81920 |
|
default_max_round = 30 |
|
|
|
question_examples = [ |
|
['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'], |
|
['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'], |
|
['A 30-year-old patient is taking Prozac to treat their depression. They were recently diagnosed with WHIM syndrome and require a treatment for that condition as well. Is Xolremdi suitable for this patient, considering contraindications?'], |
|
] |
|
|
|
chatbot = gr.Chatbot(height=800, placeholder=PLACEHOLDER, |
|
label='TxAgent', show_copy_button=True) |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(DESCRIPTION) |
|
gr.Markdown(INTRO) |
|
|
|
temperature_state = gr.State(value=default_temperature) |
|
max_new_tokens_state = gr.State(value=default_max_new_tokens) |
|
max_tokens_state = gr.State(value=default_max_tokens) |
|
max_round_state = gr.State(value=default_max_round) |
|
|
|
chatbot.retry( |
|
lambda *args: handle_retry(agent, *args), |
|
inputs=[chatbot, chatbot, temperature_state, max_new_tokens_state, |
|
max_tokens_state, gr.Checkbox(value=False, render=False), |
|
gr.State([]), max_round_state] |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
msg = gr.Textbox(label="Input", placeholder="Type your question here...") |
|
with gr.Column(scale=1): |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
|
|
with gr.Row(): |
|
clear_btn = gr.ClearButton([msg, chatbot]) |
|
|
|
def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round): |
|
response = agent.run_gradio_chat( |
|
chat_history + [{"role": "user", "content": message}], |
|
temperature, |
|
max_new_tokens, |
|
max_tokens, |
|
multi_agent, |
|
conversation, |
|
max_round |
|
) |
|
return handle_chat_response(chat_history, response, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round) |
|
|
|
submit_btn.click( |
|
respond, |
|
inputs=[msg, chatbot, temperature_state, max_new_tokens_state, |
|
max_tokens_state, gr.Checkbox(value=False, render=False), |
|
gr.State([]), max_round_state], |
|
outputs=[chatbot] |
|
) |
|
msg.submit( |
|
respond, |
|
inputs=[msg, chatbot, temperature_state, max_new_tokens_state, |
|
max_tokens_state, gr.Checkbox(value=False, render=False), |
|
gr.State([]), max_round_state], |
|
outputs=[chatbot] |
|
) |
|
|
|
with gr.Accordion("Settings", open=False): |
|
temperature_slider = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
step=0.1, |
|
value=default_temperature, |
|
label="Temperature" |
|
) |
|
max_new_tokens_slider = gr.Slider( |
|
minimum=128, |
|
maximum=4096, |
|
step=1, |
|
value=default_max_new_tokens, |
|
label="Max new tokens" |
|
) |
|
max_tokens_slider = gr.Slider( |
|
minimum=128, |
|
maximum=32000, |
|
step=1, |
|
value=default_max_tokens, |
|
label="Max tokens" |
|
) |
|
max_round_slider = gr.Slider( |
|
minimum=0, |
|
maximum=50, |
|
step=1, |
|
value=default_max_round, |
|
label="Max round") |
|
|
|
temperature_slider.change( |
|
lambda x: x, inputs=temperature_slider, outputs=temperature_state) |
|
max_new_tokens_slider.change( |
|
lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state) |
|
max_tokens_slider.change( |
|
lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state) |
|
max_round_slider.change( |
|
lambda x: x, inputs=max_round_slider, outputs=max_round_state) |
|
|
|
password_input = gr.Textbox( |
|
label="Enter Password for More Settings", type="password") |
|
incorrect_message = gr.Textbox(visible=False, interactive=False) |
|
|
|
with gr.Accordion("⚙️ Advanced Settings", open=False, visible=False) as protected_accordion: |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Accordion("Model Settings", open=False): |
|
model_name_input = gr.Textbox( |
|
label="Enter model path", value=CONFIG["model_name"]) |
|
load_model_btn = gr.Button(value="Load Model") |
|
load_model_btn.click( |
|
agent.load_models, |
|
inputs=model_name_input, |
|
outputs=gr.Textbox(label="Status")) |
|
with gr.Column(scale=1): |
|
with gr.Accordion("Functional Parameters", open=False): |
|
enable_finish = gr.Checkbox(label="Enable Finish", value=True) |
|
enable_rag = gr.Checkbox(label="Enable RAG", value=True) |
|
enable_summary = gr.Checkbox(label="Enable Summary", value=False) |
|
init_rag_num = gr.Number(label="Initial RAG Num", value=0) |
|
step_rag_num = gr.Number(label="Step RAG Num", value=10) |
|
skip_last_k = gr.Number(label="Skip Last K", value=0) |
|
summary_mode = gr.Textbox(label="Summary Mode", value='step') |
|
summary_skip_last_k = gr.Number(label="Summary Skip Last K", value=0) |
|
summary_context_length = gr.Number(label="Summary Context Length", value=None) |
|
force_finish = gr.Checkbox(label="Force FinalAnswer", value=True) |
|
seed = gr.Number(label="Seed", value=100) |
|
submit_btn = gr.Button("Update Parameters") |
|
updated_parameters_output = gr.JSON() |
|
submit_btn.click( |
|
lambda *args: update_model_parameters(agent, *args), |
|
inputs=[enable_finish, enable_rag, enable_summary, |
|
init_rag_num, step_rag_num, skip_last_k, |
|
summary_mode, summary_skip_last_k, |
|
summary_context_length, force_finish, seed], |
|
outputs=updated_parameters_output |
|
) |
|
|
|
submit_button = gr.Button("Submit") |
|
submit_button.click( |
|
check_password, |
|
inputs=password_input, |
|
outputs=[protected_accordion, incorrect_message] |
|
) |
|
|
|
gr.Markdown(LICENSE) |
|
|
|
return demo |
|
|
|
def main(): |
|
"""Main function to run the application""" |
|
try: |
|
agent = create_agent() |
|
demo = create_demo(agent) |
|
demo.launch(share=True) |
|
except Exception as e: |
|
logger.error(f"Application failed to start: {str(e)}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
main() |