test / app.py
Ali2206's picture
Update app.py
aa0dcbf verified
raw
history blame
9.44 kB
import os
import json
import logging
import torch
import gradio as gr
from tooluniverse import ToolUniverse
from txagent import TxAgent # Updated import statement
import warnings
from typing import List, Dict, Any
# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning)
# Configuration with hardcoded embedding file
CONFIG = {
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
"tool_files": {
"new_tool": "./data/new_tool.json"
}
}
# Logging setup
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def prepare_tool_files():
"""Ensure tool files exist and are populated"""
os.makedirs("./data", exist_ok=True)
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
logger.info("Generating tool list using ToolUniverse...")
tu = ToolUniverse()
tools = tu.get_all_tools()
with open(CONFIG["tool_files"]["new_tool"], "w") as f:
json.dump(tools, f, indent=2)
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
def safe_load_embeddings(filepath: str) -> Any:
"""Safely load embeddings with proper weights_only handling"""
try:
# First try with weights_only=True (secure mode)
return torch.load(filepath, weights_only=True)
except Exception as e:
logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
# If that fails, try with weights_only=False (less secure)
return torch.load(filepath, weights_only=False)
def patch_embedding_loading():
"""Monkey-patch the embedding loading functionality"""
try:
from txagent.toolrag import ToolRAGModel
original_load = ToolRAGModel.load_tool_desc_embedding
def patched_load(self, tooluniverse: ToolUniverse) -> bool:
try:
if not os.path.exists(CONFIG["embedding_filename"]):
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
return False
# Load embeddings safely
self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
# Handle tool count mismatch
tools = tooluniverse.get_all_tools()
current_count = len(tools)
embedding_count = len(self.tool_desc_embedding)
if current_count != embedding_count:
logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
if current_count < embedding_count:
self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
logger.info(f"Truncated embeddings to match {current_count} tools")
else:
last_embedding = self.tool_desc_embedding[-1]
padding = [last_embedding] * (current_count - embedding_count)
self.tool_desc_embedding = torch.cat(
[self.tool_desc_embedding] + padding
)
logger.info(f"Padded embeddings to match {current_count} tools")
return True
except Exception as e:
logger.error(f"Failed to load embeddings: {str(e)}")
return False
# Apply the patch
ToolRAGModel.load_tool_desc_embedding = patched_load
logger.info("Successfully patched embedding loading")
except Exception as e:
logger.error(f"Failed to patch embedding loading: {str(e)}")
raise
class TxAgentApp:
def __init__(self):
self.agent = None
self.is_initialized = False
def initialize(self) -> str:
"""Initialize the TxAgent with all required components"""
if self.is_initialized:
return "✅ Already initialized"
try:
# Apply our patch before initialization
patch_embedding_loading()
logger.info("Initializing TxAgent...")
self.agent = TxAgent(
model_name=CONFIG["model_name"],
rag_model_name=CONFIG["rag_model_name"],
tool_files_dict=CONFIG["tool_files"],
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
logger.info("Loading models...")
self.agent.init_model()
self.is_initialized = True
return "✅ TxAgent initialized successfully"
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
return f"❌ Initialization failed: {str(e)}"
def chat(self, message: str, history: List[List[str]]) -> List[List[str]]:
"""
Handle chat interactions with the TxAgent
Args:
message: User input message
history: Chat history in format [[user_msg, bot_msg], ...]
Returns:
Updated chat history
"""
if not self.is_initialized:
return history + [["", "⚠️ Please initialize the model first"]]
try:
# Convert history to the format TxAgent expects
tx_history = []
for user_msg, bot_msg in history:
tx_history.append({"role": "user", "content": user_msg})
if bot_msg: # Only add bot response if it exists
tx_history.append({"role": "assistant", "content": bot_msg})
# Generate response
response = ""
for chunk in self.agent.run_gradio_chat(
message=message,
history=tx_history,
temperature=0.3,
max_new_tokens=1024,
max_token=8192,
call_agent=False,
conversation=None,
max_round=30
):
response = chunk # Get the final response
# Format response for Gradio Chatbot
return history + [[message, response]]
except Exception as e:
logger.error(f"Chat error: {str(e)}")
return history + [["", f"Error: {str(e)}"]]
def create_interface() -> gr.Blocks:
"""Create the Gradio interface"""
app = TxAgentApp()
with gr.Blocks(
title="TxAgent",
css="""
.gradio-container {max-width: 900px !important}
"""
) as demo:
gr.Markdown("""
# 🧠 TxAgent: Therapeutic Reasoning AI
### (Using pre-loaded embeddings)
""")
with gr.Row():
init_btn = gr.Button("Initialize Model", variant="primary")
init_status = gr.Textbox(label="Status", interactive=False)
chatbot = gr.Chatbot(
height=500,
label="Conversation"
)
msg = gr.Textbox(label="Your clinical question")
clear_btn = gr.Button("Clear Chat")
gr.Examples(
examples=[
"How to adjust Journavx for renal impairment?",
"Xolremdi and Prozac interaction in WHIM syndrome?",
"Alternative to Warfarin for patient with amiodarone?"
],
inputs=msg
)
def wrapper_initialize() -> tuple:
"""Wrapper for initialization with UI updates"""
status = app.initialize()
return status, gr.update(interactive=False)
init_btn.click(
fn=wrapper_initialize,
outputs=[init_status, init_btn]
)
msg.submit(
fn=app.chat,
inputs=[msg, chatbot],
outputs=chatbot
).then(
lambda: "", # Clear message box
outputs=msg
)
clear_btn.click(
fn=lambda: ([], ""),
outputs=[chatbot, msg]
)
return demo
if __name__ == "__main__":
try:
logger.info("Starting application...")
# Verify embedding file exists
if not os.path.exists(CONFIG["embedding_filename"]):
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
logger.info("Please ensure the file is in the root directory")
else:
logger.info(f"Found embedding file: {CONFIG['embedding_filename']}")
# Prepare tool files
prepare_tool_files()
# Launch interface
interface = create_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
logger.error(f"Application failed to start: {str(e)}")
raise