File size: 6,434 Bytes
4b0f1a8 b8c0ae3 59ced24 12efdad 59ced24 4b0f1a8 12efdad 70839bb 59ced24 12efdad 59ced24 12efdad 59ced24 4b0f1a8 12efdad 59ced24 4b0f1a8 59ced24 8e533b3 4b0f1a8 59ced24 4b0f1a8 59ced24 4b0f1a8 a59a7be 59ced24 a59a7be 4b0f1a8 92abf33 a59a7be 4b0f1a8 92abf33 4b0f1a8 59ced24 4b0f1a8 59ced24 e014e82 4b0f1a8 59ced24 4b0f1a8 59ced24 4b0f1a8 59ced24 e014e82 4b0f1a8 92abf33 4b0f1a8 e014e82 59ced24 4b0f1a8 59ced24 4b0f1a8 59ced24 4b0f1a8 59ced24 8e533b3 92abf33 4b0f1a8 59ced24 8e533b3 59ced24 4b0f1a8 59ced24 4b0f1a8 59ced24 4b0f1a8 59ced24 4b0f1a8 8e533b3 59ced24 4b0f1a8 59ced24 8e533b3 70839bb 59ced24 a59a7be 59ced24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import os
import torch
import requests
from huggingface_hub import hf_hub_download, snapshot_download
from txagent import TxAgent
import gradio as gr
# Configuration
CONFIG = {
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_e27fb393f3144ec28f620f33d4d79911.pt",
"local_dir": "./models",
"tool_files": {
'new_tool': './data/new_tool.json',
'opentarget': './data/opentarget_tools.json',
'fda_drug_label': './data/fda_drug_labeling_tools.json',
'special_tools': './data/special_tools.json',
'monarch': './data/monarch_tools.json'
}
}
def download_model_files():
"""Download all required model files from Hugging Face Hub"""
os.makedirs(CONFIG["local_dir"], exist_ok=True)
os.makedirs("./data", exist_ok=True)
print("Downloading model files...")
# Download main model
snapshot_download(
repo_id=CONFIG["model_name"],
local_dir=os.path.join(CONFIG["local_dir"], CONFIG["model_name"]),
resume_download=True
)
# Download RAG model
snapshot_download(
repo_id=CONFIG["rag_model_name"],
local_dir=os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]),
resume_download=True
)
# Try to download the embeddings file
try:
hf_hub_download(
repo_id=CONFIG["rag_model_name"],
filename=CONFIG["embedding_filename"],
local_dir=CONFIG["local_dir"],
resume_download=True
)
print("Embeddings file downloaded successfully")
except Exception as e:
print(f"Could not download embeddings file: {e}")
print("Will attempt to generate it instead")
def generate_embeddings(agent):
"""Generate and save tool embeddings if missing"""
embedding_path = os.path.join(CONFIG["local_dir"], CONFIG["embedding_filename"])
if os.path.exists(embedding_path):
print("Embeddings file already exists")
return
print("Generating missing tool embeddings...")
try:
# Get all tools from the tool universe
tools = agent.tooluniverse.get_all_tools()
tool_descriptions = [tool['description'] for tool in tools]
# Generate embeddings using the RAG model
embeddings = agent.rag_model.generate_embeddings(tool_descriptions)
# Save the embeddings
torch.save(embeddings, embedding_path)
print(f"Embeddings saved to {embedding_path}")
# Update the RAG model to use the new embeddings
agent.rag_model.tool_desc_embedding = embeddings
except Exception as e:
print(f"Failed to generate embeddings: {e}")
raise
class TxAgentApp:
def __init__(self):
self.agent = None
self.is_initialized = False
def initialize(self):
if self.is_initialized:
return "Already initialized"
try:
# Initialize the agent
self.agent = TxAgent(
CONFIG["model_name"],
CONFIG["rag_model_name"],
tool_files_dict=CONFIG["tool_files"],
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=['DirectResponse', 'RequireClarification']
)
# Initialize model
self.agent.init_model()
# Handle embeddings
generate_embeddings(self.agent)
self.is_initialized = True
return "TxAgent initialized successfully"
except Exception as e:
return f"Initialization failed: {str(e)}"
def chat(self, message, history):
if not self.is_initialized:
return history + [(message, "Error: Please initialize the model first")]
try:
# Convert history to messages format
messages = []
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
# Get response
response = ""
for chunk in self.agent.run_gradio_chat(
messages,
temperature=0.3,
max_new_tokens=1024,
max_tokens=8192,
multi_agent=False,
conversation=[],
max_round=30
):
response += chunk
return history + [(message, response)]
except Exception as e:
return history + [(message, f"Error: {str(e)}")]
def create_interface():
app = TxAgentApp()
with gr.Blocks(title="TxAgent") as demo:
gr.Markdown("# TxAgent: Therapeutic Reasoning AI")
# Initialization
with gr.Row():
init_btn = gr.Button("Initialize Model", variant="primary")
init_status = gr.Textbox(label="Initialization Status")
# Chat interface
chatbot = gr.Chatbot(height=600)
msg = gr.Textbox(label="Your Question")
submit_btn = gr.Button("Submit")
# Examples
gr.Examples(
examples=[
"How to adjust Journavx dosage for hepatic impairment?",
"Is Xolremdi safe with Prozac for WHIM syndrome?",
"Warfarin-Amiodarone contraindications?"
],
inputs=msg
)
# Event handlers
init_btn.click(
app.initialize,
outputs=init_status
)
def respond(message, chat_history):
return app.chat(message, chat_history)
msg.submit(respond, [msg, chatbot], chatbot)
submit_btn.click(respond, [msg, chatbot], chatbot)
return demo
if __name__ == "__main__":
# First download all required files
download_model_files()
# Then create and launch the interface
interface = create_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
) |