|
import os |
|
import json |
|
import logging |
|
import torch |
|
import gradio as gr |
|
from tooluniverse import ToolUniverse |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import warnings |
|
from typing import List, Dict, Any |
|
from importlib.resources import files |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
CONFIG = { |
|
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
"embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt", |
|
"tool_files": { |
|
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')), |
|
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')), |
|
"special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')), |
|
"monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')), |
|
"new_tool": "./data/new_tool.json" |
|
} |
|
} |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def prepare_tool_files(): |
|
"""Ensure tool files exist and are populated""" |
|
os.makedirs("./data", exist_ok=True) |
|
if not os.path.exists(CONFIG["tool_files"]["new_tool"]): |
|
logger.info("Generating tool list using ToolUniverse...") |
|
tu = ToolUniverse() |
|
tools = tu.get_all_tools() if hasattr(tu, 'get_all_tools') else [] |
|
with open(CONFIG["tool_files"]["new_tool"], "w") as f: |
|
json.dump(tools, f, indent=2) |
|
logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}") |
|
|
|
def safe_load_embeddings(filepath: str) -> Any: |
|
"""Safely load embeddings with proper weights_only handling""" |
|
try: |
|
|
|
return torch.load(filepath, weights_only=True) |
|
except Exception as e: |
|
logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}") |
|
try: |
|
|
|
with torch.serialization.safe_globals([torch.serialization._reconstruct]): |
|
return torch.load(filepath, weights_only=False) |
|
except Exception as e: |
|
logger.error(f"Failed to load embeddings even with safe_globals: {str(e)}") |
|
return None |
|
|
|
class TxAgentWrapper: |
|
def __init__(self): |
|
self.model = None |
|
self.tokenizer = None |
|
self.rag_model = None |
|
self.tooluniverse = None |
|
self.is_initialized = False |
|
self.special_tools = ['Finish', 'Tool_RAG', 'DirectResponse', 'RequireClarification'] |
|
|
|
def initialize(self) -> str: |
|
"""Initialize the model from Hugging Face""" |
|
if self.is_initialized: |
|
return "β
Already initialized" |
|
|
|
try: |
|
logger.info("Loading models from Hugging Face Hub...") |
|
|
|
|
|
for tool_name, tool_path in CONFIG["tool_files"].items(): |
|
if tool_name != "new_tool" and not os.path.exists(tool_path): |
|
raise FileNotFoundError(f"Tool file not found: {tool_path}") |
|
|
|
|
|
self.tooluniverse = ToolUniverse(tool_files=CONFIG["tool_files"]) |
|
if hasattr(self.tooluniverse, 'load_tools'): |
|
self.tooluniverse.load_tools() |
|
logger.info(f"Loaded {len(self.tooluniverse.tools)} tools") |
|
else: |
|
logger.error("ToolUniverse doesn't have load_tools method") |
|
return "β Failed to load tools" |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"]) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
CONFIG["model_name"], |
|
device_map="auto", |
|
torch_dtype=torch.float16 |
|
) |
|
|
|
|
|
if os.path.exists(CONFIG["embedding_filename"]): |
|
self.rag_model = safe_load_embeddings(CONFIG["embedding_filename"]) |
|
if self.rag_model is None: |
|
return "β Failed to load embeddings" |
|
|
|
self.is_initialized = True |
|
return "β
Model initialized successfully" |
|
|
|
except Exception as e: |
|
logger.error(f"Initialization failed: {str(e)}") |
|
return f"β Initialization failed: {str(e)}" |
|
|
|
def chat(self, message: str, history: List[List[str]]) -> List[List[str]]: |
|
"""Handle chat interactions with the model""" |
|
if not self.is_initialized: |
|
return history + [["", "β οΈ Please initialize the model first"]] |
|
|
|
try: |
|
if len(message) <= 10: |
|
return history + [["", "Please provide a more detailed question (at least 10 characters)"]] |
|
|
|
|
|
tools_prompt = self._prepare_tools_prompt(message) |
|
|
|
|
|
conversation = [ |
|
{"role": "system", "content": "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning." + tools_prompt}, |
|
*self._format_history(history), |
|
{"role": "user", "content": message} |
|
] |
|
|
|
|
|
inputs = self.tokenizer.apply_chat_template( |
|
conversation, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(self.model.device) |
|
|
|
outputs = self.model.generate( |
|
inputs, |
|
max_new_tokens=1024, |
|
temperature=0.7, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response = self.tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) |
|
response = response.split("[TOOL_CALLS]")[0].strip() |
|
|
|
return history + [[message, response]] |
|
|
|
except Exception as e: |
|
logger.error(f"Chat error: {str(e)}") |
|
return history + [["", f"Error: {str(e)}"]] |
|
|
|
def _prepare_tools_prompt(self, message: str) -> str: |
|
"""Prepare the tools prompt section""" |
|
if not hasattr(self.tooluniverse, 'tools'): |
|
return "" |
|
|
|
tools_prompt = "\n\nYou have access to the following tools:\n" |
|
for tool in self.tooluniverse.tools: |
|
if tool['name'] not in self.special_tools: |
|
tools_prompt += f"- {tool['name']}: {tool['description']}\n" |
|
|
|
|
|
tools_prompt += "\nSpecial tools:\n" |
|
tools_prompt += "- Finish: Use when you have the final answer\n" |
|
tools_prompt += "- Tool_RAG: Search for additional tools when needed\n" |
|
|
|
return tools_prompt |
|
|
|
def _format_history(self, history: List[List[str]]) -> List[Dict[str, str]]: |
|
"""Format chat history for the model""" |
|
formatted = [] |
|
for user_msg, bot_msg in history: |
|
formatted.append({"role": "user", "content": user_msg}) |
|
if bot_msg: |
|
formatted.append({"role": "assistant", "content": bot_msg}) |
|
return formatted |
|
|
|
def create_interface() -> gr.Blocks: |
|
"""Create the Gradio interface""" |
|
agent = TxAgentWrapper() |
|
|
|
with gr.Blocks( |
|
title="TxAgent", |
|
css=""" |
|
.gradio-container {max-width: 900px !important} |
|
""" |
|
) as demo: |
|
gr.Markdown(""" |
|
# π§ TxAgent: Therapeutic Reasoning AI |
|
### (Loading from Hugging Face Hub) |
|
""") |
|
|
|
with gr.Row(): |
|
init_btn = gr.Button("Initialize Model", variant="primary") |
|
init_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
chatbot = gr.Chatbot( |
|
height=500, |
|
label="Conversation" |
|
) |
|
msg = gr.Textbox(label="Your clinical question") |
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
gr.Examples( |
|
examples=[ |
|
"How to adjust Journavx for renal impairment?", |
|
"Xolremdi and Prozac interaction in WHIM syndrome?", |
|
"Alternative to Warfarin for patient with amiodarone?" |
|
], |
|
inputs=msg |
|
) |
|
|
|
def wrapper_initialize(): |
|
status = agent.initialize() |
|
return status, gr.update(interactive=False) |
|
|
|
init_btn.click( |
|
fn=wrapper_initialize, |
|
outputs=[init_status, init_btn] |
|
) |
|
|
|
msg.submit( |
|
fn=agent.chat, |
|
inputs=[msg, chatbot], |
|
outputs=chatbot |
|
).then( |
|
lambda: "", |
|
outputs=msg |
|
) |
|
|
|
clear_btn.click( |
|
fn=lambda: ([], ""), |
|
outputs=[chatbot, msg] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
try: |
|
logger.info("Starting application...") |
|
|
|
|
|
if not os.path.exists(CONFIG["embedding_filename"]): |
|
logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}") |
|
logger.info("Please ensure the file is in the root directory") |
|
else: |
|
logger.info(f"Found embedding file: {CONFIG['embedding_filename']}") |
|
|
|
|
|
prepare_tool_files() |
|
|
|
|
|
interface = create_interface() |
|
interface.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |
|
except Exception as e: |
|
logger.error(f"Application failed to start: {str(e)}") |
|
raise |