File size: 4,810 Bytes
9aeb1dd
0b9e159
37d892a
0b9e159
dcb29df
229805b
410d25f
66e2fa0
81ad366
79fb3cd
 
37d892a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc06321
 
37d892a
dc06321
90e4214
0b9e159
dc06321
0b9e159
dcb29df
0b9e159
dcb29df
 
 
 
 
0b9e159
90e4214
0b9e159
 
 
 
 
 
 
dcb29df
 
 
dc06321
dcb29df
0b9e159
dcb29df
83c8341
dcb29df
0b9e159
dc06321
0b9e159
dcb29df
81ad366
37d892a
83c8341
 
37d892a
83c8341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37d892a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import logging
import multiprocessing
from txagent import TxAgent
from tooluniverse import ToolUniverse
from importlib.resources import files

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

tx_app = None  # Global holder for app instance (for Gradio to use)

def init_txagent():
    """Initialize the TxAgent with proper tool file paths"""
    try:
        multiprocessing.set_start_method("spawn", force=True)
        logger.info("Initializing TxAgent...")

        tool_files = {
            "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
            "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
            "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
            "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
        }

        logger.info(f"Using tool files at: {tool_files}")

        agent = TxAgent(
            model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
            rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
            tool_files_dict=tool_files,
            enable_finish=True,
            enable_rag=True,
            enable_summary=False,
            init_rag_num=0,
            step_rag_num=10,
            summary_mode='step',
            summary_skip_last_k=0,
            summary_context_length=None,
            force_finish=True,
            avoid_repeat=True,
            seed=42,
            enable_checker=True,
            enable_chat=False,
            additional_default_tools=["DirectResponse", "RequireClarification"]
        )

        agent.init_model()
        logger.info("Model loading complete")
        return agent

    except Exception as e:
        logger.error(f"Initialization failed: {str(e)}")
        raise

class TxAgentApp:
    def __init__(self):
        self.agent = init_txagent()

    def respond(self, message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
        """Handle streaming responses with Gradio"""
        try:
            if not isinstance(message, str) or len(message.strip()) <= 10:
                return chat_history + [("", "Please provide a valid message longer than 10 characters.")]

            if chat_history and isinstance(chat_history[0], dict):
                chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]

            response = ""
            for chunk in self.agent.run_gradio_chat(
                message=message.strip(),
                history=chat_history,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                max_token=max_tokens,
                call_agent=multi_agent,
                conversation=conversation_state,
                max_round=max_round,
                seed=42
            ):
                if isinstance(chunk, dict):
                    response += chunk.get("content", "")
                elif isinstance(chunk, str):
                    response += chunk
                else:
                    response += str(chunk)

                yield chat_history + [("user", message), ("assistant", response)]

        except Exception as e:
            logger.error(f"Error in respond function: {str(e)}")
            yield chat_history + [("", f"⚠️ Error: {str(e)}")]

# Initialize the agent safely
tx_app = TxAgentApp()

# Define Gradio UI interface
with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
    gr.Markdown("# 🧠 TxAgent Biomedical Assistant")

    chatbot = gr.Chatbot(
        label="Conversation",
        height=600
    )

    msg = gr.Textbox(
        label="Your medical query",
        placeholder="Enter your biomedical question...",
        lines=3
    )

    with gr.Row():
        temp = gr.Slider(0, 1, value=0.3, label="Temperature")
        max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
        max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
        max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds")
        multi_agent = gr.Checkbox(label="Multi-Agent Mode")

    submit = gr.Button("Submit")
    clear = gr.Button("Clear")

    conversation_state = gr.State([])

    submit.click(
        tx_app.respond,
        [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
        chatbot
    )

    clear.click(lambda: [], None, chatbot)

    msg.submit(
        tx_app.respond,
        [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
        chatbot
    )

# This `app` will be served by Hugging Face automatically