import os import json import logging import torch from txagent import TxAgent import gradio as gr from huggingface_hub import hf_hub_download, snapshot_download from tooluniverse import ToolUniverse from tqdm import tqdm import time # Configuration CONFIG = { "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt", "local_dir": "./models", "tool_files": { "new_tool": "./data/new_tool.json" }, "download_timeout": 300, # Increased timeout to 5 minutes "max_retries": 3 } # Logging setup logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def prepare_tool_files(): os.makedirs("./data", exist_ok=True) if not os.path.exists(CONFIG["tool_files"]["new_tool"]): logger.info("Generating tool list using ToolUniverse...") tu = ToolUniverse() tools = tu.get_all_tools() with open(CONFIG["tool_files"]["new_tool"], "w") as f: json.dump(tools, f, indent=2) logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}") def download_with_retry(repo_id, local_dir): retry_count = 0 while retry_count < CONFIG["max_retries"]: try: snapshot_download( repo_id=repo_id, local_dir=local_dir, resume_download=True, local_dir_use_symlinks=False, timeout=CONFIG["download_timeout"] ) return True except Exception as e: retry_count += 1 logger.error(f"Attempt {retry_count} failed for {repo_id}: {str(e)}") if retry_count < CONFIG["max_retries"]: wait_time = 10 * retry_count logger.info(f"Waiting {wait_time} seconds before retry...") time.sleep(wait_time) return False def download_model_files(): os.makedirs(CONFIG["local_dir"], exist_ok=True) logger.info("Downloading model files...") # Download main model logger.info(f"Downloading {CONFIG['model_name']}...") if not download_with_retry( CONFIG["model_name"], os.path.join(CONFIG["local_dir"], CONFIG["model_name"]) ): raise RuntimeError(f"Failed to download {CONFIG['model_name']} after {CONFIG['max_retries']} attempts") # Download RAG model logger.info(f"Downloading {CONFIG['rag_model_name']}...") if not download_with_retry( CONFIG["rag_model_name"], os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]) ): raise RuntimeError(f"Failed to download {CONFIG['rag_model_name']} after {CONFIG['max_retries']} attempts") logger.info("All model files downloaded successfully") def load_embeddings(agent): embedding_path = CONFIG["embedding_filename"] if os.path.exists(embedding_path): logger.info("✅ Loading pre-generated embeddings file") try: embeddings = torch.load(embedding_path) agent.rag_model.tool_desc_embedding = embeddings return except Exception as e: logger.error(f"Failed to load embeddings: {e}") # Fall through to generate new embeddings logger.info("Generating tool embeddings...") try: tools = agent.tooluniverse.get_all_tools() descriptions = [tool["description"] for tool in tools] embeddings = agent.rag_model.generate_embeddings(descriptions) torch.save(embeddings, embedding_path) agent.rag_model.tool_desc_embedding = embeddings logger.info(f"Embeddings saved to {embedding_path}") except Exception as e: logger.error(f"Failed to generate embeddings: {e}") raise class TxAgentApp: def __init__(self): self.agent = None self.is_initialized = False def initialize(self): if self.is_initialized: return "Already initialized" try: logger.info("Initializing TxAgent...") self.agent = TxAgent( CONFIG["model_name"], CONFIG["rag_model_name"], tool_files_dict=CONFIG["tool_files"], force_finish=True, enable_checker=True, step_rag_num=10, seed=100, additional_default_tools=["DirectResponse", "RequireClarification"] ) logger.info("Initializing models...") self.agent.init_model() logger.info("Loading embeddings...") load_embeddings(self.agent) self.is_initialized = True logger.info("✅ TxAgent initialized successfully") return "✅ TxAgent initialized successfully" except Exception as e: logger.error(f"Initialization failed: {str(e)}") return f"❌ Initialization failed: {str(e)}" def chat(self, message, history): if not self.is_initialized: return history + [(message, "⚠️ Error: Model not initialized. Please click 'Initialize Model' first.")] try: response = "" for chunk in self.agent.run_gradio_chat( message=message, history=history, temperature=0.3, max_new_tokens=1024, max_tokens=8192, multi_agent=False, conversation=[], max_round=30 ): response += chunk return history + [(message, response)] except Exception as e: logger.error(f"Chat error: {str(e)}") return history + [(message, f"Error: {str(e)}")] def create_interface(): app = TxAgentApp() with gr.Blocks(title="TxAgent", css=".gradio-container {max-width: 900px !important}") as demo: gr.Markdown(""" # 🧠 TxAgent: Therapeutic Reasoning AI ### A specialized AI for clinical decision support and therapeutic reasoning """) with gr.Row(): init_btn = gr.Button("Initialize Model", variant="primary") init_status = gr.Textbox(label="Initialization Status", interactive=False) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(height=600, label="Conversation", bubble_full_width=False) msg = gr.Textbox(label="Your Question", placeholder="Enter your clinical question here...") submit_btn = gr.Button("Submit", variant="primary") with gr.Column(scale=1): gr.Markdown("### Example Questions:") gr.Examples( examples=[ "How to adjust Journavx dosage for hepatic impairment?", "Is Xolremdi safe with Prozac for WHIM syndrome?", "Warfarin-Amiodarone contraindications?", "Alternative treatments for EGFR-positive NSCLC?" ], inputs=msg, label="Click to try" ) init_btn.click( fn=app.initialize, outputs=init_status, api_name="initialize" ) msg.submit( fn=app.chat, inputs=[msg, chatbot], outputs=chatbot, api_name="chat" ) submit_btn.click( fn=app.chat, inputs=[msg, chatbot], outputs=chatbot ) return demo if __name__ == "__main__": try: logger.info("Preparing tool files...") prepare_tool_files() logger.info("Downloading model files (if needed)...") download_model_files() logger.info("Launching interface...") interface = create_interface() interface.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True ) except Exception as e: logger.error(f"Application failed to start: {str(e)}") raise