Update app.py
Browse files
app.py
CHANGED
@@ -1,275 +1,202 @@
|
|
|
|
1 |
import random
|
2 |
-
import datetime
|
3 |
-
import sys
|
4 |
-
from txagent import TxAgent
|
5 |
-
import spaces
|
6 |
import gradio as gr
|
7 |
-
import
|
8 |
|
9 |
-
#
|
10 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
11 |
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
|
|
12 |
|
13 |
-
#
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
''
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
LICENSE = """
|
26 |
-
We welcome your feedback and suggestions to enhance your experience with TxAgent, and if you're interested in collaboration, please email Marinka Zitnik and Shanghua Gao.
|
27 |
-
|
28 |
-
### Medical Advice Disclaimer
|
29 |
-
DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE
|
30 |
-
The information, including but not limited to, text, graphics, images and other material contained on this website are for informational purposes only. No material on this site is intended to be a substitute for professional medical advice, diagnosis or treatment. Always seek the advice of your physician or other qualified health care provider with any questions you may have regarding a medical condition or treatment and before undertaking a new health care regimen, and never disregard professional medical advice or delay in seeking it because of something you have read on this website.
|
31 |
-
"""
|
32 |
-
|
33 |
-
PLACEHOLDER = """
|
34 |
-
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
|
35 |
-
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
|
36 |
-
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
|
37 |
-
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️ (top-right) to remove previous context before submitting a new question.</p>
|
38 |
-
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
|
39 |
-
</div>
|
40 |
-
"""
|
41 |
-
|
42 |
-
css = """
|
43 |
-
h1 {
|
44 |
-
text-align: center;
|
45 |
-
display: block;
|
46 |
}
|
47 |
|
48 |
-
#
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
margin-top: 0px !important;
|
62 |
-
margin-bottom: 0px !important;
|
63 |
}
|
64 |
-
"""
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
""
|
70 |
-
|
71 |
-
# Configuration variables (safe to keep at module level)
|
72 |
-
model_name = 'mims-harvard/TxAgent-T1-Llama-3.1-8B'
|
73 |
-
rag_model_name = 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B'
|
74 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
75 |
-
|
76 |
-
question_examples = [
|
77 |
-
['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'],
|
78 |
-
['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'],
|
79 |
-
['A 30-year-old patient is taking Prozac to treat their depression. They were recently diagnosed with WHIM syndrome and require a treatment for that condition as well. Is Xolremdi suitable for this patient, considering contraindications?'],
|
80 |
]
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed):
|
89 |
-
# Update model instance parameters dynamically
|
90 |
-
updated_params = agent.update_parameters(
|
91 |
-
enable_finish=enable_finish,
|
92 |
-
enable_rag=enable_rag,
|
93 |
-
enable_summary=enable_summary,
|
94 |
-
init_rag_num=init_rag_num,
|
95 |
-
step_rag_num=step_rag_num,
|
96 |
-
skip_last_k=skip_last_k,
|
97 |
-
summary_mode=summary_mode,
|
98 |
-
summary_skip_last_k=summary_skip_last_k,
|
99 |
-
summary_context_length=summary_context_length,
|
100 |
-
force_finish=force_finish,
|
101 |
-
seed=seed,
|
102 |
-
)
|
103 |
-
return updated_params
|
104 |
-
|
105 |
-
def update_seed():
|
106 |
-
# Update model instance parameters dynamically
|
107 |
-
seed = random.randint(0, 10000)
|
108 |
-
updated_params = agent.update_parameters(
|
109 |
-
seed=seed,
|
110 |
-
)
|
111 |
-
return updated_params
|
112 |
|
113 |
-
def
|
114 |
-
|
115 |
-
|
116 |
-
previous_prompt = history[retry_data.index]['content']
|
117 |
-
print("previous_prompt", previous_prompt)
|
118 |
-
yield from agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}], temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
|
119 |
-
|
120 |
-
PASSWORD = "mypassword"
|
121 |
-
|
122 |
-
def check_password(input_password):
|
123 |
-
if input_password == PASSWORD:
|
124 |
-
return gr.update(visible=True), ""
|
125 |
-
else:
|
126 |
-
return gr.update(visible=False), "Incorrect password, try again!"
|
127 |
-
|
128 |
-
# Create the Gradio interface
|
129 |
-
def create_interface(agent):
|
130 |
-
conversation_state = gr.State([])
|
131 |
-
chatbot = gr.Chatbot(height=800, placeholder=PLACEHOLDER,
|
132 |
-
label='TxAgent', type="messages", show_copy_button=True)
|
133 |
-
|
134 |
-
with gr.Blocks(css=css) as demo:
|
135 |
-
gr.Markdown(DESCRIPTION)
|
136 |
-
gr.Markdown(INTRO)
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
chatbot=chatbot,
|
154 |
-
fill_height=True, fill_width=True, stop_btn=True,
|
155 |
-
additional_inputs_accordion=gr.Accordion(
|
156 |
-
label="⚙️ Inference Parameters", open=False, render=False),
|
157 |
-
additional_inputs=[
|
158 |
-
temperature_state, max_new_tokens_state, max_tokens_state,
|
159 |
-
gr.Checkbox(
|
160 |
-
label="Activate multi-agent reasoning mode (it requires additional time but offers a more comprehensive analysis).",
|
161 |
-
value=False, render=False),
|
162 |
-
conversation_state,
|
163 |
-
max_round_state,
|
164 |
-
gr.Number(label="Seed", value=100, render=False)
|
165 |
-
],
|
166 |
-
examples=question_examples,
|
167 |
-
cache_examples=False,
|
168 |
-
css=chat_css,
|
169 |
)
|
170 |
-
|
171 |
-
with gr.
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
value=default_temperature,
|
177 |
-
label="Temperature"
|
178 |
-
)
|
179 |
-
max_new_tokens_slider = gr.Slider(
|
180 |
-
minimum=128,
|
181 |
-
maximum=4096,
|
182 |
-
step=1,
|
183 |
-
value=default_max_new_tokens,
|
184 |
-
label="Max new tokens"
|
185 |
-
)
|
186 |
-
max_tokens_slider = gr.Slider(
|
187 |
-
minimum=128,
|
188 |
-
maximum=32000,
|
189 |
-
step=1,
|
190 |
-
value=default_max_tokens,
|
191 |
-
label="Max tokens"
|
192 |
)
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
with gr.Row():
|
215 |
-
with gr.Column(scale=1):
|
216 |
-
with gr.Accordion("⚙️ Model Loading", open=False):
|
217 |
-
model_name_input = gr.Textbox(
|
218 |
-
label="Enter model path", value=model_name)
|
219 |
-
load_model_btn = gr.Button(value="Load Model")
|
220 |
-
load_model_btn.click(
|
221 |
-
agent.load_models, inputs=model_name_input, outputs=gr.Textbox(label="Status"))
|
222 |
-
|
223 |
-
with gr.Column(scale=1):
|
224 |
-
with gr.Accordion("⚙️ Functional Parameters", open=False):
|
225 |
-
enable_finish = gr.Checkbox(label="Enable Finish", value=True)
|
226 |
-
enable_rag = gr.Checkbox(label="Enable RAG", value=True)
|
227 |
-
enable_summary = gr.Checkbox(label="Enable Summary", value=False)
|
228 |
-
init_rag_num = gr.Number(label="Initial RAG Num", value=0)
|
229 |
-
step_rag_num = gr.Number(label="Step RAG Num", value=10)
|
230 |
-
skip_last_k = gr.Number(label="Skip Last K", value=0)
|
231 |
-
summary_mode = gr.Textbox(label="Summary Mode", value='step')
|
232 |
-
summary_skip_last_k = gr.Number(label="Summary Skip Last K", value=0)
|
233 |
-
summary_context_length = gr.Number(label="Summary Context Length", value=None)
|
234 |
-
force_finish = gr.Checkbox(label="Force FinalAnswer", value=True)
|
235 |
-
seed = gr.Number(label="Seed", value=100)
|
236 |
-
|
237 |
-
submit_btn = gr.Button("Update Parameters")
|
238 |
-
updated_parameters_output = gr.JSON()
|
239 |
-
|
240 |
-
submit_btn.click(
|
241 |
-
fn=update_model_parameters,
|
242 |
-
inputs=[enable_finish, enable_rag, enable_summary, init_rag_num,
|
243 |
-
step_rag_num, skip_last_k, summary_mode, summary_skip_last_k,
|
244 |
-
summary_context_length, force_finish, seed],
|
245 |
-
outputs=updated_parameters_output
|
246 |
-
)
|
247 |
-
|
248 |
-
submit_button = gr.Button("Submit")
|
249 |
-
submit_button.click(
|
250 |
-
check_password,
|
251 |
-
inputs=password_input,
|
252 |
-
outputs=[protected_accordion, incorrect_message]
|
253 |
-
)
|
254 |
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
return demo
|
258 |
|
|
|
259 |
if __name__ == "__main__":
|
260 |
-
# Initialize the agent only when running directly
|
261 |
-
agent = TxAgent(
|
262 |
-
model_name,
|
263 |
-
rag_model_name,
|
264 |
-
tool_files_dict=new_tool_files,
|
265 |
-
force_finish=True,
|
266 |
-
enable_checker=True,
|
267 |
-
step_rag_num=10,
|
268 |
-
seed=100,
|
269 |
-
additional_default_tools=['DirectResponse', 'RequireClarification']
|
270 |
-
)
|
271 |
-
agent.init_model()
|
272 |
-
|
273 |
# Create and launch the interface
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
import random
|
|
|
|
|
|
|
|
|
3 |
import gradio as gr
|
4 |
+
from txagent import TxAgent
|
5 |
|
6 |
+
# ========== Configuration ==========
|
7 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
8 |
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
9 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
10 |
|
11 |
+
# Model configuration
|
12 |
+
MODEL_CONFIG = {
|
13 |
+
'model_name': 'mims-harvard/TxAgent-T1-Llama-3.1-8B',
|
14 |
+
'rag_model_name': 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B',
|
15 |
+
'tool_files': {'new_tool': os.path.join(current_dir, 'data', 'new_tool.json')},
|
16 |
+
'additional_tools': ['DirectResponse', 'RequireClarification'],
|
17 |
+
'default_params': {
|
18 |
+
'force_finish': True,
|
19 |
+
'enable_checker': True,
|
20 |
+
'step_rag_num': 10,
|
21 |
+
'seed': 100
|
22 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
}
|
24 |
|
25 |
+
# UI Configuration
|
26 |
+
UI_CONFIG = {
|
27 |
+
'description': '''
|
28 |
+
<div>
|
29 |
+
<h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning</h1>
|
30 |
+
<p style="text-align: center;">Precision therapeutics AI with multi-step reasoning and real-time biomedical knowledge</p>
|
31 |
+
</div>
|
32 |
+
''',
|
33 |
+
'disclaimer': '''
|
34 |
+
<div style="color: #666; font-size: 0.9em; margin-top: 20px;">
|
35 |
+
<strong>Disclaimer:</strong> This tool provides informational support only and is not a substitute for professional medical advice.
|
36 |
+
</div>
|
37 |
+
'''
|
|
|
|
|
38 |
}
|
|
|
39 |
|
40 |
+
# Example questions
|
41 |
+
EXAMPLE_QUESTIONS = [
|
42 |
+
"How should dosage be adjusted for a 50-year-old with hepatic impairment taking Journavx?",
|
43 |
+
"Is Xolremdi suitable for a patient with WHIM syndrome already taking Prozac?",
|
44 |
+
"What are the contraindications for combining Warfarin with Amiodarone?"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
]
|
46 |
|
47 |
+
# ========== Application Class ==========
|
48 |
+
class TxAgentApplication:
|
49 |
+
def __init__(self):
|
50 |
+
self.agent = None
|
51 |
+
self.is_initialized = False
|
52 |
+
self.loading = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
def initialize_agent(self, progress=gr.Progress()):
|
55 |
+
if self.is_initialized:
|
56 |
+
return True, "Model already initialized"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
self.loading = True
|
59 |
+
try:
|
60 |
+
progress(0.1, desc="Initializing TxAgent...")
|
61 |
+
|
62 |
+
# Initialize the agent
|
63 |
+
self.agent = TxAgent(
|
64 |
+
MODEL_CONFIG['model_name'],
|
65 |
+
MODEL_CONFIG['rag_model_name'],
|
66 |
+
tool_files_dict=MODEL_CONFIG['tool_files'],
|
67 |
+
**MODEL_CONFIG['default_params']
|
68 |
+
)
|
69 |
+
|
70 |
+
progress(0.3, desc="Loading language model...")
|
71 |
+
self.agent.init_model()
|
72 |
+
|
73 |
+
progress(0.8, desc="Finalizing setup...")
|
74 |
+
self.is_initialized = True
|
75 |
+
self.loading = False
|
76 |
+
|
77 |
+
return True, "TxAgent initialized successfully"
|
78 |
+
except Exception as e:
|
79 |
+
self.loading = False
|
80 |
+
return False, f"Initialization failed: {str(e)}"
|
81 |
+
|
82 |
+
def chat(self, message, chat_history, temperature, max_new_tokens):
|
83 |
+
if not self.is_initialized:
|
84 |
+
yield "Error: Model not initialized. Please initialize first."
|
85 |
+
return
|
86 |
+
|
87 |
+
try:
|
88 |
+
# Convert Gradio chat history to agent format
|
89 |
+
messages = []
|
90 |
+
for turn in chat_history:
|
91 |
+
messages.append({"role": "user", "content": turn[0]})
|
92 |
+
messages.append({"role": "assistant", "content": turn[1]})
|
93 |
+
messages.append({"role": "user", "content": message})
|
94 |
+
|
95 |
+
# Stream the response
|
96 |
+
full_response = ""
|
97 |
+
for chunk in self.agent.run_gradio_chat(
|
98 |
+
messages,
|
99 |
+
temperature=temperature,
|
100 |
+
max_new_tokens=max_new_tokens,
|
101 |
+
max_tokens=8192,
|
102 |
+
multi_agent=False,
|
103 |
+
conversation=[],
|
104 |
+
max_round=30
|
105 |
+
):
|
106 |
+
full_response += chunk
|
107 |
+
yield full_response
|
108 |
+
|
109 |
+
except Exception as e:
|
110 |
+
yield f"Error during chat: {str(e)}"
|
111 |
+
|
112 |
+
# ========== Gradio Interface ==========
|
113 |
+
def create_interface():
|
114 |
+
app = TxAgentApplication()
|
115 |
+
|
116 |
+
with gr.Blocks(title="TxAgent", theme=gr.themes.Soft()) as demo:
|
117 |
+
# Header Section
|
118 |
+
gr.Markdown(UI_CONFIG['description'])
|
119 |
|
120 |
+
# Initialization Section
|
121 |
+
with gr.Row():
|
122 |
+
init_btn = gr.Button("Initialize TxAgent", variant="primary")
|
123 |
+
init_status = gr.Textbox(label="Initialization Status", interactive=False)
|
124 |
|
125 |
+
# Chat Interface
|
126 |
+
chatbot = gr.Chatbot(
|
127 |
+
height=600,
|
128 |
+
bubble_full_width=False,
|
129 |
+
placeholder="Type your medical question below..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
)
|
131 |
+
|
132 |
+
with gr.Row():
|
133 |
+
msg = gr.Textbox(
|
134 |
+
label="Your Question",
|
135 |
+
placeholder="Ask about drug interactions, dosage, or treatment options...",
|
136 |
+
scale=4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
)
|
138 |
+
submit_btn = gr.Button("Submit", variant="primary", scale=1)
|
139 |
+
|
140 |
+
# Settings
|
141 |
+
with gr.Accordion("Advanced Settings", open=False):
|
142 |
+
with gr.Row():
|
143 |
+
temperature = gr.Slider(
|
144 |
+
minimum=0.1, maximum=1.0, value=0.3, step=0.1,
|
145 |
+
label="Temperature (higher = more creative)"
|
146 |
+
)
|
147 |
+
max_new_tokens = gr.Slider(
|
148 |
+
minimum=128, maximum=4096, value=1024, step=128,
|
149 |
+
label="Max Response Length"
|
150 |
+
)
|
151 |
+
|
152 |
+
# Examples
|
153 |
+
gr.Examples(
|
154 |
+
examples=EXAMPLE_QUESTIONS,
|
155 |
+
inputs=msg,
|
156 |
+
label="Example Questions (click to try)"
|
157 |
+
)
|
158 |
+
|
159 |
+
# Footer
|
160 |
+
gr.Markdown(UI_CONFIG['disclaimer'])
|
161 |
|
162 |
+
# Event Handlers
|
163 |
+
init_btn.click(
|
164 |
+
app.initialize_agent,
|
165 |
+
outputs=init_status
|
166 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
+
msg.submit(
|
169 |
+
app.chat,
|
170 |
+
[msg, chatbot, temperature, max_new_tokens],
|
171 |
+
[chatbot]
|
172 |
+
)
|
173 |
+
|
174 |
+
submit_btn.click(
|
175 |
+
app.chat,
|
176 |
+
[msg, chatbot, temperature, max_new_tokens],
|
177 |
+
[chatbot]
|
178 |
+
).then(
|
179 |
+
lambda: "", None, msg
|
180 |
+
)
|
181 |
|
182 |
return demo
|
183 |
|
184 |
+
# ========== Main Execution ==========
|
185 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
# Create and launch the interface
|
187 |
+
interface = create_interface()
|
188 |
+
|
189 |
+
# Launch settings
|
190 |
+
launch_config = {
|
191 |
+
'server_name': '0.0.0.0',
|
192 |
+
'server_port': 7860,
|
193 |
+
'share': True,
|
194 |
+
'favicon_path': None,
|
195 |
+
'auth': None,
|
196 |
+
'auth_message': None,
|
197 |
+
'enable_queue': True,
|
198 |
+
'max_threads': 40,
|
199 |
+
'show_error': True
|
200 |
+
}
|
201 |
+
|
202 |
+
interface.queue(concurrency_count=5).launch(**launch_config)
|