Ali2206 commited on
Commit
e3711be
·
verified ·
1 Parent(s): d8c321d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -214
app.py CHANGED
@@ -36,88 +36,21 @@ CONFIG = {
36
  }
37
  }
38
 
39
- DESCRIPTION = '''
40
- <div>
41
- <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools</h1>
42
- </div>
43
- '''
44
-
45
- INTRO = """
46
- Precision therapeutics require multimodal adaptive models that provide personalized treatment recommendations.
47
- We introduce TxAgent, an AI agent that leverages multi-step reasoning and real-time biomedical knowledge
48
- retrieval across a toolbox of 211 expert-curated tools to navigate complex drug interactions,
49
- contraindications, and patient-specific treatment strategies, delivering evidence-grounded therapeutic decisions.
50
- """
51
-
52
- LICENSE = """
53
- We welcome your feedback and suggestions to enhance your experience with TxAgent, and if you're interested
54
- in collaboration, please email Marinka Zitnik and Shanghua Gao.
55
-
56
- ### Medical Advice Disclaimer
57
- DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE
58
- The information, including but not limited to, text, graphics, images and other material contained on this
59
- website are for informational purposes only. No material on this site is intended to be a substitute for
60
- professional medical advice, diagnosis or treatment.
61
- """
62
-
63
- PLACEHOLDER = """
64
- <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
65
- <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
66
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
67
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️ (top-right) to remove previous context before submitting a new question.</p>
68
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
69
- </div>
70
- """
71
-
72
- css = """
73
- h1 {
74
- text-align: center;
75
- display: block;
76
- }
77
-
78
- #duplicate-button {
79
- margin: auto;
80
- color: white;
81
- background: #1565c0;
82
- border-radius: 100vh;
83
- }
84
- .small-button button {
85
- font-size: 12px !important;
86
- padding: 4px 8px !important;
87
- height: 6px !important;
88
- width: 4px !important;
89
- }
90
- .gradio-accordion {
91
- margin-top: 0px !important;
92
- margin-bottom: 0px !important;
93
- }
94
- """
95
-
96
- chat_css = """
97
- .gr-button { font-size: 20px !important; }
98
- .gr-button svg { width: 32px !important; height: 32px !important; }
99
- """
100
-
101
  def safe_load_embeddings(filepath: str) -> any:
102
- """Safely load embeddings with proper weights_only handling"""
103
  try:
104
  return torch.load(filepath, weights_only=True)
105
  except Exception as e:
106
  logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
107
  try:
108
- with torch.serialization.safe_globals([torch.serialization._reconstruct]):
109
- return torch.load(filepath, weights_only=False)
110
  except Exception as e:
111
  logger.error(f"Failed to load embeddings even with safe_globals: {str(e)}")
112
  return None
113
 
114
  def patch_embedding_loading():
115
- """Monkey-patch the embedding loading functionality"""
116
  try:
117
  from txagent.toolrag import ToolRAGModel
118
 
119
- original_load = ToolRAGModel.load_tool_desc_embedding
120
-
121
  def patched_load(self, tooluniverse):
122
  try:
123
  if not os.path.exists(CONFIG["embedding_filename"]):
@@ -129,7 +62,6 @@ def patch_embedding_loading():
129
  logger.error("Embedding is None, aborting.")
130
  return False
131
 
132
- # Ensure tools is a list (in case it's a generator)
133
  tools = list(tooluniverse.get_all_tools()) if hasattr(tooluniverse, 'get_all_tools') else []
134
  current_count = len(tools)
135
  embedding_count = len(self.tool_desc_embedding)
@@ -159,11 +91,37 @@ def patch_embedding_loading():
159
  logger.error(f"Failed to patch embedding loading: {str(e)}")
160
  raise
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def update_model_parameters(agent, enable_finish, enable_rag, enable_summary,
163
  init_rag_num, step_rag_num, skip_last_k,
164
  summary_mode, summary_skip_last_k, summary_context_length,
165
  force_finish, seed):
166
- """Update model parameters"""
167
  updated_params = agent.update_parameters(
168
  enable_finish=enable_finish,
169
  enable_rag=enable_rag,
@@ -180,13 +138,11 @@ def update_model_parameters(agent, enable_finish, enable_rag, enable_summary,
180
  return updated_params
181
 
182
  def update_seed(agent):
183
- """Update random seed"""
184
  seed = random.randint(0, 10000)
185
  updated_params = agent.update_parameters(seed=seed)
186
  return updated_params
187
 
188
  def handle_retry(agent, history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
189
- """Handle retry functionality"""
190
  print("Updated seed:", update_seed(agent))
191
  new_history = history[:retry_data.index]
192
  previous_prompt = history[retry_data.index]['content']
@@ -197,160 +153,42 @@ def handle_retry(agent, history, retry_data: gr.RetryData, temperature, max_new_
197
  PASSWORD = "mypassword"
198
 
199
  def check_password(input_password):
200
- """Check password for protected settings"""
201
  if input_password == PASSWORD:
202
  return gr.update(visible=True), ""
203
  else:
204
  return gr.update(visible=False), "Incorrect password, try again!"
205
 
206
  def create_demo(agent):
207
- """Create the Gradio interface"""
208
- default_temperature = 0.3
209
- default_max_new_tokens = 1024
210
- default_max_tokens = 81920
211
- default_max_round = 30
212
-
213
- question_examples = [
214
- ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'],
215
- ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'],
216
- ['A 30-year-old patient is taking Prozac to treat their depression. They were recently diagnosed with WHIM syndrome and require a treatment for that condition as well. Is Xolremdi suitable for this patient, considering contraindications?'],
217
- ]
218
-
219
- chatbot = gr.Chatbot(height=800, placeholder=PLACEHOLDER,
220
- label='TxAgent', type="messages", show_copy_button=True)
221
-
222
- with gr.Blocks(css=css) as demo:
223
- gr.Markdown(DESCRIPTION)
224
- gr.Markdown(INTRO)
225
-
226
- temperature_state = gr.State(value=default_temperature)
227
- max_new_tokens_state = gr.State(value=default_max_new_tokens)
228
- max_tokens_state = gr.State(value=default_max_tokens)
229
- max_round_state = gr.State(value=default_max_round)
230
-
231
- chatbot.retry(
232
- lambda *args: handle_retry(agent, *args),
233
- inputs=[chatbot, chatbot, temperature_state, max_new_tokens_state,
234
- max_tokens_state, gr.Checkbox(value=False, render=False),
235
- gr.State([]), max_round_state]
236
- )
237
-
238
- gr.ChatInterface(
239
- fn=lambda *args: agent.run_gradio_chat(*args),
240
- chatbot=chatbot,
241
- fill_height=True,
242
- fill_width=True,
243
- stop_btn=True,
244
- additional_inputs_accordion=gr.Accordion(
245
- label="⚙️ Inference Parameters", open=False, render=False),
246
- additional_inputs=[
247
- temperature_state, max_new_tokens_state, max_tokens_state,
248
- gr.Checkbox(
249
- label="Activate multi-agent reasoning mode",
250
- value=False,
251
- render=False),
252
- gr.State([]),
253
- max_round_state,
254
- gr.Number(label="Seed", value=100, render=False)
255
- ],
256
- examples=question_examples,
257
- cache_examples=False,
258
- css=chat_css,
259
- )
260
-
261
- with gr.Accordion("Settings", open=False):
262
- temperature_slider = gr.Slider(
263
- minimum=0,
264
- maximum=1,
265
- step=0.1,
266
- value=default_temperature,
267
- label="Temperature"
268
- )
269
- max_new_tokens_slider = gr.Slider(
270
- minimum=128,
271
- maximum=4096,
272
- step=1,
273
- value=default_max_new_tokens,
274
- label="Max new tokens"
275
- )
276
- max_tokens_slider = gr.Slider(
277
- minimum=128,
278
- maximum=32000,
279
- step=1,
280
- value=default_max_tokens,
281
- label="Max tokens"
282
- )
283
- max_round_slider = gr.Slider(
284
- minimum=0,
285
- maximum=50,
286
- step=1,
287
- value=default_max_round,
288
- label="Max round")
289
-
290
- temperature_slider.change(
291
- lambda x: x, inputs=temperature_slider, outputs=temperature_state)
292
- max_new_tokens_slider.change(
293
- lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
294
- max_tokens_slider.change(
295
- lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
296
- max_round_slider.change(
297
- lambda x: x, inputs=max_round_slider, outputs=max_round_state)
298
 
299
- password_input = gr.Textbox(
300
- label="Enter Password for More Settings", type="password")
301
- incorrect_message = gr.Textbox(visible=False, interactive=False)
302
-
303
- with gr.Accordion("⚙️ Settings", open=False, visible=False) as protected_accordion:
304
- with gr.Row():
305
- with gr.Column(scale=1):
306
- with gr.Accordion("⚙️ Model Loading", open=False):
307
- model_name_input = gr.Textbox(
308
- label="Enter model path", value=CONFIG["model_name"])
309
- load_model_btn = gr.Button(value="Load Model")
310
- load_model_btn.click(
311
- agent.load_models,
312
- inputs=model_name_input,
313
- outputs=gr.Textbox(label="Status"))
314
- with gr.Column(scale=1):
315
- with gr.Accordion("⚙️ Functional Parameters", open=False):
316
- enable_finish = gr.Checkbox(label="Enable Finish", value=True)
317
- enable_rag = gr.Checkbox(label="Enable RAG", value=True)
318
- enable_summary = gr.Checkbox(label="Enable Summary", value=False)
319
- init_rag_num = gr.Number(label="Initial RAG Num", value=0)
320
- step_rag_num = gr.Number(label="Step RAG Num", value=10)
321
- skip_last_k = gr.Number(label="Skip Last K", value=0)
322
- summary_mode = gr.Textbox(label="Summary Mode", value='step')
323
- summary_skip_last_k = gr.Number(label="Summary Skip Last K", value=0)
324
- summary_context_length = gr.Number(label="Summary Context Length", value=None)
325
- force_finish = gr.Checkbox(label="Force FinalAnswer", value=True)
326
- seed = gr.Number(label="Seed", value=100)
327
- submit_btn = gr.Button("Update Parameters")
328
- updated_parameters_output = gr.JSON()
329
- submit_btn.click(
330
- lambda *args: update_model_parameters(agent, *args),
331
- inputs=[enable_finish, enable_rag, enable_summary,
332
- init_rag_num, step_rag_num, skip_last_k,
333
- summary_mode, summary_skip_last_k,
334
- summary_context_length, force_finish, seed],
335
- outputs=updated_parameters_output
336
- )
337
-
338
- submit_button = gr.Button("Submit")
339
- submit_button.click(
340
- check_password,
341
- inputs=password_input,
342
- outputs=[protected_accordion, incorrect_message]
343
- )
344
-
345
- gr.Markdown(LICENSE)
346
-
347
  return demo
348
 
349
  def main():
350
- """Main function to run the application"""
351
  agent = create_agent()
352
  demo = create_demo(agent)
353
  demo.launch(share=True)
354
 
355
  if __name__ == "__main__":
356
- main()
 
36
  }
37
  }
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def safe_load_embeddings(filepath: str) -> any:
 
40
  try:
41
  return torch.load(filepath, weights_only=True)
42
  except Exception as e:
43
  logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
44
  try:
45
+ return torch.load(filepath, weights_only=False)
 
46
  except Exception as e:
47
  logger.error(f"Failed to load embeddings even with safe_globals: {str(e)}")
48
  return None
49
 
50
  def patch_embedding_loading():
 
51
  try:
52
  from txagent.toolrag import ToolRAGModel
53
 
 
 
54
  def patched_load(self, tooluniverse):
55
  try:
56
  if not os.path.exists(CONFIG["embedding_filename"]):
 
62
  logger.error("Embedding is None, aborting.")
63
  return False
64
 
 
65
  tools = list(tooluniverse.get_all_tools()) if hasattr(tooluniverse, 'get_all_tools') else []
66
  current_count = len(tools)
67
  embedding_count = len(self.tool_desc_embedding)
 
91
  logger.error(f"Failed to patch embedding loading: {str(e)}")
92
  raise
93
 
94
+ def prepare_tool_files():
95
+ os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
96
+ if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
97
+ logger.info("Generating tool list using ToolUniverse...")
98
+ tu = ToolUniverse()
99
+ tools = list(tu.get_all_tools()) if hasattr(tu, 'get_all_tools') else []
100
+ with open(CONFIG["tool_files"]["new_tool"], "w") as f:
101
+ json.dump(tools, f, indent=2)
102
+ logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
103
+
104
+ def create_agent():
105
+ patch_embedding_loading()
106
+ prepare_tool_files()
107
+
108
+ agent = TxAgent(
109
+ CONFIG["model_name"],
110
+ CONFIG["rag_model_name"],
111
+ tool_files_dict=CONFIG["tool_files"],
112
+ force_finish=True,
113
+ enable_checker=True,
114
+ step_rag_num=10,
115
+ seed=100,
116
+ additional_default_tools=['DirectResponse', 'RequireClarification']
117
+ )
118
+ agent.init_model()
119
+ return agent
120
+
121
  def update_model_parameters(agent, enable_finish, enable_rag, enable_summary,
122
  init_rag_num, step_rag_num, skip_last_k,
123
  summary_mode, summary_skip_last_k, summary_context_length,
124
  force_finish, seed):
 
125
  updated_params = agent.update_parameters(
126
  enable_finish=enable_finish,
127
  enable_rag=enable_rag,
 
138
  return updated_params
139
 
140
  def update_seed(agent):
 
141
  seed = random.randint(0, 10000)
142
  updated_params = agent.update_parameters(seed=seed)
143
  return updated_params
144
 
145
  def handle_retry(agent, history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
 
146
  print("Updated seed:", update_seed(agent))
147
  new_history = history[:retry_data.index]
148
  previous_prompt = history[retry_data.index]['content']
 
153
  PASSWORD = "mypassword"
154
 
155
  def check_password(input_password):
 
156
  if input_password == PASSWORD:
157
  return gr.update(visible=True), ""
158
  else:
159
  return gr.update(visible=False), "Incorrect password, try again!"
160
 
161
  def create_demo(agent):
162
+ chatbot = gr.Chatbot()
163
+ with gr.Blocks(css=chat_css) as demo:
164
+ with gr.Row():
165
+ gr.Markdown("""
166
+ # TxAgent Interface
167
+ Ask biomedical questions and get reasoning-based answers.
168
+ """)
169
+
170
+ user_input = gr.Textbox(label="Your question")
171
+ temperature = gr.Slider(0, 1, value=0.3, step=0.1, label="Temperature")
172
+ max_new_tokens = gr.Slider(128, 4096, value=1024, step=1, label="Max New Tokens")
173
+ max_tokens = gr.Slider(128, 81920, value=81920, step=1, label="Max Tokens")
174
+ max_round = gr.Slider(1, 30, value=30, step=1, label="Max Rounds")
175
+ multi_agent = gr.Checkbox(label="Multi-Agent Mode")
176
+
177
+ submit = gr.Button("Run TxAgent")
178
+
179
+ def run_agent(message, history, temperature, max_new_tokens, max_tokens, multi_agent, max_round):
180
+ return agent.run_gradio_chat(history + [{"role": "user", "content": message}],
181
+ temperature, max_new_tokens, max_tokens,
182
+ multi_agent, [], max_round)
183
+
184
+ submit.click(run_agent, inputs=[user_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, max_round], outputs=chatbot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  return demo
187
 
188
  def main():
 
189
  agent = create_agent()
190
  demo = create_demo(agent)
191
  demo.launch(share=True)
192
 
193
  if __name__ == "__main__":
194
+ main()