File size: 4,628 Bytes
1155704
8c814cd
47f0902
9aeb1dd
3f69fbe
9c9d2f8
 
16f16a5
84b4115
16f16a5
 
84b4115
6604d0d
1729ddc
4b98818
3f76413
84b4115
3f76413
 
4b98818
84b4115
0cec600
 
1155704
0cec600
84b4115
0cec600
 
40ad293
0cec600
 
47f0902
 
84b4115
0cec600
47f0902
84b4115
0cec600
 
84b4115
 
0cec600
 
84b4115
47f0902
 
0cec600
47f0902
0cec600
 
84b4115
47f0902
 
 
 
0cec600
84b4115
3f69fbe
6309d92
 
 
 
 
 
 
 
 
 
 
84b4115
 
6309d92
84b4115
 
 
 
9c9d2f8
bb37713
 
 
 
 
 
 
 
 
 
84b4115
 
 
 
 
6309d92
 
84b4115
6309d92
84b4115
 
 
 
 
 
6309d92
84b4115
6309d92
 
 
84b4115
6309d92
 
9c9d2f8
6309d92
3f69fbe
 
 
9c9d2f8
 
3f69fbe
 
 
 
 
 
 
 
9c9d2f8
3f69fbe
84b4115
9c9d2f8
3f69fbe
 
9c9d2f8
3f69fbe
84b4115
9c9d2f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import sys
import random
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect

# === Fix import path BEFORE loading TxAgent
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))

# === Reload to avoid stale cache
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent

# === Debug confirmation
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))

# === Env vars
current_dir = os.path.dirname(os.path.abspath(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# === UI Text
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools</h1>
</div>
'''
INTRO = "Precision therapeutics require multimodal adaptive models..."
LICENSE = "DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE..."

css = """
h1 { text-align: center; }
.gradio-accordion { margin-top: 0 !important; margin-bottom: 0 !important; }
"""
chat_css = """
.gr-button { font-size: 18px !important; }
.gr-button svg { width: 26px !important; height: 26px !important; }
"""

# === Model setup
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
    "new_tool": os.path.join(current_dir, "data", "new_tool.json")
}

# === Sample prompts
question_examples = [
    ["Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering moderate hepatic impairment?"],
    ["A 30-year-old patient is on Prozac for depression and now diagnosed with WHIM syndrome. Is Xolremdi suitable?"]
]

# === UI
def create_ui(agent):
    with gr.Blocks(css=css) as demo:
        gr.Markdown(DESCRIPTION)
        gr.Markdown(INTRO)

        temperature = gr.Slider(0, 1, step=0.1, value=0.3, label="Temperature")
        max_new_tokens = gr.Slider(128, 4096, step=1, value=1024, label="Max New Tokens")
        max_tokens = gr.Slider(128, 32000, step=1, value=8192, label="Max Total Tokens")
        max_round = gr.Slider(1, 50, step=1, value=30, label="Max Rounds")
        multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
        conversation_state = gr.State([])

        chatbot = gr.Chatbot(label="TxAgent", height=700)
        message_input = gr.Textbox(placeholder="Ask a biomedical question...", show_label=False)

        send_btn = gr.Button("Send", variant="primary")

        # === Streaming handler
        def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
            return agent.run_gradio_chat(
                message=message,
                history=history,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                max_token=max_tokens,
                call_agent=multi_agent,
                conversation=conversation,
                max_round=max_round
            )

        # === Submit handlers
        send_btn.click(
            fn=handle_chat,
            inputs=[message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
            outputs=chatbot,
        )

        message_input.submit(
            fn=handle_chat,
            inputs=[message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
            outputs=chatbot,
        )

        # === Example buttons
        gr.Examples(
            examples=question_examples,
            inputs=message_input
        )

        gr.Markdown(LICENSE)

    return demo

# === App start
if __name__ == "__main__":
    freeze_support()
    try:
        agent = TxAgent(
            model_name=model_name,
            rag_model_name=rag_model_name,
            tool_files_dict=new_tool_files,
            force_finish=True,
            enable_checker=True,
            step_rag_num=10,
            seed=100,
            additional_default_tools=["DirectResponse", "RequireClarification"]
        )
        agent.init_model()

        if not hasattr(agent, 'run_gradio_chat'):
            raise AttributeError("TxAgent is missing `run_gradio_chat`. Make sure the correct txagent.py is used.")

        demo = create_ui(agent)
        demo.launch(show_error=True)

    except Exception as e:
        print(f"🚨 App failed to start: {e}")
        raise