import os import logging import torch import gradio as gr from txagent import TxAgent # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Configuration MODEL_NAME = "mims-harvard/TxAgent-T1-Llama-3.1-8B" RAG_MODEL_NAME = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B" TOOL_FILE = "data/new_tool.json" # Environment setup os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["CUDA_MODULE_LOADING"] = "LAZY" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" class TxAgentSystem: def __init__(self): self.agent = None self.is_initialized = False self.examples = [ ["A 68-year-old with CKD prescribed metformin. Safe for renal clearance?"], ["30-year-old on Prozac diagnosed with WHIM. Safe to take Xolremdi?"] ] if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available - GPU required") logger.info(f"GPU: {torch.cuda.get_device_name(0)}") logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB") self._initialize_system() def _initialize_system(self): try: os.makedirs("data", exist_ok=True) if not os.path.exists(TOOL_FILE): with open(TOOL_FILE, "w") as f: f.write("[]") logger.info("Initializing TxAgent...") # Initialize with RAG disabled first try: self.agent = TxAgent( model_name=MODEL_NAME, rag_model_name=RAG_MODEL_NAME, tool_files_dict={"new_tool": TOOL_FILE}, force_finish=True, enable_checker=True, step_rag_num=10, seed=100, enable_rag=True ) except Exception as e: logger.warning(f"Failed to initialize with RAG: {str(e)}") logger.info("Retrying without RAG...") self.agent = TxAgent( model_name=MODEL_NAME, rag_model_name=None, tool_files_dict={"new_tool": TOOL_FILE}, force_finish=True, enable_checker=True, step_rag_num=0, seed=100, enable_rag=False ) logger.info("Loading main model...") self.agent.init_model() self.is_initialized = True logger.info("System initialization completed successfully") except Exception as e: logger.error(f"System initialization failed: {str(e)}") self.is_initialized = False raise def chat_fn(self, message, history, temperature, max_tokens, rag_depth): if not self.is_initialized: return "", history + [(message, "System initialization failed. Please check logs.")] try: response = self.agent.run_gradio_chat( message=message, history=history, temperature=temperature, max_new_tokens=max_tokens, max_total_tokens=16384, enable_multi_agent=False, conv_history=history, max_steps=rag_depth, seed=100 ) new_history = history + [(message, response)] return "", new_history except torch.cuda.OutOfMemoryError: torch.cuda.empty_cache() return "", history + [(message, "⚠️ GPU memory overflow. Please try a shorter query.")] except Exception as e: logger.error(f"Chat error: {str(e)}") return "", history + [(message, f"🚨 Error: {str(e)}")] def launch_ui(self): with gr.Blocks(theme=gr.themes.Soft(), title="TxAgent Medical AI") as demo: gr.Markdown("## 🧠 TxAgent (A100/H100 Optimized)") status = gr.Textbox( value="✅ System ready" if self.is_initialized else "❌ Initialization failed", label="System Status", interactive=False ) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(height=600, label="Conversation History") msg = gr.Textbox(label="Enter Medical Query", placeholder="Type your question here...") with gr.Column(scale=1): temp = gr.Slider(0.1, 1.0, value=0.7, label="Temperature") max_tokens = gr.Slider(128, 8192, value=2048, label="Max Response Tokens") rag_depth = gr.Slider(1, 20, value=10, label="RAG Depth") clear_btn = gr.Button("Clear History") gr.Examples( examples=self.examples, inputs=msg, label="Example Queries" ) msg.submit( self.chat_fn, inputs=[msg, chatbot, temp, max_tokens, rag_depth], outputs=[msg, chatbot] ) clear_btn.click(lambda: None, None, chatbot, queue=False) demo.launch( server_name="0.0.0.0", server_port=7860 ) if __name__ == "__main__": try: system = TxAgentSystem() system.launch_ui() except Exception as e: logger.critical(f"Fatal error: {str(e)}") raise