test / app.py
Ali2206's picture
Update app.py
929325a verified
raw
history blame
9.52 kB
import os
import json
import logging
import torch
import gradio as gr
from tooluniverse import ToolUniverse
import warnings
from typing import List, Dict, Any
# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning)
# Configuration with hardcoded embedding file
CONFIG = {
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
"tool_files": {
"new_tool": "./data/new_tool.json"
}
}
# Logging setup
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def prepare_tool_files():
"""Ensure tool files exist and are populated"""
os.makedirs("./data", exist_ok=True)
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
logger.info("Generating tool list using ToolUniverse...")
tu = ToolUniverse()
tools = tu.get_all_tools()
with open(CONFIG["tool_files"]["new_tool"], "w") as f:
json.dump(tools, f, indent=2)
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
def safe_load_embeddings(filepath: str) -> Any:
"""Safely load embeddings with proper weights_only handling"""
try:
# First try with weights_only=True (secure mode)
return torch.load(filepath, weights_only=True)
except Exception as e:
logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
# If that fails, try with weights_only=False (less secure)
return torch.load(filepath, weights_only=False)
def patch_embedding_loading():
"""Monkey-patch the embedding loading functionality"""
try:
from txagent.toolrag import ToolRAGModel
original_load = ToolRAGModel.load_tool_desc_embedding
def patched_load(self, tooluniverse: ToolUniverse) -> bool:
try:
if not os.path.exists(CONFIG["embedding_filename"]):
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
return False
# Load embeddings safely
self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
# Handle tool count mismatch
tools = tooluniverse.get_all_tools() # Use get_all_tools() instead of direct access
current_count = len(tools)
embedding_count = len(self.tool_desc_embedding)
if current_count != embedding_count:
logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
if current_count < embedding_count:
self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
logger.info(f"Truncated embeddings to match {current_count} tools")
else:
last_embedding = self.tool_desc_embedding[-1]
padding = [last_embedding] * (current_count - embedding_count)
self.tool_desc_embedding = torch.cat(
[self.tool_desc_embedding] + padding
)
logger.info(f"Padded embeddings to match {current_count} tools")
return True
except Exception as e:
logger.error(f"Failed to load embeddings: {str(e)}")
return False
# Apply the patch
ToolRAGModel.load_tool_desc_embedding = patched_load
logger.info("Successfully patched embedding loading")
except Exception as e:
logger.error(f"Failed to patch embedding loading: {str(e)}")
raise
class TxAgentApp:
def __init__(self):
self.agent = None
self.is_initialized = False
def initialize(self) -> str:
"""Initialize the TxAgent with all required components"""
if self.is_initialized:
return "✅ Already initialized"
try:
# Apply our patch before initialization
patch_embedding_loading()
logger.info("Initializing TxAgent...")
self.agent = TxAgent(
model_name=CONFIG["model_name"],
rag_model_name=CONFIG["rag_model_name"],
tool_files_dict=CONFIG["tool_files"],
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
logger.info("Loading models...")
self.agent.init_model()
self.is_initialized = True
return "✅ TxAgent initialized successfully"
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
return f"❌ Initialization failed: {str(e)}"
def chat(self, message: str, history: List[List[str]]) -> List[List[str]]:
"""
Handle chat interactions with the TxAgent
Args:
message: User input message
history: Chat history in format [[user_msg, bot_msg], ...]
Returns:
Updated chat history
"""
if not self.is_initialized:
return history + [["", "⚠️ Please initialize the model first"]]
try:
# Convert history to the format TxAgent expects
tx_history = []
for user_msg, bot_msg in history:
tx_history.append({"role": "user", "content": user_msg})
if bot_msg: # Only add bot response if it exists
tx_history.append({"role": "assistant", "content": bot_msg})
# Generate response
response = ""
for chunk in self.agent.run_gradio_chat(
message=message,
history=tx_history,
temperature=0.3,
max_new_tokens=1024,
max_token=8192, # Note: Using max_token instead of max_length
call_agent=False,
conversation=None,
max_round=30
):
response = chunk # Get the final response
# Format response for Gradio Chatbot
return history + [[message, response]]
except Exception as e:
logger.error(f"Chat error: {str(e)}")
return history + [["", f"Error: {str(e)}"]]
def create_interface() -> gr.Blocks:
"""Create the Gradio interface"""
app = TxAgentApp()
with gr.Blocks(
title="TxAgent",
css="""
.gradio-container {max-width: 900px !important}
"""
) as demo:
gr.Markdown("""
# 🧠 TxAgent: Therapeutic Reasoning AI
### (Using pre-loaded embeddings)
""")
with gr.Row():
init_btn = gr.Button("Initialize Model", variant="primary")
init_status = gr.Textbox(label="Status", interactive=False)
chatbot = gr.Chatbot(
height=500,
label="Conversation",
bubble_full_width=False
)
msg = gr.Textbox(label="Your clinical question")
clear_btn = gr.Button("Clear Chat")
gr.Examples(
examples=[
"How to adjust Journavx for renal impairment?",
"Xolremdi and Prozac interaction in WHIM syndrome?",
"Alternative to Warfarin for patient with amiodarone?"
],
inputs=msg
)
def wrapper_initialize() -> tuple:
"""Wrapper for initialization with UI updates"""
status = app.initialize()
return status, gr.update(interactive=False)
init_btn.click(
fn=wrapper_initialize,
outputs=[init_status, init_btn]
)
msg.submit(
fn=app.chat,
inputs=[msg, chatbot],
outputs=chatbot
).then(
lambda: "", # Clear message box
outputs=msg
)
clear_btn.click(
fn=lambda: ([], ""),
outputs=[chatbot, msg]
)
return demo
if __name__ == "__main__":
try:
logger.info("Starting application...")
# Verify embedding file exists
if not os.path.exists(CONFIG["embedding_filename"]):
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
logger.info("Please ensure the file is in the root directory")
else:
logger.info(f"Found embedding file: {CONFIG['embedding_filename']}")
# Prepare tool files
prepare_tool_files()
# Launch interface
interface = create_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
logger.error(f"Application failed to start: {str(e)}")
raise