import os import json import logging import torch import gradio as gr from tooluniverse import ToolUniverse from transformers import AutoModelForCausalLM, AutoTokenizer import warnings from typing import List, Dict, Any from importlib.resources import files # Suppress specific warnings warnings.filterwarnings("ignore", category=UserWarning) # Configuration CONFIG = { "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt", "tool_files": { "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')), "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')), "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')), "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')), "new_tool": "./data/new_tool.json" } } # Logging setup logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def prepare_tool_files(): """Ensure tool files exist and are populated""" os.makedirs("./data", exist_ok=True) if not os.path.exists(CONFIG["tool_files"]["new_tool"]): logger.info("Generating tool list using ToolUniverse...") tu = ToolUniverse() tools = tu.get_all_tools() if hasattr(tu, 'get_all_tools') else [] with open(CONFIG["tool_files"]["new_tool"], "w") as f: json.dump(tools, f, indent=2) logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}") def safe_load_embeddings(filepath: str) -> Any: """Safely load embeddings with proper weights_only handling""" try: # First try with weights_only=True (secure mode) return torch.load(filepath, weights_only=True) except Exception as e: logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}") try: # Try with the safe_globals context manager with torch.serialization.safe_globals([torch.serialization._reconstruct]): return torch.load(filepath, weights_only=False) except Exception as e: logger.error(f"Failed to load embeddings even with safe_globals: {str(e)}") return None class TxAgentWrapper: def __init__(self): self.model = None self.tokenizer = None self.rag_model = None self.tooluniverse = None self.is_initialized = False self.special_tools = ['Finish', 'Tool_RAG', 'DirectResponse', 'RequireClarification'] def initialize(self) -> str: """Initialize the model from Hugging Face""" if self.is_initialized: return "✅ Already initialized" try: logger.info("Loading models from Hugging Face Hub...") # Verify tool files exist for tool_name, tool_path in CONFIG["tool_files"].items(): if tool_name != "new_tool" and not os.path.exists(tool_path): raise FileNotFoundError(f"Tool file not found: {tool_path}") # Initialize ToolUniverse with verified paths self.tooluniverse = ToolUniverse(tool_files=CONFIG["tool_files"]) if hasattr(self.tooluniverse, 'load_tools'): self.tooluniverse.load_tools() logger.info(f"Loaded {len(self.tooluniverse.tools)} tools") else: logger.error("ToolUniverse doesn't have load_tools method") return "❌ Failed to load tools" # Load main model self.tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"]) self.model = AutoModelForCausalLM.from_pretrained( CONFIG["model_name"], device_map="auto", torch_dtype=torch.float16 ) # Load embeddings if file exists if os.path.exists(CONFIG["embedding_filename"]): self.rag_model = safe_load_embeddings(CONFIG["embedding_filename"]) if self.rag_model is None: return "❌ Failed to load embeddings" self.is_initialized = True return "✅ Model initialized successfully" except Exception as e: logger.error(f"Initialization failed: {str(e)}") return f"❌ Initialization failed: {str(e)}" def chat(self, message: str, history: List[List[str]]) -> List[List[str]]: """Handle chat interactions with the model""" if not self.is_initialized: return history + [["", "⚠️ Please initialize the model first"]] try: if len(message) <= 10: return history + [["", "Please provide a more detailed question (at least 10 characters)"]] # Prepare tools prompt tools_prompt = self._prepare_tools_prompt(message) # Format conversation conversation = [ {"role": "system", "content": "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning." + tools_prompt}, *self._format_history(history), {"role": "user", "content": message} ] # Generate response inputs = self.tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ).to(self.model.device) outputs = self.model.generate( inputs, max_new_tokens=1024, temperature=0.7, do_sample=True, pad_token_id=self.tokenizer.eos_token_id ) # Decode and clean response response = self.tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) response = response.split("[TOOL_CALLS]")[0].strip() return history + [[message, response]] except Exception as e: logger.error(f"Chat error: {str(e)}") return history + [["", f"Error: {str(e)}"]] def _prepare_tools_prompt(self, message: str) -> str: """Prepare the tools prompt section""" if not hasattr(self.tooluniverse, 'tools'): return "" tools_prompt = "\n\nYou have access to the following tools:\n" for tool in self.tooluniverse.tools: if tool['name'] not in self.special_tools: tools_prompt += f"- {tool['name']}: {tool['description']}\n" # Add special tools tools_prompt += "\nSpecial tools:\n" tools_prompt += "- Finish: Use when you have the final answer\n" tools_prompt += "- Tool_RAG: Search for additional tools when needed\n" return tools_prompt def _format_history(self, history: List[List[str]]) -> List[Dict[str, str]]: """Format chat history for the model""" formatted = [] for user_msg, bot_msg in history: formatted.append({"role": "user", "content": user_msg}) if bot_msg: formatted.append({"role": "assistant", "content": bot_msg}) return formatted def create_interface() -> gr.Blocks: """Create the Gradio interface""" agent = TxAgentWrapper() with gr.Blocks( title="TxAgent", css=""" .gradio-container {max-width: 900px !important} """ ) as demo: gr.Markdown(""" # 🧠 TxAgent: Therapeutic Reasoning AI ### (Loading from Hugging Face Hub) """) with gr.Row(): init_btn = gr.Button("Initialize Model", variant="primary") init_status = gr.Textbox(label="Status", interactive=False) chatbot = gr.Chatbot( height=500, label="Conversation" ) msg = gr.Textbox(label="Your clinical question") clear_btn = gr.Button("Clear Chat") gr.Examples( examples=[ "How to adjust Journavx for renal impairment?", "Xolremdi and Prozac interaction in WHIM syndrome?", "Alternative to Warfarin for patient with amiodarone?" ], inputs=msg ) def wrapper_initialize(): status = agent.initialize() return status, gr.update(interactive=False) init_btn.click( fn=wrapper_initialize, outputs=[init_status, init_btn] ) msg.submit( fn=agent.chat, inputs=[msg, chatbot], outputs=chatbot ).then( lambda: "", # Clear message box outputs=msg ) clear_btn.click( fn=lambda: ([], ""), outputs=[chatbot, msg] ) return demo if __name__ == "__main__": try: logger.info("Starting application...") # Verify embedding file exists if not os.path.exists(CONFIG["embedding_filename"]): logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}") logger.info("Please ensure the file is in the root directory") else: logger.info(f"Found embedding file: {CONFIG['embedding_filename']}") # Prepare tool files prepare_tool_files() # Launch interface interface = create_interface() interface.launch( server_name="0.0.0.0", server_port=7860, share=False ) except Exception as e: logger.error(f"Application failed to start: {str(e)}") raise