test / app.py
Ali2206's picture
Update app.py
d206f24 verified
raw
history blame
8.89 kB
import os
import json
import logging
import torch
from txagent import TxAgent
import gradio as gr
from huggingface_hub import snapshot_download
from tooluniverse import ToolUniverse
import time
from functools import partial
from requests.adapters import HTTPAdapter
from requests import Session
from urllib3.util.retry import Retry
# Configuration
CONFIG = {
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
"local_dir": "./models",
"tool_files": {
"new_tool": "./data/new_tool.json"
},
"download_timeout": 300, # 5 minutes timeout
"max_retries": 3,
"retry_delay": 10 # seconds between retries
}
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def create_custom_session():
"""Create a session with custom timeout and retry settings"""
session = Session()
retries = Retry(
total=CONFIG["max_retries"],
backoff_factor=1,
status_forcelist=[500, 502, 503, 504]
)
adapter = HTTPAdapter(
max_retries=retries,
pool_connections=10,
pool_maxsize=10
)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
def prepare_tool_files():
os.makedirs("./data", exist_ok=True)
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
logger.info("Generating tool list using ToolUniverse...")
tu = ToolUniverse()
tools = tu.get_all_tools()
with open(CONFIG["tool_files"]["new_tool"], "w") as f:
json.dump(tools, f, indent=2)
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
def download_with_retry(repo_id, local_dir):
retry_count = 0
custom_session = create_custom_session()
while retry_count < CONFIG["max_retries"]:
try:
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
resume_download=True,
local_dir_use_symlinks=False,
use_auth_token=True,
session=custom_session
)
return True
except Exception as e:
retry_count += 1
logger.error(f"Attempt {retry_count} failed for {repo_id}: {str(e)}")
if retry_count < CONFIG["max_retries"]:
wait_time = CONFIG["retry_delay"] * retry_count
logger.info(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time)
return False
def download_model_files():
os.makedirs(CONFIG["local_dir"], exist_ok=True)
logger.info("Downloading model files...")
# Download main model
logger.info(f"Downloading {CONFIG['model_name']}...")
if not download_with_retry(
CONFIG["model_name"],
os.path.join(CONFIG["local_dir"], CONFIG["model_name"])
):
raise RuntimeError(f"Failed to download {CONFIG['model_name']} after {CONFIG['max_retries']} attempts")
# Download RAG model
logger.info(f"Downloading {CONFIG['rag_model_name']}...")
if not download_with_retry(
CONFIG["rag_model_name"],
os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"])
):
raise RuntimeError(f"Failed to download {CONFIG['rag_model_name']} after {CONFIG['max_retries']} attempts")
logger.info("All model files downloaded successfully")
def load_embeddings(agent):
embedding_path = CONFIG["embedding_filename"]
if os.path.exists(embedding_path):
logger.info("✅ Loading pre-generated embeddings file")
try:
embeddings = torch.load(embedding_path)
agent.rag_model.tool_desc_embedding = embeddings
return
except Exception as e:
logger.error(f"Failed to load embeddings: {e}")
# Fall through to generate new embeddings
logger.info("Generating tool embeddings...")
try:
tools = agent.tooluniverse.get_all_tools()
descriptions = [tool["description"] for tool in tools]
embeddings = agent.rag_model.generate_embeddings(descriptions)
torch.save(embeddings, embedding_path)
agent.rag_model.tool_desc_embedding = embeddings
logger.info(f"Embeddings saved to {embedding_path}")
except Exception as e:
logger.error(f"Failed to generate embeddings: {e}")
raise
class TxAgentApp:
def __init__(self):
self.agent = None
self.is_initialized = False
def initialize(self):
if self.is_initialized:
return "Already initialized"
try:
logger.info("Initializing TxAgent...")
self.agent = TxAgent(
CONFIG["model_name"],
CONFIG["rag_model_name"],
tool_files_dict=CONFIG["tool_files"],
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
logger.info("Initializing models...")
self.agent.init_model()
logger.info("Loading embeddings...")
load_embeddings(self.agent)
self.is_initialized = True
logger.info("✅ TxAgent initialized successfully")
return "✅ TxAgent initialized successfully"
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
return f"❌ Initialization failed: {str(e)}"
def chat(self, message, history):
if not self.is_initialized:
return history + [(message, "⚠️ Error: Model not initialized. Please click 'Initialize Model' first.")]
try:
response = ""
for chunk in self.agent.run_gradio_chat(
message=message,
history=history,
temperature=0.3,
max_new_tokens=1024,
max_tokens=8192,
multi_agent=False,
conversation=[],
max_round=30
):
response += chunk
return history + [(message, response)]
except Exception as e:
logger.error(f"Chat error: {str(e)}")
return history + [(message, f"Error: {str(e)}")]
def create_interface():
app = TxAgentApp()
with gr.Blocks(title="TxAgent", css=".gradio-container {max-width: 900px !important}") as demo:
gr.Markdown("""
# 🧠 TxAgent: Therapeutic Reasoning AI
### A specialized AI for clinical decision support and therapeutic reasoning
""")
with gr.Row():
init_btn = gr.Button("Initialize Model", variant="primary")
init_status = gr.Textbox(label="Initialization Status", interactive=False)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=600, label="Conversation", bubble_full_width=False)
msg = gr.Textbox(label="Your Question", placeholder="Enter your clinical question here...")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### Example Questions:")
gr.Examples(
examples=[
"How to adjust Journavx dosage for hepatic impairment?",
"Is Xolremdi safe with Prozac for WHIM syndrome?",
"Warfarin-Amiodarone contraindications?",
"Alternative treatments for EGFR-positive NSCLC?"
],
inputs=msg,
label="Click to try"
)
init_btn.click(
fn=app.initialize,
outputs=init_status,
api_name="initialize"
)
msg.submit(
fn=app.chat,
inputs=[msg, chatbot],
outputs=chatbot,
api_name="chat"
)
submit_btn.click(
fn=app.chat,
inputs=[msg, chatbot],
outputs=chatbot
)
return demo
if __name__ == "__main__":
try:
logger.info("Preparing tool files...")
prepare_tool_files()
logger.info("Downloading model files (if needed)...")
download_model_files()
logger.info("Launching interface...")
interface = create_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)
except Exception as e:
logger.error(f"Application failed to start: {str(e)}")
raise