test / app.py
Ali2206's picture
Create app.py
70839bb verified
raw
history blame
5.59 kB
import os
import logging
import torch
import gradio as gr
from txagent import TxAgent
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Configuration
MODEL_NAME = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
RAG_MODEL_NAME = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
TOOL_FILE = "data/new_tool.json"
# Environment setup
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
class TxAgentSystem:
def __init__(self):
self.agent = None
self.is_initialized = False
self.examples = [
["A 68-year-old with CKD prescribed metformin. Safe for renal clearance?"],
["30-year-old on Prozac diagnosed with WHIM. Safe to take Xolremdi?"]
]
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available - GPU required")
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
self._initialize_system()
def _initialize_system(self):
try:
os.makedirs("data", exist_ok=True)
if not os.path.exists(TOOL_FILE):
with open(TOOL_FILE, "w") as f:
f.write("[]")
logger.info("Initializing TxAgent...")
# Initialize with RAG disabled first
try:
self.agent = TxAgent(
model_name=MODEL_NAME,
rag_model_name=RAG_MODEL_NAME,
tool_files_dict={"new_tool": TOOL_FILE},
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
enable_rag=True
)
except Exception as e:
logger.warning(f"Failed to initialize with RAG: {str(e)}")
logger.info("Retrying without RAG...")
self.agent = TxAgent(
model_name=MODEL_NAME,
rag_model_name=None,
tool_files_dict={"new_tool": TOOL_FILE},
force_finish=True,
enable_checker=True,
step_rag_num=0,
seed=100,
enable_rag=False
)
logger.info("Loading main model...")
self.agent.init_model()
self.is_initialized = True
logger.info("System initialization completed successfully")
except Exception as e:
logger.error(f"System initialization failed: {str(e)}")
self.is_initialized = False
raise
def chat_fn(self, message, history, temperature, max_tokens, rag_depth):
if not self.is_initialized:
return "", history + [(message, "System initialization failed. Please check logs.")]
try:
response = self.agent.run_gradio_chat(
message=message,
history=history,
temperature=temperature,
max_new_tokens=max_tokens,
max_total_tokens=16384,
enable_multi_agent=False,
conv_history=history,
max_steps=rag_depth,
seed=100
)
new_history = history + [(message, response)]
return "", new_history
except torch.cuda.OutOfMemoryError:
torch.cuda.empty_cache()
return "", history + [(message, "⚠️ GPU memory overflow. Please try a shorter query.")]
except Exception as e:
logger.error(f"Chat error: {str(e)}")
return "", history + [(message, f"🚨 Error: {str(e)}")]
def launch_ui(self):
with gr.Blocks(theme=gr.themes.Soft(), title="TxAgent Medical AI") as demo:
gr.Markdown("## 🧠 TxAgent (A100/H100 Optimized)")
status = gr.Textbox(
value="βœ… System ready" if self.is_initialized else "❌ Initialization failed",
label="System Status",
interactive=False
)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=600, label="Conversation History")
msg = gr.Textbox(label="Enter Medical Query", placeholder="Type your question here...")
with gr.Column(scale=1):
temp = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
max_tokens = gr.Slider(128, 8192, value=2048, label="Max Response Tokens")
rag_depth = gr.Slider(1, 20, value=10, label="RAG Depth")
clear_btn = gr.Button("Clear History")
gr.Examples(
examples=self.examples,
inputs=msg,
label="Example Queries"
)
msg.submit(
self.chat_fn,
inputs=[msg, chatbot, temp, max_tokens, rag_depth],
outputs=[msg, chatbot]
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
demo.launch(
server_name="0.0.0.0",
server_port=7860
)
if __name__ == "__main__":
try:
system = TxAgentSystem()
system.launch_ui()
except Exception as e:
logger.critical(f"Fatal error: {str(e)}")
raise