File size: 5,212 Bytes
1155704
8c814cd
9aeb1dd
3f69fbe
9c9d2f8
 
16f16a5
537f975
6604d0d
1729ddc
4b98818
3f76413
537f975
0cec600
1155704
0cec600
47f0902
 
0cec600
47f0902
0cec600
 
47f0902
ecaf6bd
91d1d93
47f0902
0cec600
ecaf6bd
 
 
 
 
 
 
 
 
 
3ef6302
3f69fbe
91d1d93
537f975
ecaf6bd
6309d92
537f975
 
 
 
6309d92
 
 
ecaf6bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537f975
 
84b4115
 
ecaf6bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537f975
84b4115
ecaf6bd
 
6309d92
 
84b4115
6309d92
ecaf6bd
 
84b4115
 
537f975
ecaf6bd
84b4115
6309d92
 
 
3f69fbe
 
 
9c9d2f8
 
3f69fbe
 
 
 
 
ecaf6bd
3f69fbe
 
9c9d2f8
91d1d93
ecaf6bd
9c9d2f8
3f69fbe
 
9c9d2f8
3f69fbe
ecaf6bd
9c9d2f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import sys
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent

current_dir = os.path.abspath(os.path.dirname(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
    "new_tool": os.path.join(current_dir, "data", "new_tool.json")
}

question_examples = [
    ["Given a patient with WHIM syndrome on antibiotics, is Xolremdi + fluconazole advisable?"],
    ["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
]

def extract_sections(content):
    """
    Example extractor splitting into sections. You should improve it to parse actual keys.
    """
    return {
        "Summary": content[:1000],  # simulate
        "Clinical Studies": content[1000:2500],
        "Drug Interactions": "See CYP3A4 interactions...",
        "Pharmacokinetics": "- Absorption: Oral\n- Half-life: ~24h\n- Metabolized by CYP3A4"
    }

def create_ui(agent):
    with gr.Blocks() as demo:
        gr.Markdown("<h1 style='text-align: center;'>TxAgent: Therapeutic Reasoning</h1>")
        gr.Markdown("Ask therapeutic or biomedical questions. Results are categorized for readability.")

        temperature = gr.Slider(0, 1, value=0.3, label="Temperature")
        max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
        max_tokens = gr.Slider(128, 32000, value=8192, label="Max Total Tokens")
        max_round = gr.Slider(1, 50, value=30, label="Max Rounds")
        multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
        conversation_state = gr.State([])

        chatbot = gr.Tabs()
        summary_box = gr.Markdown(label="Summary")
        studies_box = gr.Markdown(label="Clinical Studies")
        interactions_box = gr.Markdown(label="Drug Interactions")
        kinetics_box = gr.Markdown(label="Pharmacokinetics")

        with chatbot:
            with gr.TabItem("Summary"):
                summary_display = summary_box
            with gr.TabItem("Clinical Studies"):
                studies_display = studies_box
            with gr.TabItem("Drug Interactions"):
                interactions_display = interactions_box
            with gr.TabItem("Pharmacokinetics"):
                kinetics_display = kinetics_box

        message_input = gr.Textbox(placeholder="Ask your biomedical question...", show_label=False)
        send_button = gr.Button("Send", variant="primary")

        def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
            generator = agent.run_gradio_chat(
                message=message,
                history=history,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                max_token=max_tokens,
                call_agent=multi_agent,
                conversation=conversation,
                max_round=max_round
            )

            final_output = ""
            for update in generator:
                for m in update:
                    role = m["role"] if isinstance(m, dict) else getattr(m, "role", "assistant")
                    content = m["content"] if isinstance(m, dict) else getattr(m, "content", "")
                    if role == "assistant":
                        final_output += content + "\n"

            sections = extract_sections(final_output)
            return sections["Summary"], sections["Clinical Studies"], sections["Drug Interactions"], sections["Pharmacokinetics"]

        send_button.click(
            fn=handle_chat,
            inputs=[message_input, [], temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
            outputs=[summary_box, studies_box, interactions_box, kinetics_box]
        )

        message_input.submit(
            fn=handle_chat,
            inputs=[message_input, [], temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round],
            outputs=[summary_box, studies_box, interactions_box, kinetics_box]
        )

        gr.Examples(examples=question_examples, inputs=message_input)
        gr.Markdown("**DISCLAIMER**: For research only. Not medical advice.")

    return demo

if __name__ == "__main__":
    freeze_support()
    try:
        agent = TxAgent(
            model_name=model_name,
            rag_model_name=rag_model_name,
            tool_files_dict=new_tool_files,
            force_finish=True,
            enable_checker=True,
            step_rag_num=10,
            seed=100,
            additional_default_tools=[]
        )
        agent.init_model()

        if not hasattr(agent, "run_gradio_chat"):
            raise AttributeError("❌ TxAgent missing `run_gradio_chat`")

        demo = create_ui(agent)
        demo.launch(show_error=True)

    except Exception as e:
        print(f"❌ App failed to start: {e}")
        raise