File size: 6,498 Bytes
1bb8be7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896d369
 
ccb72b9
 
 
 
 
 
 
 
 
 
 
 
1bb8be7
 
 
ccb72b9
896d369
ccb72b9
 
 
 
 
 
1bb8be7
 
 
896d369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bb8be7
896d369
1bb8be7
 
896d369
1bb8be7
896d369
1bb8be7
 
 
 
 
 
 
 
 
 
 
896d369
1bb8be7
 
 
 
 
896d369
 
1bb8be7
896d369
 
1bb8be7
 
 
896d369
1bb8be7
 
 
 
 
896d369
1bb8be7
 
 
896d369
1bb8be7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896d369
1bb8be7
 
896d369
1bb8be7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import sys
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect
import json

# Fix path to include src
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))

# Reload TxAgent from txagent.py
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent

# Debug info
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))

# Env vars
current_dir = os.path.abspath(os.path.dirname(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Model config
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
    "new_tool": os.path.join(current_dir, "data", "new_tool.json")
}

# Sample questions
question_examples = [
    ["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
    ["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
]

# Helper: collapsible format with tool name
def format_collapsible(content, tool_name=None):
    if isinstance(content, dict) and "results" in content:
        readable = ""
        for i, result in enumerate(content["results"], 1):
            readable += f"Result {i}:\n"
            for key, value in result.items():
                key_str = key.replace("openfda.", "").replace("_", " ").capitalize()
                val_str = ", ".join(value) if isinstance(value, list) else str(value)
                readable += f"- {key_str}: {val_str}\n"
            readable += "\n"
        formatted = readable.strip()
    elif isinstance(content, (dict, list)):
        formatted = json.dumps(content, indent=2)
    else:
        formatted = str(content)

    title = f"{tool_name} Result" if tool_name else "Answer"

    return (
        "<details style='border: 1px solid #aaa; border-radius: 8px; padding: 10px; margin: 12px 0; background-color: #f8f8f8;'>"
        f"<summary style='font-weight: bold; font-size: 16px; color: #333;'>{title}</summary>"
        f"<pre style='white-space: pre-wrap; font-family: monospace; color: #222; padding-top: 6px;'>{formatted}</pre>"
        "</details>"
    )

# UI setup
def create_ui(agent):
    custom_css = """
        body {
            font-family: Inter, sans-serif;
            background-color: #121212;
            color: #ffffff;
        }
        .gradio-container {
            max-width: 900px;
            margin: auto;
        }
        textarea, input, .gr-button {
            font-size: 16px;
        }
        .gr-button {
            background: linear-gradient(to right, #37B6E9, #4B4CED);
            color: white;
            border-radius: 8px;
            font-weight: bold;
        }
        .gr-button:hover {
            background: linear-gradient(to right, #4B4CED, #37B6E9);
        }
        .gr-chatbot {
            background-color: #1e1e1e;
            border-radius: 10px;
        }
    """

    with gr.Blocks(css=custom_css) as demo:
        gr.Markdown("<h1 style='text-align: center; color: #4B4CED;'>TxAgent: Therapeutic Reasoning</h1>")
        gr.Markdown("Ask biomedical or therapeutic questions. Powered by step-by-step reasoning and intelligent tool use.")

        temperature = gr.Slider(0, 1, value=0.3, label="Temperature")
        max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
        max_tokens = gr.Slider(128, 32000, value=8192, label="Max Total Tokens")
        max_round = gr.Slider(1, 50, value=30, label="Max Rounds")
        multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
        conversation_state = gr.State([])

        chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
        message_input = gr.Textbox(placeholder="Ask your biomedical question...", show_label=False)
        send_button = gr.Button("Send")

        # Chat logic
        def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
            generator = agent.run_gradio_chat(
                message=message,
                history=history,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                max_token=max_tokens,
                call_agent=multi_agent,
                conversation=conversation,
                max_round=max_round
            )

            for update in generator:
                formatted = []
                for m in update:
                    role = m["role"] if isinstance(m, dict) else getattr(m, "role", "assistant")
                    content = m["content"] if isinstance(m, dict) else getattr(m, "content", "")
                    tool_name = m.get("metadata", {}).get("tool") if isinstance(m, dict) else getattr(m, "metadata", {}).get("tool", None)

                    if role == "assistant":
                        content = format_collapsible(content, tool_name)

                    formatted.append({"role": role, "content": content})
                yield formatted

        # Events
        inputs = [message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round]
        send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
        message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)

        gr.Examples(examples=question_examples, inputs=message_input)
        gr.Markdown("<div style='text-align: center; font-size: 0.9em; color: #999;'>This demo is for research purposes only and does not provide medical advice.</div>")

    return demo

# Main
if __name__ == "__main__":
    freeze_support()
    try:
        agent = TxAgent(
            model_name=model_name,
            rag_model_name=rag_model_name,
            tool_files_dict=new_tool_files,
            force_finish=True,
            enable_checker=True,
            step_rag_num=10,
            seed=100,
            additional_default_tools=[]
        )
        agent.init_model()

        if not hasattr(agent, "run_gradio_chat"):
            raise AttributeError("TxAgent missing run_gradio_chat")

        demo = create_ui(agent)
        demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True)

    except Exception as e:
        print(f"❌ App failed to start: {e}")
        raise