import os import json import logging import torch import gradio as gr from tooluniverse import ToolUniverse from transformers import AutoModelForCausalLM, AutoTokenizer import warnings from typing import List, Dict, Any # Suppress specific warnings warnings.filterwarnings("ignore", category=UserWarning) # Configuration CONFIG = { "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt", "tool_files": { "opentarget": "opentarget_tools.json", "fda_drug_label": "fda_drug_labeling_tools.json", "special_tools": "special_tools.json", "monarch": "monarch_tools.json", "new_tool": "./data/new_tool.json" } } # Logging setup logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def prepare_tool_files(): """Ensure tool files exist and are populated""" os.makedirs("./data", exist_ok=True) if not os.path.exists(CONFIG["tool_files"]["new_tool"]): logger.info("Generating tool list using ToolUniverse...") tu = ToolUniverse() tools = tu.get_all_tools() if hasattr(tu, 'get_all_tools') else [] with open(CONFIG["tool_files"]["new_tool"], "w") as f: json.dump(tools, f, indent=2) logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}") def safe_load_embeddings(filepath: str) -> Any: """Safely load embeddings with proper weights_only handling""" try: # First try with weights_only=True (secure mode) return torch.load(filepath, weights_only=True) except Exception as e: logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}") try: # Try with the safe_globals context manager with torch.serialization.safe_globals([torch.serialization._reconstruct]): return torch.load(filepath, weights_only=False) except Exception as e: logger.error(f"Failed to load embeddings even with safe_globals: {str(e)}") return None class TxAgentWrapper: def __init__(self): self.model = None self.tokenizer = None self.rag_model = None self.tooluniverse = None self.is_initialized = False self.special_tools = ['Finish', 'Tool_RAG', 'DirectResponse', 'RequireClarification'] def initialize(self) -> str: """Initialize the model from Hugging Face""" if self.is_initialized: return "✅ Already initialized" try: logger.info("Loading models from Hugging Face Hub...") # Initialize ToolUniverse first self.tooluniverse = ToolUniverse(tool_files=CONFIG["tool_files"]) if hasattr(self.tooluniverse, 'load_tools'): self.tooluniverse.load_tools() logger.info(f"Loaded {len(self.tooluniverse.tools)} tools") else: logger.error("ToolUniverse doesn't have load_tools method") return "❌ Failed to load tools" # Load main model self.tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"]) self.model = AutoModelForCausalLM.from_pretrained( CONFIG["model_name"], device_map="auto", torch_dtype=torch.float16 ) # Load embeddings if file exists if os.path.exists(CONFIG["embedding_filename"]): self.rag_model = safe_load_embeddings(CONFIG["embedding_filename"]) if self.rag_model is None: return "❌ Failed to load embeddings" self.is_initialized = True return "✅ Model initialized successfully" except Exception as e: logger.error(f"Initialization failed: {str(e)}") return f"❌ Initialization failed: {str(e)}" def chat(self, message: str, history: List[List[str]]) -> List[List[str]]: """Handle chat interactions with the model""" if not self.is_initialized: return history + [["", "⚠️ Please initialize the model first"]] try: if len(message) <= 10: return history + [["", "Please provide a more detailed question (at least 10 characters)"]] # Prepare tools prompt tools_prompt = self._prepare_tools_prompt(message) # Format conversation conversation = [ {"role": "system", "content": "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning."}, *self._format_history(history), {"role": "user", "content": message} ] # Generate response inputs = self.tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ).to(self.model.device) outputs = self.model.generate( inputs, max_new_tokens=1024, temperature=0.7, do_sample=True, pad_token_id=self.tokenizer.eos_token_id ) # Decode and clean response response = self.tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) response = response.split("[TOOL_CALLS]")[0].strip() return history + [[message, response]] except Exception as e: logger.error(f"Chat error: {str(e)}") return history + [["", f"Error: {str(e)}"]] def _prepare_tools_prompt(self, message: str) -> str: """Prepare the tools prompt section""" if not hasattr(self.tooluniverse, 'tools'): return "" tools_prompt = "\n\nYou have access to the following tools:\n" for tool in self.tooluniverse.tools: if tool['name'] not in self.special_tools: tools_prompt += f"- {tool['name']}: {tool['description']}\n" # Add special tools tools_prompt += "\nSpecial tools:\n" tools_prompt += "- Finish: Use when you have the final answer\n" tools_prompt += "- Tool_RAG: Search for additional tools when needed\n" return tools_prompt def _format_history(self, history: List[List[str]]) -> List[Dict[str, str]]: """Format chat history for the model""" formatted = [] for user_msg, bot_msg in history: formatted.append({"role": "user", "content": user_msg}) if bot_msg: formatted.append({"role": "assistant", "content": bot_msg}) return formatted def create_interface() -> gr.Blocks: """Create the Gradio interface""" agent = TxAgentWrapper() with gr.Blocks( title="TxAgent", css=""" .gradio-container {max-width: 900px !important} """ ) as demo: gr.Markdown(""" # 🧠 TxAgent: Therapeutic Reasoning AI ### (Loading from Hugging Face Hub) """) with gr.Row(): init_btn = gr.Button("Initialize Model", variant="primary") init_status = gr.Textbox(label="Status", interactive=False) chatbot = gr.Chatbot( height=500, label="Conversation" ) msg = gr.Textbox(label="Your clinical question") clear_btn = gr.Button("Clear Chat") gr.Examples( examples=[ "How to adjust Journavx for renal impairment?", "Xolremdi and Prozac interaction in WHIM syndrome?", "Alternative to Warfarin for patient with amiodarone?" ], inputs=msg ) def wrapper_initialize(): status = agent.initialize() return status, gr.update(interactive=False) init_btn.click( fn=wrapper_initialize, outputs=[init_status, init_btn] ) msg.submit( fn=agent.chat, inputs=[msg, chatbot], outputs=chatbot ).then( lambda: "", # Clear message box outputs=msg ) clear_btn.click( fn=lambda: ([], ""), outputs=[chatbot, msg] ) return demo if __name__ == "__main__": try: logger.info("Starting application...") # Verify embedding file exists if not os.path.exists(CONFIG["embedding_filename"]): logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}") logger.info("Please ensure the file is in the root directory") else: logger.info(f"Found embedding file: {CONFIG['embedding_filename']}") # Prepare tool files prepare_tool_files() # Launch interface interface = create_interface() interface.launch( server_name="0.0.0.0", server_port=7860, share=False ) except Exception as e: logger.error(f"Application failed to start: {str(e)}") raise