test / app.py
Ali2206's picture
Update app.py
dc06321 verified
raw
history blame
6.93 kB
import os
import json
import torch
import logging
import numpy
import gradio as gr
from importlib.resources import files
from txagent import TxAgent
from tooluniverse import ToolUniverse
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Environment setup
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
current_dir = os.path.dirname(os.path.abspath(__file__))
# Configuration
CONFIG = {
"model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
"rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
"tool_files": {
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
"special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
"monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')),
"new_tool": os.path.join(current_dir, 'data', 'new_tool.json')
}
}
class TxAgentApp:
def __init__(self):
self.agent = None
self.initialize_agent()
def initialize_agent(self):
"""Initialize the TxAgent with proper error handling"""
try:
self.prepare_tool_files()
logger.info("Initializing TxAgent...")
self.agent = TxAgent(
model_name=CONFIG["model_name"],
rag_model_name=CONFIG["rag_model_name"],
tool_files_dict=CONFIG["tool_files"],
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=42,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
logger.info("Initializing model...")
self.agent.init_model()
logger.info("Agent initialization complete")
except Exception as e:
logger.error(f"Failed to initialize agent: {e}")
raise
def prepare_tool_files(self):
"""Prepare the tool files directory"""
try:
os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
logger.info("Creating new_tool.json...")
tu = ToolUniverse()
tools = tu.get_all_tools() if hasattr(tu, "get_all_tools") else getattr(tu, "tools", [])
with open(CONFIG["tool_files"]["new_tool"], "w") as f:
json.dump(tools, f, indent=2)
except Exception as e:
logger.error(f"Failed to prepare tool files: {e}")
raise
def respond(self, msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
"""Handle user message and generate response"""
try:
if not isinstance(msg, str) or len(msg.strip()) <= 10:
return chat_history + [{"role": "assistant", "content": "Please provide a valid message longer than 10 characters."}]
message = msg.strip()
chat_history.append({"role": "user", "content": message})
formatted_history = [(m["role"], m["content"]) for m in chat_history if "role" in m and "content" in m]
logger.info(f"Processing message: {message[:100]}...")
response_generator = self.agent.run_gradio_chat(
message=message,
history=formatted_history,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_token=max_tokens,
call_agent=multi_agent,
conversation=conversation,
max_round=max_round,
seed=42
)
collected = ""
for chunk in response_generator:
if isinstance(chunk, dict) and "content" in chunk:
collected += chunk["content"]
elif isinstance(chunk, str):
collected += chunk
elif chunk is not None:
collected += str(chunk)
chat_history.append({"role": "assistant", "content": collected or "No response generated."})
return chat_history
except Exception as e:
logger.error(f"Error in respond function: {e}")
chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"})
return chat_history
def create_demo(self):
"""Create and return the Gradio interface"""
with gr.Blocks(title="TxAgent", css=".gr-button { font-size: 18px !important; }") as demo:
gr.Markdown("# TxAgent - Biomedical AI Assistant")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Conversation",
height=600
)
msg = gr.Textbox(
label="Your question",
placeholder="Ask a biomedical question...",
lines=3
)
submit = gr.Button("Ask", variant="primary")
with gr.Column(scale=1):
temp = gr.Slider(0, 1, value=0.3, label="Temperature")
max_new_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max New Tokens")
max_tokens = gr.Slider(128, 81920, value=81920, step=1024, label="Max Total Tokens")
max_rounds = gr.Slider(1, 30, value=10, step=1, label="Max Rounds")
multi_agent = gr.Checkbox(label="Multi-Agent Mode", value=False)
clear_btn = gr.Button("Clear Chat")
submit.click(
self.respond,
inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
outputs=[chatbot]
)
clear_btn.click(lambda: [], None, chatbot, queue=False)
# Add a dummy event to ensure the app stays alive
demo.load(lambda: None, None, None)
return demo
def main():
"""Main entry point for the application"""
try:
logger.info("Starting TxAgent application...")
app = TxAgentApp()
demo = app.create_demo()
logger.info("Launching Gradio interface...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)
except Exception as e:
logger.error(f"Application failed to start: {e}")
raise
if __name__ == "__main__":
main()