test / app.py
Ali2206's picture
Update app.py
9438945 verified
raw
history blame
10.1 kB
import os
import json
import logging
import torch
import gradio as gr
from tooluniverse import ToolUniverse
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
from typing import List, Dict, Any
from importlib.resources import files
# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning)
# Configuration
CONFIG = {
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
"tool_files": {
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
"special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
"monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')),
"new_tool": "./data/new_tool.json"
}
}
# Logging setup
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def prepare_tool_files():
"""Ensure tool files exist and are populated"""
os.makedirs("./data", exist_ok=True)
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
logger.info("Generating tool list using ToolUniverse...")
tu = ToolUniverse()
tools = tu.get_all_tools() if hasattr(tu, 'get_all_tools') else []
with open(CONFIG["tool_files"]["new_tool"], "w") as f:
json.dump(tools, f, indent=2)
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
def safe_load_embeddings(filepath: str) -> Any:
"""Safely load embeddings with proper weights_only handling"""
try:
# First try with weights_only=True (secure mode)
return torch.load(filepath, weights_only=True)
except Exception as e:
logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
try:
# Try with the safe_globals context manager
with torch.serialization.safe_globals([torch.serialization._reconstruct]):
return torch.load(filepath, weights_only=False)
except Exception as e:
logger.error(f"Failed to load embeddings even with safe_globals: {str(e)}")
return None
class TxAgentWrapper:
def __init__(self):
self.model = None
self.tokenizer = None
self.rag_model = None
self.tooluniverse = None
self.is_initialized = False
self.special_tools = ['Finish', 'Tool_RAG', 'DirectResponse', 'RequireClarification']
def initialize(self) -> str:
"""Initialize the model from Hugging Face"""
if self.is_initialized:
return "βœ… Already initialized"
try:
logger.info("Loading models from Hugging Face Hub...")
# Verify tool files exist
for tool_name, tool_path in CONFIG["tool_files"].items():
if tool_name != "new_tool" and not os.path.exists(tool_path):
raise FileNotFoundError(f"Tool file not found: {tool_path}")
# Initialize ToolUniverse with verified paths
self.tooluniverse = ToolUniverse(tool_files=CONFIG["tool_files"])
if hasattr(self.tooluniverse, 'load_tools'):
self.tooluniverse.load_tools()
logger.info(f"Loaded {len(self.tooluniverse.tools)} tools")
else:
logger.error("ToolUniverse doesn't have load_tools method")
return "❌ Failed to load tools"
# Load main model
self.tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"])
self.model = AutoModelForCausalLM.from_pretrained(
CONFIG["model_name"],
device_map="auto",
torch_dtype=torch.float16
)
# Load embeddings if file exists
if os.path.exists(CONFIG["embedding_filename"]):
self.rag_model = safe_load_embeddings(CONFIG["embedding_filename"])
if self.rag_model is None:
return "❌ Failed to load embeddings"
self.is_initialized = True
return "βœ… Model initialized successfully"
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
return f"❌ Initialization failed: {str(e)}"
def chat(self, message: str, history: List[List[str]]) -> List[List[str]]:
"""Handle chat interactions with the model"""
if not self.is_initialized:
return history + [["", "⚠️ Please initialize the model first"]]
try:
if len(message) <= 10:
return history + [["", "Please provide a more detailed question (at least 10 characters)"]]
# Prepare tools prompt
tools_prompt = self._prepare_tools_prompt(message)
# Format conversation
conversation = [
{"role": "system", "content": "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning." + tools_prompt},
*self._format_history(history),
{"role": "user", "content": message}
]
# Generate response
inputs = self.tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
).to(self.model.device)
outputs = self.model.generate(
inputs,
max_new_tokens=1024,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode and clean response
response = self.tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
response = response.split("[TOOL_CALLS]")[0].strip()
return history + [[message, response]]
except Exception as e:
logger.error(f"Chat error: {str(e)}")
return history + [["", f"Error: {str(e)}"]]
def _prepare_tools_prompt(self, message: str) -> str:
"""Prepare the tools prompt section"""
if not hasattr(self.tooluniverse, 'tools'):
return ""
tools_prompt = "\n\nYou have access to the following tools:\n"
for tool in self.tooluniverse.tools:
if tool['name'] not in self.special_tools:
tools_prompt += f"- {tool['name']}: {tool['description']}\n"
# Add special tools
tools_prompt += "\nSpecial tools:\n"
tools_prompt += "- Finish: Use when you have the final answer\n"
tools_prompt += "- Tool_RAG: Search for additional tools when needed\n"
return tools_prompt
def _format_history(self, history: List[List[str]]) -> List[Dict[str, str]]:
"""Format chat history for the model"""
formatted = []
for user_msg, bot_msg in history:
formatted.append({"role": "user", "content": user_msg})
if bot_msg:
formatted.append({"role": "assistant", "content": bot_msg})
return formatted
def create_interface() -> gr.Blocks:
"""Create the Gradio interface"""
agent = TxAgentWrapper()
with gr.Blocks(
title="TxAgent",
css="""
.gradio-container {max-width: 900px !important}
"""
) as demo:
gr.Markdown("""
# 🧠 TxAgent: Therapeutic Reasoning AI
### (Loading from Hugging Face Hub)
""")
with gr.Row():
init_btn = gr.Button("Initialize Model", variant="primary")
init_status = gr.Textbox(label="Status", interactive=False)
chatbot = gr.Chatbot(
height=500,
label="Conversation"
)
msg = gr.Textbox(label="Your clinical question")
clear_btn = gr.Button("Clear Chat")
gr.Examples(
examples=[
"How to adjust Journavx for renal impairment?",
"Xolremdi and Prozac interaction in WHIM syndrome?",
"Alternative to Warfarin for patient with amiodarone?"
],
inputs=msg
)
def wrapper_initialize():
status = agent.initialize()
return status, gr.update(interactive=False)
init_btn.click(
fn=wrapper_initialize,
outputs=[init_status, init_btn]
)
msg.submit(
fn=agent.chat,
inputs=[msg, chatbot],
outputs=chatbot
).then(
lambda: "", # Clear message box
outputs=msg
)
clear_btn.click(
fn=lambda: ([], ""),
outputs=[chatbot, msg]
)
return demo
if __name__ == "__main__":
try:
logger.info("Starting application...")
# Verify embedding file exists
if not os.path.exists(CONFIG["embedding_filename"]):
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
logger.info("Please ensure the file is in the root directory")
else:
logger.info(f"Found embedding file: {CONFIG['embedding_filename']}")
# Prepare tool files
prepare_tool_files()
# Launch interface
interface = create_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
logger.error(f"Application failed to start: {str(e)}")
raise