test / app.py
Ali2206's picture
Update app.py
59ced24 verified
raw
history blame
6.43 kB
import os
import torch
import requests
from huggingface_hub import hf_hub_download, snapshot_download
from txagent import TxAgent
import gradio as gr
# Configuration
CONFIG = {
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_e27fb393f3144ec28f620f33d4d79911.pt",
"local_dir": "./models",
"tool_files": {
'new_tool': './data/new_tool.json',
'opentarget': './data/opentarget_tools.json',
'fda_drug_label': './data/fda_drug_labeling_tools.json',
'special_tools': './data/special_tools.json',
'monarch': './data/monarch_tools.json'
}
}
def download_model_files():
"""Download all required model files from Hugging Face Hub"""
os.makedirs(CONFIG["local_dir"], exist_ok=True)
os.makedirs("./data", exist_ok=True)
print("Downloading model files...")
# Download main model
snapshot_download(
repo_id=CONFIG["model_name"],
local_dir=os.path.join(CONFIG["local_dir"], CONFIG["model_name"]),
resume_download=True
)
# Download RAG model
snapshot_download(
repo_id=CONFIG["rag_model_name"],
local_dir=os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]),
resume_download=True
)
# Try to download the embeddings file
try:
hf_hub_download(
repo_id=CONFIG["rag_model_name"],
filename=CONFIG["embedding_filename"],
local_dir=CONFIG["local_dir"],
resume_download=True
)
print("Embeddings file downloaded successfully")
except Exception as e:
print(f"Could not download embeddings file: {e}")
print("Will attempt to generate it instead")
def generate_embeddings(agent):
"""Generate and save tool embeddings if missing"""
embedding_path = os.path.join(CONFIG["local_dir"], CONFIG["embedding_filename"])
if os.path.exists(embedding_path):
print("Embeddings file already exists")
return
print("Generating missing tool embeddings...")
try:
# Get all tools from the tool universe
tools = agent.tooluniverse.get_all_tools()
tool_descriptions = [tool['description'] for tool in tools]
# Generate embeddings using the RAG model
embeddings = agent.rag_model.generate_embeddings(tool_descriptions)
# Save the embeddings
torch.save(embeddings, embedding_path)
print(f"Embeddings saved to {embedding_path}")
# Update the RAG model to use the new embeddings
agent.rag_model.tool_desc_embedding = embeddings
except Exception as e:
print(f"Failed to generate embeddings: {e}")
raise
class TxAgentApp:
def __init__(self):
self.agent = None
self.is_initialized = False
def initialize(self):
if self.is_initialized:
return "Already initialized"
try:
# Initialize the agent
self.agent = TxAgent(
CONFIG["model_name"],
CONFIG["rag_model_name"],
tool_files_dict=CONFIG["tool_files"],
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=['DirectResponse', 'RequireClarification']
)
# Initialize model
self.agent.init_model()
# Handle embeddings
generate_embeddings(self.agent)
self.is_initialized = True
return "TxAgent initialized successfully"
except Exception as e:
return f"Initialization failed: {str(e)}"
def chat(self, message, history):
if not self.is_initialized:
return history + [(message, "Error: Please initialize the model first")]
try:
# Convert history to messages format
messages = []
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
# Get response
response = ""
for chunk in self.agent.run_gradio_chat(
messages,
temperature=0.3,
max_new_tokens=1024,
max_tokens=8192,
multi_agent=False,
conversation=[],
max_round=30
):
response += chunk
return history + [(message, response)]
except Exception as e:
return history + [(message, f"Error: {str(e)}")]
def create_interface():
app = TxAgentApp()
with gr.Blocks(title="TxAgent") as demo:
gr.Markdown("# TxAgent: Therapeutic Reasoning AI")
# Initialization
with gr.Row():
init_btn = gr.Button("Initialize Model", variant="primary")
init_status = gr.Textbox(label="Initialization Status")
# Chat interface
chatbot = gr.Chatbot(height=600)
msg = gr.Textbox(label="Your Question")
submit_btn = gr.Button("Submit")
# Examples
gr.Examples(
examples=[
"How to adjust Journavx dosage for hepatic impairment?",
"Is Xolremdi safe with Prozac for WHIM syndrome?",
"Warfarin-Amiodarone contraindications?"
],
inputs=msg
)
# Event handlers
init_btn.click(
app.initialize,
outputs=init_status
)
def respond(message, chat_history):
return app.chat(message, chat_history)
msg.submit(respond, [msg, chatbot], chatbot)
submit_btn.click(respond, [msg, chatbot], chatbot)
return demo
if __name__ == "__main__":
# First download all required files
download_model_files()
# Then create and launch the interface
interface = create_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)