import os
import gradio as gr
from txagent import TxAgent
# ========== Configuration ==========
current_dir = os.path.dirname(os.path.abspath(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Model configuration
MODEL_CONFIG = {
'model_name': 'mims-harvard/TxAgent-T1-Llama-3.1-8B',
'rag_model_name': 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B',
'tool_files': {'new_tool': os.path.join(current_dir, 'data', 'new_tool.json')},
'additional_tools': ['DirectResponse', 'RequireClarification'],
'default_params': {
'force_finish': True,
'enable_checker': True,
'step_rag_num': 10,
'seed': 100
}
}
# UI Configuration
UI_CONFIG = {
'description': '''
TxAgent: Therapeutic Reasoning AI
Precision therapeutics with multi-step reasoning
''',
'disclaimer': '''
Disclaimer: For informational purposes only, not medical advice.
'''
}
# Example questions
EXAMPLE_QUESTIONS = [
"How should dosage be adjusted for hepatic impairment with Journavx?",
"Is Xolremdi suitable with Prozac for WHIM syndrome?",
"What are Warfarin-Amiodarone contraindications?"
]
# ========== Application Class ==========
class TxAgentApplication:
def __init__(self):
self.agent = None
self.is_initialized = False
self.initialization_error = None
def initialize_agent(self):
if self.is_initialized:
return "Model already initialized"
try:
# Initialize the agent
self.agent = TxAgent(
MODEL_CONFIG['model_name'],
MODEL_CONFIG['rag_model_name'],
tool_files_dict=MODEL_CONFIG['tool_files'],
**MODEL_CONFIG['default_params']
)
# Initialize model with error handling
try:
self.agent.init_model()
except Exception as e:
# Handle specific tool embedding error
if "No such file or directory" in str(e) and "tool_embedding" in str(e):
return ("Error: Missing tool embedding file. "
"Please ensure the RAG model files are properly downloaded.")
raise
self.is_initialized = True
self.initialization_error = None
return "TxAgent initialized successfully"
except Exception as e:
self.initialization_error = str(e)
return f"Initialization failed: {str(e)}"
def chat(self, message, chat_history):
if not self.is_initialized:
if self.initialization_error:
return chat_history + [(message, f"System Error: {self.initialization_error}")]
return chat_history + [(message, "Error: Please initialize the model first")]
try:
# Convert to messages format
messages = []
for user, assistant in chat_history:
messages.append({"role": "user", "content": user})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
# Get response
response = ""
for chunk in self.agent.run_gradio_chat(
messages,
temperature=0.3,
max_new_tokens=1024,
max_tokens=8192,
multi_agent=False,
conversation=[],
max_round=30
):
response += chunk
return chat_history + [(message, response)]
except Exception as e:
return chat_history + [(message, f"Error during processing: {str(e)}")]
# ========== Gradio Interface ==========
def create_interface():
app = TxAgentApplication()
with gr.Blocks(title="TxAgent", theme=gr.themes.Soft()) as demo:
gr.Markdown(UI_CONFIG['description'])
# Initialization
with gr.Row():
init_btn = gr.Button("Initialize TxAgent", variant="primary")
init_status = gr.Textbox(label="Status", interactive=False)
# Chat Interface
chatbot = gr.Chatbot(
height=600,
label="Conversation",
show_label=True,
show_copy_button=True
)
with gr.Row():
msg = gr.Textbox(
label="Your Question",
placeholder="Ask about drug interactions or treatments...",
scale=4,
container=False
)
submit_btn = gr.Button("Submit", variant="primary", scale=1)
# Examples
gr.Examples(
examples=EXAMPLE_QUESTIONS,
inputs=msg,
label="Try these examples:",
examples_per_page=3
)
gr.Markdown(UI_CONFIG['disclaimer'])
# Event Handlers
init_btn.click(
app.initialize_agent,
outputs=init_status
)
msg.submit(
app.chat,
[msg, chatbot],
chatbot
)
submit_btn.click(
app.chat,
[msg, chatbot],
chatbot
).then(
lambda: "", None, msg
)
return demo
# ========== Main Execution ==========
if __name__ == "__main__":
# Create and configure the interface
interface = create_interface()
# Launch configuration
launch_params = {
'server_name': '0.0.0.0',
'server_port': 7860,
'share': True
}
# Enable queue if needed (for production)
try:
interface.queue().launch(**launch_params)
except Exception as e:
print(f"Error launching interface: {e}")
interface.launch(**launch_params) # Fallback without queue