import os import json import torch import logging import numpy import gradio as gr from importlib.resources import files from txagent import TxAgent from tooluniverse import ToolUniverse # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Environment setup os.environ["MKL_THREADING_LAYER"] = "GNU" os.environ["TOKENIZERS_PARALLELISM"] = "false" current_dir = os.path.dirname(os.path.abspath(__file__)) # Configuration CONFIG = { "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B", "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", "tool_files": { "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')), "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')), "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')), "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')), "new_tool": os.path.join(current_dir, 'data', 'new_tool.json') } } class TxAgentApp: def __init__(self): self.agent = None self.initialize_agent() def initialize_agent(self): """Initialize the TxAgent with proper error handling""" try: self.prepare_tool_files() logger.info("Initializing TxAgent...") self.agent = TxAgent( model_name=CONFIG["model_name"], rag_model_name=CONFIG["rag_model_name"], tool_files_dict=CONFIG["tool_files"], force_finish=True, enable_checker=True, step_rag_num=10, seed=42, additional_default_tools=["DirectResponse", "RequireClarification"] ) logger.info("Initializing model...") self.agent.init_model() logger.info("Agent initialization complete") except Exception as e: logger.error(f"Failed to initialize agent: {e}") raise def prepare_tool_files(self): """Prepare the tool files directory""" try: os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True) if not os.path.exists(CONFIG["tool_files"]["new_tool"]): logger.info("Creating new_tool.json...") tu = ToolUniverse() tools = tu.get_all_tools() if hasattr(tu, "get_all_tools") else getattr(tu, "tools", []) with open(CONFIG["tool_files"]["new_tool"], "w") as f: json.dump(tools, f, indent=2) except Exception as e: logger.error(f"Failed to prepare tool files: {e}") raise def respond(self, msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round): """Handle user message and generate response""" try: if not isinstance(msg, str) or len(msg.strip()) <= 10: return chat_history + [{"role": "assistant", "content": "Please provide a valid message longer than 10 characters."}] message = msg.strip() chat_history.append({"role": "user", "content": message}) formatted_history = [(m["role"], m["content"]) for m in chat_history if "role" in m and "content" in m] logger.info(f"Processing message: {message[:100]}...") response_generator = self.agent.run_gradio_chat( message=message, history=formatted_history, temperature=temperature, max_new_tokens=max_new_tokens, max_token=max_tokens, call_agent=multi_agent, conversation=conversation, max_round=max_round, seed=42 ) collected = "" for chunk in response_generator: if isinstance(chunk, dict) and "content" in chunk: collected += chunk["content"] elif isinstance(chunk, str): collected += chunk elif chunk is not None: collected += str(chunk) chat_history.append({"role": "assistant", "content": collected or "No response generated."}) return chat_history except Exception as e: logger.error(f"Error in respond function: {e}") chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"}) return chat_history def create_demo(self): """Create and return the Gradio interface""" with gr.Blocks(title="TxAgent", css=".gr-button { font-size: 18px !important; }") as demo: gr.Markdown("# TxAgent - Biomedical AI Assistant") with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( label="Conversation", height=600 ) msg = gr.Textbox( label="Your question", placeholder="Ask a biomedical question...", lines=3 ) submit = gr.Button("Ask", variant="primary") with gr.Column(scale=1): temp = gr.Slider(0, 1, value=0.3, label="Temperature") max_new_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max New Tokens") max_tokens = gr.Slider(128, 81920, value=81920, step=1024, label="Max Total Tokens") max_rounds = gr.Slider(1, 30, value=10, step=1, label="Max Rounds") multi_agent = gr.Checkbox(label="Multi-Agent Mode", value=False) clear_btn = gr.Button("Clear Chat") submit.click( self.respond, inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds], outputs=[chatbot] ) clear_btn.click(lambda: [], None, chatbot, queue=False) # Add a dummy event to ensure the app stays alive demo.load(lambda: None, None, None) return demo def main(): """Main entry point for the application""" try: logger.info("Starting TxAgent application...") app = TxAgentApp() demo = app.create_demo() logger.info("Launching Gradio interface...") demo.launch( server_name="0.0.0.0", server_port=7860, share=True, show_error=True ) except Exception as e: logger.error(f"Application failed to start: {e}") raise if __name__ == "__main__": main()