File size: 4,373 Bytes
9aeb1dd
0b9e159
 
dcb29df
229805b
410d25f
81ad366
79fb3cd
 
696fd36
37d892a
 
08baaf7
f15352f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696fd36
f15352f
 
 
 
696fd36
 
 
37d892a
f15352f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37d892a
 
f15352f
 
 
696fd36
83c8341
 
 
08baaf7
696fd36
83c8341
 
 
 
 
 
 
 
 
 
 
 
 
f15352f
83c8341
 
 
 
 
 
 
f15352f
83c8341
 
 
 
696fd36
 
08baaf7
696fd36
f15352f
 
08baaf7
 
696fd36
08baaf7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import logging
from txagent import TxAgent
from tooluniverse import ToolUniverse
from importlib.resources import files

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

tx_app = None  # Global TxAgent instance

def init_txagent():
    logger.info("🔥 Initializing TxAgent...")

    tool_files = {
        "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
        "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
        "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
        "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
    }

    agent = TxAgent(
        model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
        rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
        tool_files_dict=tool_files,
        enable_finish=True,
        enable_rag=True,
        enable_summary=False,
        init_rag_num=0,
        step_rag_num=10,
        summary_mode='step',
        summary_skip_last_k=0,
        summary_context_length=None,
        force_finish=True,
        avoid_repeat=True,
        seed=42,
        enable_checker=True,
        enable_chat=False,
        additional_default_tools=["DirectResponse", "RequireClarification"]
    )

    agent.init_model()
    logger.info("✅ TxAgent fully initialized")
    return agent

def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
    global tx_app
    if tx_app is None:
        return chat_history + [("", "⚠️ Model not ready yet. Please wait a few seconds and try again.")]

    try:
        if not isinstance(message, str) or len(message.strip()) <= 10:
            return chat_history + [("", "Please provide a valid message longer than 10 characters.")]

        if chat_history and isinstance(chat_history[0], dict):
            chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]

        response = ""
        for chunk in tx_app.run_gradio_chat(
            message=message.strip(),
            history=chat_history,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            max_token=max_tokens,
            call_agent=multi_agent,
            conversation=conversation_state,
            max_round=max_round,
            seed=42
        ):
            if isinstance(chunk, dict):
                response += chunk.get("content", "")
            elif isinstance(chunk, str):
                response += chunk
            else:
                response += str(chunk)

            yield chat_history + [("user", message), ("assistant", response)]

    except Exception as e:
        logger.error(f"Error in respond function: {str(e)}")
        yield chat_history + [("", f"⚠️ Error: {str(e)}")]

# ✅ Top-level app object that HF Spaces can detect
with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
    gr.Markdown("# 🧠 TxAgent Biomedical Assistant")

    chatbot = gr.Chatbot(label="Conversation", height=600, type="messages")
    msg = gr.Textbox(label="Your medical query", placeholder="Enter your biomedical question...", lines=3)

    with gr.Row():
        temp = gr.Slider(0, 1, value=0.3, label="Temperature")
        max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
        max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
        max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds")
        multi_agent = gr.Checkbox(label="Multi-Agent Mode")

    submit = gr.Button("Submit")
    clear = gr.Button("Clear")
    conversation_state = gr.State([])

    submit.click(
        respond,
        [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
        chatbot
    )

    clear.click(lambda: [], None, chatbot)

    msg.submit(
        respond,
        [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
        chatbot
    )

    # ✅ hidden init trigger on page load
    hidden_button = gr.Button(visible=False)

    def initialize_agent():
        global tx_app
        tx_app = init_txagent()
        return gr.update(visible=False)

    app.load(hidden_button.click(fn=initialize_agent))