import gradio as gr import os import logging from txagent import TxAgent from tooluniverse import ToolUniverse # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TxAgentApp: def __init__(self): self.agent = self._initialize_agent() def _initialize_agent(self): """Initialize the TxAgent with proper parameters""" try: logger.info("Initializing TxAgent...") # Initialize default tool files tool_files = { "opentarget": "opentarget_tools.json", "fda_drug_label": "fda_drug_labeling_tools.json", "special_tools": "special_tools.json", "monarch": "monarch_tools.json" } agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict=tool_files, # This is critical! enable_finish=True, enable_rag=True, enable_summary=False, init_rag_num=0, step_rag_num=10, summary_mode='step', summary_skip_last_k=0, summary_context_length=None, force_finish=True, avoid_repeat=True, seed=42, enable_checker=True, enable_chat=False, additional_default_tools=["DirectResponse", "RequireClarification"] ) # Explicitly initialize the model agent.init_model() logger.info("Model loading complete") return agent except Exception as e: logger.error(f"Initialization failed: {str(e)}") raise def respond(self, message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round): """Handle streaming responses with Gradio""" try: if not isinstance(message, str) or len(message.strip()) <= 10: return chat_history + [("", "Please provide a valid message longer than 10 characters.")] # Convert chat history to list of tuples if needed if chat_history and isinstance(chat_history[0], dict): chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h] response = "" for chunk in self.agent.run_gradio_chat( message=message.strip(), history=chat_history, temperature=temperature, max_new_tokens=max_new_tokens, max_token=max_tokens, call_agent=multi_agent, conversation=conversation_state, max_round=max_round, seed=42 ): if isinstance(chunk, dict): response += chunk.get("content", "") elif isinstance(chunk, str): response += chunk else: response += str(chunk) yield chat_history + [("user", message), ("assistant", response)] except Exception as e: logger.error(f"Error in respond function: {str(e)}") yield chat_history + [("", f"⚠️ Error: {str(e)}")] def create_demo(): """Create and configure the Gradio interface""" app = TxAgentApp() with gr.Blocks(title="TxAgent Medical AI") as demo: gr.Markdown("# TxAgent Biomedical Assistant") chatbot = gr.Chatbot( label="Conversation", height=600, bubble_full_width=False ) msg = gr.Textbox( label="Your medical query", placeholder="Enter your biomedical question...", lines=3 ) with gr.Row(): temp = gr.Slider(0, 1, value=0.3, label="Temperature") max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens") max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens") max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds") multi_agent = gr.Checkbox(label="Multi-Agent Mode") submit = gr.Button("Submit") clear = gr.Button("Clear") conversation_state = gr.State([]) submit.click( app.respond, [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds], chatbot ) clear.click(lambda: [], None, chatbot) msg.submit( app.respond, [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds], chatbot ) return demo def main(): """Main entry point for the application""" try: logger.info("Starting TxAgent application...") demo = create_demo() logger.info("Launching Gradio interface...") demo.launch( server_name="0.0.0.0", server_port=7860, share=False ) except Exception as e: logger.error(f"Application failed to start: {str(e)}") raise if __name__ == "__main__": main()