|
import os |
|
import json |
|
import logging |
|
import torch |
|
import gradio as gr |
|
from tooluniverse import ToolUniverse |
|
from txagent import TxAgent |
|
import warnings |
|
from typing import List, Dict, Any |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
CONFIG = { |
|
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt", |
|
"tool_files": { |
|
"new_tool": "./data/new_tool.json" |
|
} |
|
} |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def prepare_tool_files(): |
|
"""Ensure tool files exist and are populated""" |
|
os.makedirs("./data", exist_ok=True) |
|
if not os.path.exists(CONFIG["tool_files"]["new_tool"]): |
|
logger.info("Generating tool list using ToolUniverse...") |
|
tu = ToolUniverse() |
|
tools = tu.get_all_tools() |
|
with open(CONFIG["tool_files"]["new_tool"], "w") as f: |
|
json.dump(tools, f, indent=2) |
|
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}") |
|
|
|
def safe_load_embeddings(filepath: str) -> Any: |
|
"""Safely load embeddings with proper weights_only handling""" |
|
try: |
|
|
|
return torch.load(filepath, weights_only=True) |
|
except Exception as e: |
|
logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}") |
|
|
|
return torch.load(filepath, weights_only=False) |
|
|
|
def patch_embedding_loading(): |
|
"""Monkey-patch the embedding loading functionality""" |
|
try: |
|
from txagent.toolrag import ToolRAGModel |
|
|
|
original_load = ToolRAGModel.load_tool_desc_embedding |
|
|
|
def patched_load(self, tooluniverse: ToolUniverse) -> bool: |
|
try: |
|
if not os.path.exists(CONFIG["embedding_filename"]): |
|
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}") |
|
return False |
|
|
|
|
|
self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"]) |
|
|
|
|
|
tools = tooluniverse.get_all_tools() |
|
current_count = len(tools) |
|
embedding_count = len(self.tool_desc_embedding) |
|
|
|
if current_count != embedding_count: |
|
logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})") |
|
|
|
if current_count < embedding_count: |
|
self.tool_desc_embedding = self.tool_desc_embedding[:current_count] |
|
logger.info(f"Truncated embeddings to match {current_count} tools") |
|
else: |
|
last_embedding = self.tool_desc_embedding[-1] |
|
padding = [last_embedding] * (current_count - embedding_count) |
|
self.tool_desc_embedding = torch.cat( |
|
[self.tool_desc_embedding] + padding |
|
) |
|
logger.info(f"Padded embeddings to match {current_count} tools") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load embeddings: {str(e)}") |
|
return False |
|
|
|
|
|
ToolRAGModel.load_tool_desc_embedding = patched_load |
|
logger.info("Successfully patched embedding loading") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to patch embedding loading: {str(e)}") |
|
raise |
|
|
|
class TxAgentApp: |
|
def __init__(self): |
|
self.agent = None |
|
self.is_initialized = False |
|
|
|
def initialize(self) -> str: |
|
"""Initialize the TxAgent with all required components""" |
|
if self.is_initialized: |
|
return "✅ Already initialized" |
|
|
|
try: |
|
|
|
patch_embedding_loading() |
|
|
|
logger.info("Initializing TxAgent...") |
|
self.agent = TxAgent( |
|
model_name=CONFIG["model_name"], |
|
rag_model_name=CONFIG["rag_model_name"], |
|
tool_files_dict=CONFIG["tool_files"], |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=10, |
|
seed=100, |
|
additional_default_tools=["DirectResponse", "RequireClarification"] |
|
) |
|
|
|
logger.info("Loading models...") |
|
self.agent.init_model() |
|
|
|
self.is_initialized = True |
|
return "✅ TxAgent initialized successfully" |
|
|
|
except Exception as e: |
|
logger.error(f"Initialization failed: {str(e)}") |
|
return f"❌ Initialization failed: {str(e)}" |
|
|
|
def chat(self, message: str, history: List[List[str]]) -> List[List[str]]: |
|
""" |
|
Handle chat interactions with the TxAgent |
|
|
|
Args: |
|
message: User input message |
|
history: Chat history in format [[user_msg, bot_msg], ...] |
|
|
|
Returns: |
|
Updated chat history |
|
""" |
|
if not self.is_initialized: |
|
return history + [["", "⚠️ Please initialize the model first"]] |
|
|
|
try: |
|
|
|
tx_history = [] |
|
for user_msg, bot_msg in history: |
|
tx_history.append({"role": "user", "content": user_msg}) |
|
if bot_msg: |
|
tx_history.append({"role": "assistant", "content": bot_msg}) |
|
|
|
|
|
response = "" |
|
for chunk in self.agent.run_gradio_chat( |
|
message=message, |
|
history=tx_history, |
|
temperature=0.3, |
|
max_new_tokens=1024, |
|
max_token=8192, |
|
call_agent=False, |
|
conversation=None, |
|
max_round=30 |
|
): |
|
response = chunk |
|
|
|
|
|
return history + [[message, response]] |
|
|
|
except Exception as e: |
|
logger.error(f"Chat error: {str(e)}") |
|
return history + [["", f"Error: {str(e)}"]] |
|
|
|
def create_interface() -> gr.Blocks: |
|
"""Create the Gradio interface""" |
|
app = TxAgentApp() |
|
|
|
with gr.Blocks( |
|
title="TxAgent", |
|
css=""" |
|
.gradio-container {max-width: 900px !important} |
|
""" |
|
) as demo: |
|
gr.Markdown(""" |
|
# 🧠 TxAgent: Therapeutic Reasoning AI |
|
### (Using pre-loaded embeddings) |
|
""") |
|
|
|
with gr.Row(): |
|
init_btn = gr.Button("Initialize Model", variant="primary") |
|
init_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
chatbot = gr.Chatbot( |
|
height=500, |
|
label="Conversation" |
|
) |
|
msg = gr.Textbox(label="Your clinical question") |
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
gr.Examples( |
|
examples=[ |
|
"How to adjust Journavx for renal impairment?", |
|
"Xolremdi and Prozac interaction in WHIM syndrome?", |
|
"Alternative to Warfarin for patient with amiodarone?" |
|
], |
|
inputs=msg |
|
) |
|
|
|
def wrapper_initialize() -> tuple: |
|
"""Wrapper for initialization with UI updates""" |
|
status = app.initialize() |
|
return status, gr.update(interactive=False) |
|
|
|
init_btn.click( |
|
fn=wrapper_initialize, |
|
outputs=[init_status, init_btn] |
|
) |
|
|
|
msg.submit( |
|
fn=app.chat, |
|
inputs=[msg, chatbot], |
|
outputs=chatbot |
|
).then( |
|
lambda: "", |
|
outputs=msg |
|
) |
|
|
|
clear_btn.click( |
|
fn=lambda: ([], ""), |
|
outputs=[chatbot, msg] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
try: |
|
logger.info("Starting application...") |
|
|
|
|
|
if not os.path.exists(CONFIG["embedding_filename"]): |
|
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}") |
|
logger.info("Please ensure the file is in the root directory") |
|
else: |
|
logger.info(f"Found embedding file: {CONFIG['embedding_filename']}") |
|
|
|
|
|
prepare_tool_files() |
|
|
|
|
|
interface = create_interface() |
|
interface.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |
|
except Exception as e: |
|
logger.error(f"Application failed to start: {str(e)}") |
|
raise |