File size: 4,193 Bytes
1155704
8c814cd
47f0902
9aeb1dd
3f69fbe
9c9d2f8
 
16f16a5
91d1d93
16f16a5
 
91d1d93
6604d0d
1729ddc
4b98818
91d1d93
3f76413
91d1d93
3f76413
 
4b98818
84b4115
0cec600
 
1155704
0cec600
91d1d93
0cec600
91d1d93
0cec600
91d1d93
 
 
 
47f0902
 
0cec600
47f0902
0cec600
 
47f0902
91d1d93
 
47f0902
0cec600
91d1d93
3f69fbe
91d1d93
6309d92
 
 
 
 
 
 
 
 
 
91d1d93
84b4115
 
 
 
91d1d93
9c9d2f8
bb37713
 
 
 
 
 
 
 
 
 
84b4115
 
 
91d1d93
6309d92
 
84b4115
6309d92
84b4115
91d1d93
84b4115
 
 
6309d92
84b4115
6309d92
 
 
84b4115
6309d92
 
91d1d93
6309d92
3f69fbe
 
 
9c9d2f8
 
3f69fbe
 
 
 
 
 
 
 
9c9d2f8
91d1d93
 
9c9d2f8
3f69fbe
 
9c9d2f8
3f69fbe
91d1d93
9c9d2f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import sys
import random
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect

# === Path fix
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))

# === Reload to avoid stale module
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent
from gradio import ChatMessage

# === Debug
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))

# === Env vars
current_dir = os.path.dirname(os.path.abspath(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# === UI text
DESCRIPTION = '''
<h1 style="text-align: center;">TxAgent: AI for Therapeutic Reasoning</h1>
'''
INTRO = "Ask biomedical or therapeutic questions. Results are powered by tools and reasoning."
LICENSE = "DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE."

# === Model & tool config
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
    "new_tool": os.path.join(current_dir, "data", "new_tool.json")
}

question_examples = [
    ["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
    ["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
]

# === Gradio UI
def create_ui(agent):
    with gr.Blocks() as demo:
        gr.Markdown(DESCRIPTION)
        gr.Markdown(INTRO)

        temperature = gr.Slider(0, 1, step=0.1, value=0.3, label="Temperature")
        max_new_tokens = gr.Slider(128, 4096, step=1, value=1024, label="Max New Tokens")
        max_tokens = gr.Slider(128, 32000, step=1, value=8192, label="Max Total Tokens")
        max_round = gr.Slider(1, 50, step=1, value=30, label="Max Rounds")
        multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
        conversation_state = gr.State([])

        chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
        message_input = gr.Textbox(placeholder="Ask a biomedical question...", show_label=False)
        send_btn = gr.Button("Send", variant="primary")

        def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
            # Ensure response is a generator that yields list of {role, content} dictionaries
            return agent.run_gradio_chat(
                message=message,
                history=history,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                max_token=max_tokens,
                call_agent=multi_agent,
                conversation=conversation,
                max_round=max_round
            )

        send_btn.click(
            fn=handle_chat,
            inputs=[message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
            outputs=chatbot
        )

        message_input.submit(
            fn=handle_chat,
            inputs=[message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
            outputs=chatbot
        )

        gr.Examples(
            examples=question_examples,
            inputs=message_input
        )

        gr.Markdown(LICENSE)

    return demo

# === App startup
if __name__ == "__main__":
    freeze_support()
    try:
        agent = TxAgent(
            model_name=model_name,
            rag_model_name=rag_model_name,
            tool_files_dict=new_tool_files,
            force_finish=True,
            enable_checker=True,
            step_rag_num=10,
            seed=100,
            additional_default_tools=["DirectResponse", "RequireClarification"]
        )
        agent.init_model()

        if not hasattr(agent, "run_gradio_chat"):
            raise AttributeError("❌ TxAgent is missing `run_gradio_chat`.")

        demo = create_ui(agent)
        demo.launch(show_error=True)

    except Exception as e:
        print(f"🚨 Startup error: {e}")
        raise