import os import gradio as gr from txagent import TxAgent # ========== Configuration ========== current_dir = os.path.dirname(os.path.abspath(__file__)) os.environ["MKL_THREADING_LAYER"] = "GNU" os.environ["TOKENIZERS_PARALLELISM"] = "false" # Model configuration MODEL_CONFIG = { 'model_name': 'mims-harvard/TxAgent-T1-Llama-3.1-8B', 'rag_model_name': 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B', 'tool_files': {'new_tool': os.path.join(current_dir, 'data', 'new_tool.json')}, 'additional_tools': ['DirectResponse', 'RequireClarification'], 'default_params': { 'force_finish': True, 'enable_checker': True, 'step_rag_num': 10, 'seed': 100 } } # UI Configuration UI_CONFIG = { 'description': '''

TxAgent: Therapeutic Reasoning AI

Precision therapeutics with multi-step reasoning

''', 'disclaimer': '''
Disclaimer: For informational purposes only, not medical advice.
''' } # Example questions EXAMPLE_QUESTIONS = [ "How should dosage be adjusted for hepatic impairment with Journavx?", "Is Xolremdi suitable with Prozac for WHIM syndrome?", "What are Warfarin-Amiodarone contraindications?" ] # ========== Application Class ========== class TxAgentApplication: def __init__(self): self.agent = None self.is_initialized = False def initialize_agent(self): if self.is_initialized: return "Model already initialized" try: self.agent = TxAgent( MODEL_CONFIG['model_name'], MODEL_CONFIG['rag_model_name'], tool_files_dict=MODEL_CONFIG['tool_files'], **MODEL_CONFIG['default_params'] ) self.agent.init_model() self.is_initialized = True return "TxAgent initialized successfully" except Exception as e: return f"Initialization failed: {str(e)}" def chat(self, message, chat_history): if not self.is_initialized: yield "Error: Please initialize the model first" return try: # Convert to messages format messages = [] for user, assistant in chat_history: messages.append({"role": "user", "content": user}) messages.append({"role": "assistant", "content": assistant}) messages.append({"role": "user", "content": message}) # Stream response full_response = "" for chunk in self.agent.run_gradio_chat( messages, temperature=0.3, max_new_tokens=1024, max_tokens=8192, multi_agent=False, conversation=[], max_round=30 ): full_response += chunk yield [(message, full_response)] except Exception as e: yield [(message, f"Error: {str(e)}")] # ========== Gradio Interface ========== def create_interface(): app = TxAgentApplication() with gr.Blocks(title="TxAgent", theme=gr.themes.Soft()) as demo: gr.Markdown(UI_CONFIG['description']) # Initialization with gr.Row(): init_btn = gr.Button("Initialize TxAgent", variant="primary") init_status = gr.Textbox(label="Status", interactive=False) # Chat Interface (using modern messages format) chatbot = gr.Chatbot( height=600, label="Conversation", avatar_images=( "https://example.com/user.png", # User avatar "https://example.com/bot.png" # Bot avatar ) ) with gr.Row(): msg = gr.Textbox( label="Your Question", placeholder="Ask about drug interactions or treatments...", scale=4 ) submit_btn = gr.Button("Submit", variant="primary", scale=1) # Examples gr.Examples( examples=EXAMPLE_QUESTIONS, inputs=msg, label="Try these examples:" ) gr.Markdown(UI_CONFIG['disclaimer']) # Event Handlers init_btn.click( app.initialize_agent, outputs=init_status ) msg.submit( app.chat, [msg, chatbot], [chatbot] ) submit_btn.click( app.chat, [msg, chatbot], [chatbot] ).then( lambda: "", None, msg ) return demo # ========== Main Execution ========== if __name__ == "__main__": interface = create_interface() # Correct launch configuration interface.launch( server_name="0.0.0.0", server_port=7860, share=True, enable_queue=True # Enable queue without deprecated parameters )