File size: 5,913 Bytes
1bb8be7
 
 
 
 
 
 
fe2e04a
1bb8be7
eaac969
1bb8be7
 
eaac969
1bb8be7
 
 
 
fe2e04a
1bb8be7
 
 
fe2e04a
 
 
eaac969
1bb8be7
 
 
 
eaac969
1bb8be7
 
 
 
 
 
eaac969
1bb8be7
 
 
 
 
fe2e04a
 
9dfa570
 
f9a6a36
fe2e04a
 
 
9dfa570
f9a6a36
 
 
 
9dfa570
 
 
 
 
 
 
 
 
 
 
 
 
f9a6a36
9dfa570
4800765
f9a6a36
fe2e04a
 
eaac969
fe2e04a
eaac969
ccb72b9
fe2e04a
 
 
ccb72b9
1bb8be7
fe2e04a
1bb8be7
eaac969
fe2e04a
 
896d369
1bb8be7
fe2e04a
d2631b6
eaac969
1bb8be7
d2631b6
1bb8be7
 
 
d2631b6
 
 
 
1bb8be7
d2631b6
1bb8be7
 
fe2e04a
1bb8be7
fe2e04a
1bb8be7
fe2e04a
eaac969
fe2e04a
 
 
 
1bb8be7
d2631b6
1bb8be7
 
 
 
fe2e04a
1bb8be7
 
 
fe2e04a
1bb8be7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaac969
1bb8be7
 
eaac969
 
 
fe2e04a
eaac969
1bb8be7
eaac969
781bcb9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import sys
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect
import json
import logging

# === Fix path to include src/txagent
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))

# === Import and reload to ensure correct file
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent

# === Debug info
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))

# === Logging
logging.basicConfig(level=logging.INFO)

# === Environment
current_dir = os.path.abspath(os.path.dirname(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# === Model config
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
    "new_tool": os.path.join(current_dir, "data", "new_tool.json")
}

# === Example prompts
question_examples = [
    ["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
    ["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
]

# === Extract tool name and format output
def extract_tool_name_and_clean_content(msg):
    import re

    tool_name = "Tool Result"
    content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", "")
    tool_calls = msg.get("tool_calls") if isinstance(msg, dict) else getattr(msg, "tool_calls", None)

    # Try to extract tool name from tool_calls JSON
    if tool_calls:
        try:
            if isinstance(tool_calls, str):
                tool_calls = json.loads(tool_calls)
            if isinstance(tool_calls, list) and tool_calls:
                tool_name = tool_calls[0].get("name", "Tool Result")
        except Exception as e:
            logging.warning(f"[extract_tool_name] Failed tool_calls parsing: {e}")

    # Try fallback: extract from [TOOL_CALLS] JSON inside raw content
    if "TOOL_CALLS" in str(content):
        try:
            match = re.search(r"\[TOOL_CALLS\](\[.*?\])", str(content))
            if match:
                embedded = json.loads(match.group(1))
                if isinstance(embedded, list) and embedded:
                    tool_name = embedded[0].get("name", "Tool Result")
        except Exception as e:
            logging.warning(f"[extract_tool_name] Failed TOOL_CALLS content parse: {e}")

    if isinstance(content, (dict, list)):
        content = json.dumps(content, indent=2)
    return f"Tool: {tool_name}", content

# === Format answer in collapsible box
def format_collapsible(content, title="Answer"):
    return (
        f"<details style='border: 1px solid #ccc; border-radius: 8px; padding: 10px; margin-top: 10px;'>"
        f"<summary style='font-size: 16px; font-weight: bold; color: #3B82F6;'>{title}</summary>"
        f"<div style='margin-top: 8px; font-size: 15px; line-height: 1.6; white-space: pre-wrap;'>{content}</div></details>"
    )

# === Build UI
def create_ui(agent):
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("<h1 style='text-align: center;'>💊 TxAgent: Therapeutic Reasoning</h1>")
        gr.Markdown("Ask biomedical or therapeutic questions. Powered by tool-augmented reasoning.")

        chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
        message_input = gr.Textbox(placeholder="Ask a biomedical question...", show_label=False)
        send_button = gr.Button("Send", variant="primary")
        conversation_state = gr.State([])

        def handle_chat(message, history, conversation):
            generator = agent.run_gradio_chat(
                message=message,
                history=history,
                temperature=0.3,
                max_new_tokens=1024,
                max_token=8192,
                call_agent=False,
                conversation=conversation,
                max_round=30
            )
            for update in generator:
                formatted = []
                for m in update:
                    role = m.get("role") if isinstance(m, dict) else getattr(m, "role", "assistant")
                    if role == "assistant":
                        title, clean = extract_tool_name_and_clean_content(m)
                        content = format_collapsible(clean, title)
                    else:
                        content = m.get("content") if isinstance(m, dict) else getattr(m, "content", "")
                    formatted.append({"role": role, "content": content})
                yield formatted

        inputs = [message_input, chatbot, conversation_state]
        send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
        message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)

        gr.Examples(examples=question_examples, inputs=message_input)
        gr.Markdown("<small style='color: gray;'>DISCLAIMER: This demo is for research purposes only and does not provide medical advice.</small>")

    return demo

# === Main
if __name__ == "__main__":
    freeze_support()
    try:
        agent = TxAgent(
            model_name=model_name,
            rag_model_name=rag_model_name,
            tool_files_dict=new_tool_files,
            force_finish=True,
            enable_checker=True,
            step_rag_num=10,
            seed=100,
            additional_default_tools=[]
        )
        agent.init_model()

        if not hasattr(agent, "run_gradio_chat"):
            raise AttributeError("❌ TxAgent is missing `run_gradio_chat`.")

        demo = create_ui(agent)
        demo.queue().launch(
            server_name="0.0.0.0",
            server_port=7860,
            show_error=True
        )
    except Exception as e:
        print(f"❌ App failed to start: {e}")
        raise