import os import json import logging import torch import gradio as gr from tooluniverse import ToolUniverse from txagent import TxAgent # Updated import statement import warnings from typing import List, Dict, Any # Suppress specific warnings warnings.filterwarnings("ignore", category=UserWarning) # Configuration with hardcoded embedding file CONFIG = { "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt", "tool_files": { "new_tool": "./data/new_tool.json" } } # Logging setup logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def prepare_tool_files(): """Ensure tool files exist and are populated""" os.makedirs("./data", exist_ok=True) if not os.path.exists(CONFIG["tool_files"]["new_tool"]): logger.info("Generating tool list using ToolUniverse...") tu = ToolUniverse() tools = tu.get_all_tools() with open(CONFIG["tool_files"]["new_tool"], "w") as f: json.dump(tools, f, indent=2) logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}") def safe_load_embeddings(filepath: str) -> Any: """Safely load embeddings with proper weights_only handling""" try: # First try with weights_only=True (secure mode) return torch.load(filepath, weights_only=True) except Exception as e: logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}") # If that fails, try with weights_only=False (less secure) return torch.load(filepath, weights_only=False) def patch_embedding_loading(): """Monkey-patch the embedding loading functionality""" try: from txagent.toolrag import ToolRAGModel original_load = ToolRAGModel.load_tool_desc_embedding def patched_load(self, tooluniverse: ToolUniverse) -> bool: try: if not os.path.exists(CONFIG["embedding_filename"]): logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}") return False # Load embeddings safely self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"]) # Handle tool count mismatch tools = tooluniverse.get_all_tools() current_count = len(tools) embedding_count = len(self.tool_desc_embedding) if current_count != embedding_count: logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})") if current_count < embedding_count: self.tool_desc_embedding = self.tool_desc_embedding[:current_count] logger.info(f"Truncated embeddings to match {current_count} tools") else: last_embedding = self.tool_desc_embedding[-1] padding = [last_embedding] * (current_count - embedding_count) self.tool_desc_embedding = torch.cat( [self.tool_desc_embedding] + padding ) logger.info(f"Padded embeddings to match {current_count} tools") return True except Exception as e: logger.error(f"Failed to load embeddings: {str(e)}") return False # Apply the patch ToolRAGModel.load_tool_desc_embedding = patched_load logger.info("Successfully patched embedding loading") except Exception as e: logger.error(f"Failed to patch embedding loading: {str(e)}") raise class TxAgentApp: def __init__(self): self.agent = None self.is_initialized = False def initialize(self) -> str: """Initialize the TxAgent with all required components""" if self.is_initialized: return "✅ Already initialized" try: # Apply our patch before initialization patch_embedding_loading() logger.info("Initializing TxAgent...") self.agent = TxAgent( model_name=CONFIG["model_name"], rag_model_name=CONFIG["rag_model_name"], tool_files_dict=CONFIG["tool_files"], force_finish=True, enable_checker=True, step_rag_num=10, seed=100, additional_default_tools=["DirectResponse", "RequireClarification"] ) logger.info("Loading models...") self.agent.init_model() self.is_initialized = True return "✅ TxAgent initialized successfully" except Exception as e: logger.error(f"Initialization failed: {str(e)}") return f"❌ Initialization failed: {str(e)}" def chat(self, message: str, history: List[List[str]]) -> List[List[str]]: """ Handle chat interactions with the TxAgent Args: message: User input message history: Chat history in format [[user_msg, bot_msg], ...] Returns: Updated chat history """ if not self.is_initialized: return history + [["", "⚠️ Please initialize the model first"]] try: # Convert history to the format TxAgent expects tx_history = [] for user_msg, bot_msg in history: tx_history.append({"role": "user", "content": user_msg}) if bot_msg: # Only add bot response if it exists tx_history.append({"role": "assistant", "content": bot_msg}) # Generate response response = "" for chunk in self.agent.run_gradio_chat( message=message, history=tx_history, temperature=0.3, max_new_tokens=1024, max_token=8192, call_agent=False, conversation=None, max_round=30 ): response = chunk # Get the final response # Format response for Gradio Chatbot return history + [[message, response]] except Exception as e: logger.error(f"Chat error: {str(e)}") return history + [["", f"Error: {str(e)}"]] def create_interface() -> gr.Blocks: """Create the Gradio interface""" app = TxAgentApp() with gr.Blocks( title="TxAgent", css=""" .gradio-container {max-width: 900px !important} """ ) as demo: gr.Markdown(""" # 🧠 TxAgent: Therapeutic Reasoning AI ### (Using pre-loaded embeddings) """) with gr.Row(): init_btn = gr.Button("Initialize Model", variant="primary") init_status = gr.Textbox(label="Status", interactive=False) chatbot = gr.Chatbot( height=500, label="Conversation" ) msg = gr.Textbox(label="Your clinical question") clear_btn = gr.Button("Clear Chat") gr.Examples( examples=[ "How to adjust Journavx for renal impairment?", "Xolremdi and Prozac interaction in WHIM syndrome?", "Alternative to Warfarin for patient with amiodarone?" ], inputs=msg ) def wrapper_initialize() -> tuple: """Wrapper for initialization with UI updates""" status = app.initialize() return status, gr.update(interactive=False) init_btn.click( fn=wrapper_initialize, outputs=[init_status, init_btn] ) msg.submit( fn=app.chat, inputs=[msg, chatbot], outputs=chatbot ).then( lambda: "", # Clear message box outputs=msg ) clear_btn.click( fn=lambda: ([], ""), outputs=[chatbot, msg] ) return demo if __name__ == "__main__": try: logger.info("Starting application...") # Verify embedding file exists if not os.path.exists(CONFIG["embedding_filename"]): logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}") logger.info("Please ensure the file is in the root directory") else: logger.info(f"Found embedding file: {CONFIG['embedding_filename']}") # Prepare tool files prepare_tool_files() # Launch interface interface = create_interface() interface.launch( server_name="0.0.0.0", server_port=7860, share=False ) except Exception as e: logger.error(f"Application failed to start: {str(e)}") raise