test / app.py
Ali2206's picture
Update app.py
849209d verified
raw
history blame
9.85 kB
import os
import json
import logging
import torch
from txagent import TxAgent
import gradio as gr
from huggingface_hub import snapshot_download
from tooluniverse import ToolUniverse
import time
from requests.adapters import HTTPAdapter
from requests import Session
from urllib3.util.retry import Retry
from tqdm import tqdm
# Configuration
CONFIG = {
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
"local_dir": "./models",
"tool_files": {
"new_tool": "./data/new_tool.json"
},
"download_settings": {
"timeout": 600, # 10 minutes per request
"max_retries": 5,
"retry_delay": 30, # seconds between retries
"chunk_size": 1024 * 1024 * 10, # 10MB chunks
"max_concurrent": 2 # concurrent downloads
}
}
# Logging setup
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def create_optimized_session():
"""Create a session optimized for large file downloads"""
session = Session()
retry_strategy = Retry(
total=CONFIG["download_settings"]["max_retries"],
backoff_factor=1,
status_forcelist=[408, 429, 500, 502, 503, 504]
)
adapter = HTTPAdapter(
max_retries=retry_strategy,
pool_connections=10,
pool_maxsize=10,
pool_block=True
)
session.mount("https://", adapter)
session.mount("http://", adapter)
return session
def prepare_tool_files():
os.makedirs("./data", exist_ok=True)
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
logger.info("Generating tool list using ToolUniverse...")
tu = ToolUniverse()
tools = tu.get_all_tools()
with open(CONFIG["tool_files"]["new_tool"], "w") as f:
json.dump(tools, f, indent=2)
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
def download_model_with_progress(repo_id, local_dir):
custom_session = create_optimized_session()
for attempt in range(CONFIG["download_settings"]["max_retries"] + 1):
try:
logger.info(f"Download attempt {attempt + 1} for {repo_id}")
# Create progress bar
progress = tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
miniters=1,
desc=f"Downloading {repo_id.split('/')[-1]}"
)
def update_progress(monitor):
progress.update(monitor.bytes_read - progress.n)
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
resume_download=True,
local_dir_use_symlinks=False,
use_auth_token=True,
max_workers=CONFIG["download_settings"]["max_concurrent"],
tqdm_class=None, # We handle progress ourselves
session=custom_session
)
progress.close()
return True
except Exception as e:
logger.error(f"Attempt {attempt + 1} failed: {str(e)}")
if attempt < CONFIG["download_settings"]["max_retries"]:
wait_time = CONFIG["download_settings"]["retry_delay"] * (attempt + 1)
logger.info(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time)
else:
progress.close()
return False
def download_model_files():
os.makedirs(CONFIG["local_dir"], exist_ok=True)
logger.info("Starting model downloads...")
# Download main model
if not download_model_with_progress(
CONFIG["model_name"],
os.path.join(CONFIG["local_dir"], CONFIG["model_name"])
):
raise RuntimeError(f"Failed to download {CONFIG['model_name']}")
# Download RAG model
if not download_model_with_progress(
CONFIG["rag_model_name"],
os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"])
):
raise RuntimeError(f"Failed to download {CONFIG['rag_model_name']}")
logger.info("All model files downloaded successfully")
def load_embeddings(agent):
embedding_path = CONFIG["embedding_filename"]
if os.path.exists(embedding_path):
logger.info("✅ Loading pre-generated embeddings file")
try:
embeddings = torch.load(embedding_path)
agent.rag_model.tool_desc_embedding = embeddings
return
except Exception as e:
logger.error(f"Failed to load embeddings: {e}")
logger.info("Generating tool embeddings...")
try:
tools = agent.tooluniverse.get_all_tools()
descriptions = [tool["description"] for tool in tools]
embeddings = agent.rag_model.generate_embeddings(descriptions)
torch.save(embeddings, embedding_path)
agent.rag_model.tool_desc_embedding = embeddings
logger.info(f"Embeddings saved to {embedding_path}")
except Exception as e:
logger.error(f"Failed to generate embeddings: {e}")
raise
class TxAgentApp:
def __init__(self):
self.agent = None
self.is_initialized = False
def initialize(self):
if self.is_initialized:
return "✅ Already initialized"
try:
# Initialize with progress tracking
with tqdm(total=4, desc="Initializing TxAgent") as pbar:
logger.info("Creating TxAgent instance...")
self.agent = TxAgent(
CONFIG["model_name"],
CONFIG["rag_model_name"],
tool_files_dict=CONFIG["tool_files"],
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
pbar.update(1)
logger.info("Initializing models...")
self.agent.init_model()
pbar.update(1)
logger.info("Loading embeddings...")
load_embeddings(self.agent)
pbar.update(1)
self.is_initialized = True
pbar.update(1)
return "✅ TxAgent initialized successfully"
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
return f"❌ Initialization failed: {str(e)}"
def chat(self, message, history):
if not self.is_initialized:
return history + [(message, "⚠️ Please initialize the model first")]
try:
response = ""
for chunk in self.agent.run_gradio_chat(
message=message,
history=history,
temperature=0.3,
max_new_tokens=1024,
max_tokens=8192,
multi_agent=False,
conversation=[],
max_round=30
):
response += chunk
yield history + [(message, response)]
except Exception as e:
logger.error(f"Chat error: {str(e)}")
yield history + [(message, f"Error: {str(e)}")]
def create_interface():
app = TxAgentApp()
with gr.Blocks(
title="TxAgent",
css="""
.gradio-container {max-width: 900px !important}
.progress-bar {height: 20px !important}
"""
) as demo:
gr.Markdown("""
# � TxAgent: Therapeutic Reasoning AI
### Specialized for clinical decision support
""")
# Initialization section
with gr.Row():
init_btn = gr.Button("Initialize Model", variant="primary")
init_status = gr.Textbox(label="Status", interactive=False)
download_progress = gr.Textbox(visible=False)
# Chat interface
chatbot = gr.Chatbot(height=500, label="Conversation")
msg = gr.Textbox(label="Your clinical question", placeholder="Ask about drug interactions, dosing, etc...")
clear_btn = gr.Button("Clear Chat")
# Examples
gr.Examples(
examples=[
"How to adjust Journavx for renal impairment?",
"Xolremdi and Prozac interaction in WHIM syndrome?",
"Alternative to Warfarin for patient with amiodarone?"
],
inputs=msg,
label="Example Questions"
)
# Event handlers
init_btn.click(
fn=app.initialize,
outputs=init_status
)
msg.submit(
fn=app.chat,
inputs=[msg, chatbot],
outputs=chatbot
)
clear_btn.click(
fn=lambda: ([], ""),
outputs=[chatbot, msg]
)
return demo
if __name__ == "__main__":
try:
logger.info("Starting application setup...")
# Prepare files
prepare_tool_files()
# Download models with progress tracking
download_model_files()
# Launch interface
interface = create_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
logger.error(f"Fatal error: {str(e)}")
raise