|
import os |
|
import gradio as gr |
|
from txagent import TxAgent |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
os.environ["MKL_THREADING_LAYER"] = "GNU" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
MODEL_CONFIG = { |
|
'model_name': 'mims-harvard/TxAgent-T1-Llama-3.1-8B', |
|
'rag_model_name': 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B', |
|
'tool_files': {'new_tool': os.path.join(current_dir, 'data', 'new_tool.json')}, |
|
'additional_tools': ['DirectResponse', 'RequireClarification'], |
|
'default_params': { |
|
'force_finish': True, |
|
'enable_checker': True, |
|
'step_rag_num': 10, |
|
'seed': 100 |
|
} |
|
} |
|
|
|
|
|
UI_CONFIG = { |
|
'description': ''' |
|
<div> |
|
<h1 style="text-align: center;">TxAgent: Therapeutic Reasoning AI</h1> |
|
<p style="text-align: center;">Precision therapeutics with multi-step reasoning</p> |
|
</div> |
|
''', |
|
'disclaimer': ''' |
|
<div style="color: #666; font-size: 0.9em; margin-top: 20px;"> |
|
<strong>Disclaimer:</strong> For informational purposes only, not medical advice. |
|
</div> |
|
''' |
|
} |
|
|
|
|
|
EXAMPLE_QUESTIONS = [ |
|
"How should dosage be adjusted for hepatic impairment with Journavx?", |
|
"Is Xolremdi suitable with Prozac for WHIM syndrome?", |
|
"What are Warfarin-Amiodarone contraindications?" |
|
] |
|
|
|
|
|
class TxAgentApplication: |
|
def __init__(self): |
|
self.agent = None |
|
self.is_initialized = False |
|
|
|
def initialize_agent(self): |
|
if self.is_initialized: |
|
return "Model already initialized" |
|
|
|
try: |
|
self.agent = TxAgent( |
|
MODEL_CONFIG['model_name'], |
|
MODEL_CONFIG['rag_model_name'], |
|
tool_files_dict=MODEL_CONFIG['tool_files'], |
|
**MODEL_CONFIG['default_params'] |
|
) |
|
self.agent.init_model() |
|
self.is_initialized = True |
|
return "TxAgent initialized successfully" |
|
except Exception as e: |
|
return f"Initialization failed: {str(e)}" |
|
|
|
def chat(self, message, chat_history): |
|
if not self.is_initialized: |
|
yield "Error: Please initialize the model first" |
|
return |
|
|
|
try: |
|
|
|
messages = [] |
|
for user, assistant in chat_history: |
|
messages.append({"role": "user", "content": user}) |
|
messages.append({"role": "assistant", "content": assistant}) |
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
full_response = "" |
|
for chunk in self.agent.run_gradio_chat( |
|
messages, |
|
temperature=0.3, |
|
max_new_tokens=1024, |
|
max_tokens=8192, |
|
multi_agent=False, |
|
conversation=[], |
|
max_round=30 |
|
): |
|
full_response += chunk |
|
yield [(message, full_response)] |
|
|
|
except Exception as e: |
|
yield [(message, f"Error: {str(e)}")] |
|
|
|
|
|
def create_interface(): |
|
app = TxAgentApplication() |
|
|
|
with gr.Blocks(title="TxAgent", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(UI_CONFIG['description']) |
|
|
|
|
|
with gr.Row(): |
|
init_btn = gr.Button("Initialize TxAgent", variant="primary") |
|
init_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
height=600, |
|
label="Conversation", |
|
avatar_images=( |
|
"https://example.com/user.png", |
|
"https://example.com/bot.png" |
|
) |
|
) |
|
|
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
label="Your Question", |
|
placeholder="Ask about drug interactions or treatments...", |
|
scale=4 |
|
) |
|
submit_btn = gr.Button("Submit", variant="primary", scale=1) |
|
|
|
|
|
gr.Examples( |
|
examples=EXAMPLE_QUESTIONS, |
|
inputs=msg, |
|
label="Try these examples:" |
|
) |
|
|
|
gr.Markdown(UI_CONFIG['disclaimer']) |
|
|
|
|
|
init_btn.click( |
|
app.initialize_agent, |
|
outputs=init_status |
|
) |
|
|
|
msg.submit( |
|
app.chat, |
|
[msg, chatbot], |
|
[chatbot] |
|
) |
|
|
|
submit_btn.click( |
|
app.chat, |
|
[msg, chatbot], |
|
[chatbot] |
|
).then( |
|
lambda: "", None, msg |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
|
|
|
|
interface.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True, |
|
enable_queue=True |
|
) |