File size: 4,164 Bytes
9aeb1dd
1c98688
bb17715
1c98688
 
410d25f
1c98688
79fb3cd
1c98688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

tx_app = None  # global agent

def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
    global tx_app
    if tx_app is None:
        return chat_history + [("", "⚠️ Model is still loading. Please wait a few seconds and try again.")]

    try:
        if not isinstance(message, str) or len(message.strip()) < 10:
            return chat_history + [("", "Please enter a longer message.")]

        if chat_history and isinstance(chat_history[0], dict):
            chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]

        response = ""
        for chunk in tx_app.run_gradio_chat(
            message=message.strip(),
            history=chat_history,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            max_token=max_tokens,
            call_agent=multi_agent,
            conversation=conversation_state,
            max_round=max_round,
            seed=42,
        ):
            if isinstance(chunk, dict):
                response += chunk.get("content", "")
            elif isinstance(chunk, str):
                response += chunk
            else:
                response += str(chunk)

            yield chat_history + [("user", message), ("assistant", response)]
    except Exception as e:
        logger.error(f"Respond error: {e}")
        yield chat_history + [("", f"⚠️ Error: {e}")]

# Define Gradio app at module level so Hugging Face Spaces can find it
with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
    gr.Markdown("# 🧠 TxAgent Biomedical Assistant")

    chatbot = gr.Chatbot(label="Conversation", height=600, type="messages")
    msg = gr.Textbox(label="Your medical query", placeholder="Type here...", lines=3)

    with gr.Row():
        temp = gr.Slider(0, 1, value=0.3, label="Temperature")
        max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
        max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
        max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds")
        multi_agent = gr.Checkbox(label="Multi-Agent Mode")

    conversation_state = gr.State([])
    submit = gr.Button("Submit")
    clear = gr.Button("Clear")

    submit.click(
        respond,
        [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
        chatbot
    )
    clear.click(lambda: [], None, chatbot)
    msg.submit(
        respond,
        [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
        chatbot
    )

# 🔥 Safely initialize vLLM inside __main__
if __name__ == "__main__":
    import multiprocessing
    multiprocessing.set_start_method("spawn", force=True)

    import torch
    from txagent import TxAgent
    from importlib.resources import files

    logger.info("🔥 Initializing TxAgent safely in __main__")

    tool_files = {
        "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
        "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
        "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
        "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
    }

    tx_app = TxAgent(
        model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
        rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
        tool_files_dict=tool_files,
        enable_finish=True,
        enable_rag=True,
        enable_summary=False,
        init_rag_num=0,
        step_rag_num=10,
        summary_mode='step',
        summary_skip_last_k=0,
        summary_context_length=None,
        force_finish=True,
        avoid_repeat=True,
        seed=42,
        enable_checker=True,
        enable_chat=False,
        additional_default_tools=["DirectResponse", "RequireClarification"]
    )

    tx_app.init_model()
    logger.info("✅ TxAgent initialized.")