File size: 4,255 Bytes
9aeb1dd
0b9e159
 
dcb29df
229805b
410d25f
66e2fa0
81ad366
79fb3cd
 
f15352f
37d892a
 
 
f15352f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37d892a
f15352f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37d892a
 
f15352f
 
 
 
83c8341
 
 
f15352f
83c8341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f15352f
83c8341
 
 
 
 
 
 
f15352f
83c8341
 
 
 
f15352f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import logging
from txagent import TxAgent
from tooluniverse import ToolUniverse
from importlib.resources import files

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

tx_app = None  # Will be initialized later in on_start

def init_txagent():
    """Initialize the TxAgent with proper tool file paths"""
    logger.info("Initializing TxAgent...")

    tool_files = {
        "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
        "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
        "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
        "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
    }

    logger.info(f"Using tool files at: {tool_files}")

    agent = TxAgent(
        model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
        rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
        tool_files_dict=tool_files,
        enable_finish=True,
        enable_rag=True,
        enable_summary=False,
        init_rag_num=0,
        step_rag_num=10,
        summary_mode='step',
        summary_skip_last_k=0,
        summary_context_length=None,
        force_finish=True,
        avoid_repeat=True,
        seed=42,
        enable_checker=True,
        enable_chat=False,
        additional_default_tools=["DirectResponse", "RequireClarification"]
    )

    agent.init_model()
    logger.info("Model loading complete")
    return agent

def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
    global tx_app
    try:
        if not isinstance(message, str) or len(message.strip()) <= 10:
            return chat_history + [("", "Please provide a valid message longer than 10 characters.")]

        if chat_history and isinstance(chat_history[0], dict):
            chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]

        response = ""
        for chunk in tx_app.run_gradio_chat(
            message=message.strip(),
            history=chat_history,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            max_token=max_tokens,
            call_agent=multi_agent,
            conversation=conversation_state,
            max_round=max_round,
            seed=42
        ):
            if isinstance(chunk, dict):
                response += chunk.get("content", "")
            elif isinstance(chunk, str):
                response += chunk
            else:
                response += str(chunk)

            yield chat_history + [("user", message), ("assistant", response)]

    except Exception as e:
        logger.error(f"Error in respond function: {str(e)}")
        yield chat_history + [("", f"⚠️ Error: {str(e)}")]

# Define Gradio UI
with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
    gr.Markdown("# 🧠 TxAgent Biomedical Assistant")

    chatbot = gr.Chatbot(label="Conversation", height=600)

    msg = gr.Textbox(
        label="Your medical query",
        placeholder="Enter your biomedical question...",
        lines=3
    )

    with gr.Row():
        temp = gr.Slider(0, 1, value=0.3, label="Temperature")
        max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
        max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
        max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds")
        multi_agent = gr.Checkbox(label="Multi-Agent Mode")

    submit = gr.Button("Submit")
    clear = gr.Button("Clear")

    conversation_state = gr.State([])

    submit.click(
        respond,
        [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
        chatbot
    )

    clear.click(lambda: [], None, chatbot)

    msg.submit(
        respond,
        [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
        chatbot
    )

    @app.on_start
    def load_model():
        global tx_app
        logger.info("🔥 Loading TxAgent model in Gradio @on_start")
        tx_app = init_txagent()