|
import os |
|
import json |
|
import logging |
|
import torch |
|
from txagent import TxAgent |
|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
from tooluniverse import ToolUniverse |
|
import time |
|
from requests.adapters import HTTPAdapter |
|
from requests import Session |
|
from urllib3.util.retry import Retry |
|
from tqdm import tqdm |
|
|
|
|
|
CONFIG = { |
|
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt", |
|
"local_dir": "./models", |
|
"tool_files": { |
|
"new_tool": "./data/new_tool.json" |
|
}, |
|
"download_settings": { |
|
"timeout": 600, |
|
"max_retries": 5, |
|
"retry_delay": 30, |
|
"chunk_size": 1024 * 1024 * 10, |
|
"max_concurrent": 2 |
|
} |
|
} |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def create_optimized_session(): |
|
"""Create a session optimized for large file downloads""" |
|
session = Session() |
|
|
|
retry_strategy = Retry( |
|
total=CONFIG["download_settings"]["max_retries"], |
|
backoff_factor=1, |
|
status_forcelist=[408, 429, 500, 502, 503, 504] |
|
) |
|
|
|
adapter = HTTPAdapter( |
|
max_retries=retry_strategy, |
|
pool_connections=10, |
|
pool_maxsize=10, |
|
pool_block=True |
|
) |
|
|
|
session.mount("https://", adapter) |
|
session.mount("http://", adapter) |
|
|
|
return session |
|
|
|
def prepare_tool_files(): |
|
os.makedirs("./data", exist_ok=True) |
|
if not os.path.exists(CONFIG["tool_files"]["new_tool"]): |
|
logger.info("Generating tool list using ToolUniverse...") |
|
tu = ToolUniverse() |
|
tools = tu.get_all_tools() |
|
with open(CONFIG["tool_files"]["new_tool"], "w") as f: |
|
json.dump(tools, f, indent=2) |
|
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}") |
|
|
|
def download_model_with_progress(repo_id, local_dir): |
|
custom_session = create_optimized_session() |
|
|
|
for attempt in range(CONFIG["download_settings"]["max_retries"] + 1): |
|
try: |
|
logger.info(f"Download attempt {attempt + 1} for {repo_id}") |
|
|
|
|
|
progress = tqdm( |
|
unit="B", |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
miniters=1, |
|
desc=f"Downloading {repo_id.split('/')[-1]}" |
|
) |
|
|
|
def update_progress(monitor): |
|
progress.update(monitor.bytes_read - progress.n) |
|
|
|
snapshot_download( |
|
repo_id=repo_id, |
|
local_dir=local_dir, |
|
resume_download=True, |
|
local_dir_use_symlinks=False, |
|
use_auth_token=True, |
|
max_workers=CONFIG["download_settings"]["max_concurrent"], |
|
tqdm_class=None, |
|
session=custom_session |
|
) |
|
|
|
progress.close() |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Attempt {attempt + 1} failed: {str(e)}") |
|
if attempt < CONFIG["download_settings"]["max_retries"]: |
|
wait_time = CONFIG["download_settings"]["retry_delay"] * (attempt + 1) |
|
logger.info(f"Waiting {wait_time} seconds before retry...") |
|
time.sleep(wait_time) |
|
else: |
|
progress.close() |
|
return False |
|
|
|
def download_model_files(): |
|
os.makedirs(CONFIG["local_dir"], exist_ok=True) |
|
logger.info("Starting model downloads...") |
|
|
|
|
|
if not download_model_with_progress( |
|
CONFIG["model_name"], |
|
os.path.join(CONFIG["local_dir"], CONFIG["model_name"]) |
|
): |
|
raise RuntimeError(f"Failed to download {CONFIG['model_name']}") |
|
|
|
|
|
if not download_model_with_progress( |
|
CONFIG["rag_model_name"], |
|
os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]) |
|
): |
|
raise RuntimeError(f"Failed to download {CONFIG['rag_model_name']}") |
|
|
|
logger.info("All model files downloaded successfully") |
|
|
|
def load_embeddings(agent): |
|
embedding_path = CONFIG["embedding_filename"] |
|
if os.path.exists(embedding_path): |
|
logger.info("✅ Loading pre-generated embeddings file") |
|
try: |
|
embeddings = torch.load(embedding_path) |
|
agent.rag_model.tool_desc_embedding = embeddings |
|
return |
|
except Exception as e: |
|
logger.error(f"Failed to load embeddings: {e}") |
|
|
|
logger.info("Generating tool embeddings...") |
|
try: |
|
tools = agent.tooluniverse.get_all_tools() |
|
descriptions = [tool["description"] for tool in tools] |
|
embeddings = agent.rag_model.generate_embeddings(descriptions) |
|
torch.save(embeddings, embedding_path) |
|
agent.rag_model.tool_desc_embedding = embeddings |
|
logger.info(f"Embeddings saved to {embedding_path}") |
|
except Exception as e: |
|
logger.error(f"Failed to generate embeddings: {e}") |
|
raise |
|
|
|
class TxAgentApp: |
|
def __init__(self): |
|
self.agent = None |
|
self.is_initialized = False |
|
|
|
def initialize(self): |
|
if self.is_initialized: |
|
return "✅ Already initialized" |
|
|
|
try: |
|
|
|
with tqdm(total=4, desc="Initializing TxAgent") as pbar: |
|
logger.info("Creating TxAgent instance...") |
|
self.agent = TxAgent( |
|
CONFIG["model_name"], |
|
CONFIG["rag_model_name"], |
|
tool_files_dict=CONFIG["tool_files"], |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=10, |
|
seed=100, |
|
additional_default_tools=["DirectResponse", "RequireClarification"] |
|
) |
|
pbar.update(1) |
|
|
|
logger.info("Initializing models...") |
|
self.agent.init_model() |
|
pbar.update(1) |
|
|
|
logger.info("Loading embeddings...") |
|
load_embeddings(self.agent) |
|
pbar.update(1) |
|
|
|
self.is_initialized = True |
|
pbar.update(1) |
|
|
|
return "✅ TxAgent initialized successfully" |
|
except Exception as e: |
|
logger.error(f"Initialization failed: {str(e)}") |
|
return f"❌ Initialization failed: {str(e)}" |
|
|
|
def chat(self, message, history): |
|
if not self.is_initialized: |
|
return history + [(message, "⚠️ Please initialize the model first")] |
|
|
|
try: |
|
response = "" |
|
for chunk in self.agent.run_gradio_chat( |
|
message=message, |
|
history=history, |
|
temperature=0.3, |
|
max_new_tokens=1024, |
|
max_tokens=8192, |
|
multi_agent=False, |
|
conversation=[], |
|
max_round=30 |
|
): |
|
response += chunk |
|
yield history + [(message, response)] |
|
|
|
except Exception as e: |
|
logger.error(f"Chat error: {str(e)}") |
|
yield history + [(message, f"Error: {str(e)}")] |
|
|
|
def create_interface(): |
|
app = TxAgentApp() |
|
|
|
with gr.Blocks( |
|
title="TxAgent", |
|
css=""" |
|
.gradio-container {max-width: 900px !important} |
|
.progress-bar {height: 20px !important} |
|
""" |
|
) as demo: |
|
gr.Markdown(""" |
|
# � TxAgent: Therapeutic Reasoning AI |
|
### Specialized for clinical decision support |
|
""") |
|
|
|
|
|
with gr.Row(): |
|
init_btn = gr.Button("Initialize Model", variant="primary") |
|
init_status = gr.Textbox(label="Status", interactive=False) |
|
download_progress = gr.Textbox(visible=False) |
|
|
|
|
|
chatbot = gr.Chatbot(height=500, label="Conversation") |
|
msg = gr.Textbox(label="Your clinical question", placeholder="Ask about drug interactions, dosing, etc...") |
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
"How to adjust Journavx for renal impairment?", |
|
"Xolremdi and Prozac interaction in WHIM syndrome?", |
|
"Alternative to Warfarin for patient with amiodarone?" |
|
], |
|
inputs=msg, |
|
label="Example Questions" |
|
) |
|
|
|
|
|
init_btn.click( |
|
fn=app.initialize, |
|
outputs=init_status |
|
) |
|
|
|
msg.submit( |
|
fn=app.chat, |
|
inputs=[msg, chatbot], |
|
outputs=chatbot |
|
) |
|
|
|
clear_btn.click( |
|
fn=lambda: ([], ""), |
|
outputs=[chatbot, msg] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
try: |
|
logger.info("Starting application setup...") |
|
|
|
|
|
prepare_tool_files() |
|
|
|
|
|
download_model_files() |
|
|
|
|
|
interface = create_interface() |
|
interface.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |
|
except Exception as e: |
|
logger.error(f"Fatal error: {str(e)}") |
|
raise |