import os import gradio as gr from txagent import TxAgent # ========== Configuration ========== current_dir = os.path.dirname(os.path.abspath(__file__)) os.environ["MKL_THREADING_LAYER"] = "GNU" os.environ["TOKENIZERS_PARALLELISM"] = "false" # Model configuration MODEL_CONFIG = { 'model_name': 'mims-harvard/TxAgent-T1-Llama-3.1-8B', 'rag_model_name': 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B', 'tool_files': {'new_tool': os.path.join(current_dir, 'data', 'new_tool.json')}, 'additional_tools': ['DirectResponse', 'RequireClarification'], 'default_params': { 'force_finish': True, 'enable_checker': True, 'step_rag_num': 10, 'seed': 100 } } # UI Configuration UI_CONFIG = { 'description': '''

TxAgent: Therapeutic Reasoning AI

Precision therapeutics with multi-step reasoning

''', 'disclaimer': '''
Disclaimer: For informational purposes only, not medical advice.
''' } # Example questions EXAMPLE_QUESTIONS = [ "How should dosage be adjusted for hepatic impairment with Journavx?", "Is Xolremdi suitable with Prozac for WHIM syndrome?", "What are Warfarin-Amiodarone contraindications?" ] # ========== Application Class ========== class TxAgentApplication: def __init__(self): self.agent = None self.is_initialized = False self.initialization_error = None def initialize_agent(self): if self.is_initialized: return "Model already initialized" try: # Initialize the agent self.agent = TxAgent( MODEL_CONFIG['model_name'], MODEL_CONFIG['rag_model_name'], tool_files_dict=MODEL_CONFIG['tool_files'], **MODEL_CONFIG['default_params'] ) # Initialize model with error handling try: self.agent.init_model() except Exception as e: # Handle specific tool embedding error if "No such file or directory" in str(e) and "tool_embedding" in str(e): return ("Error: Missing tool embedding file. " "Please ensure the RAG model files are properly downloaded.") raise self.is_initialized = True self.initialization_error = None return "TxAgent initialized successfully" except Exception as e: self.initialization_error = str(e) return f"Initialization failed: {str(e)}" def chat(self, message, chat_history): if not self.is_initialized: if self.initialization_error: return chat_history + [(message, f"System Error: {self.initialization_error}")] return chat_history + [(message, "Error: Please initialize the model first")] try: # Convert to messages format messages = [] for user, assistant in chat_history: messages.append({"role": "user", "content": user}) messages.append({"role": "assistant", "content": assistant}) messages.append({"role": "user", "content": message}) # Get response response = "" for chunk in self.agent.run_gradio_chat( messages, temperature=0.3, max_new_tokens=1024, max_tokens=8192, multi_agent=False, conversation=[], max_round=30 ): response += chunk return chat_history + [(message, response)] except Exception as e: return chat_history + [(message, f"Error during processing: {str(e)}")] # ========== Gradio Interface ========== def create_interface(): app = TxAgentApplication() with gr.Blocks(title="TxAgent", theme=gr.themes.Soft()) as demo: gr.Markdown(UI_CONFIG['description']) # Initialization with gr.Row(): init_btn = gr.Button("Initialize TxAgent", variant="primary") init_status = gr.Textbox(label="Status", interactive=False) # Chat Interface chatbot = gr.Chatbot( height=600, label="Conversation", show_label=True, show_copy_button=True ) with gr.Row(): msg = gr.Textbox( label="Your Question", placeholder="Ask about drug interactions or treatments...", scale=4, container=False ) submit_btn = gr.Button("Submit", variant="primary", scale=1) # Examples gr.Examples( examples=EXAMPLE_QUESTIONS, inputs=msg, label="Try these examples:", examples_per_page=3 ) gr.Markdown(UI_CONFIG['disclaimer']) # Event Handlers init_btn.click( app.initialize_agent, outputs=init_status ) msg.submit( app.chat, [msg, chatbot], chatbot ) submit_btn.click( app.chat, [msg, chatbot], chatbot ).then( lambda: "", None, msg ) return demo # ========== Main Execution ========== if __name__ == "__main__": # Create and configure the interface interface = create_interface() # Launch configuration launch_params = { 'server_name': '0.0.0.0', 'server_port': 7860, 'share': True } # Enable queue if needed (for production) try: interface.queue().launch(**launch_params) except Exception as e: print(f"Error launching interface: {e}") interface.launch(**launch_params) # Fallback without queue