Ali2206 commited on
Commit
a65af0c
·
verified ·
1 Parent(s): db4b178

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -281
app.py CHANGED
@@ -36,98 +36,34 @@ CONFIG = {
36
  }
37
  }
38
 
39
- DESCRIPTION = '''
40
- <div>
41
- <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools</h1>
42
- </div>
43
- '''
44
-
45
- INTRO = """
46
- Precision therapeutics require multimodal adaptive models that provide personalized treatment recommendations.
47
- We introduce TxAgent, an AI agent that leverages multi-step reasoning and real-time biomedical knowledge
48
- retrieval across a toolbox of 211 expert-curated tools to navigate complex drug interactions,
49
- contraindications, and patient-specific treatment strategies, delivering evidence-grounded therapeutic decisions.
50
- """
51
-
52
- LICENSE = """
53
- We welcome your feedback and suggestions to enhance your experience with TxAgent, and if you're interested
54
- in collaboration, please email Marinka Zitnik and Shanghua Gao.
55
-
56
- ### Medical Advice Disclaimer
57
- DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE
58
- The information, including but not limited to, text, graphics, images and other material contained on this
59
- website are for informational purposes only. No material on this site is intended to be a substitute for
60
- professional medical advice, diagnosis or treatment.
61
- """
62
-
63
- PLACEHOLDER = """
64
- <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
65
- <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
66
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
67
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️ (top-right) to remove previous context before submitting a new question.</p>
68
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
69
- </div>
70
- """
71
-
72
- css = """
73
- h1 {
74
- text-align: center;
75
- display: block;
76
- }
77
-
78
- #duplicate-button {
79
- margin: auto;
80
- color: white;
81
- background: #1565c0;
82
- border-radius: 100vh;
83
- }
84
- .small-button button {
85
- font-size: 12px !important;
86
- padding: 4px 8px !important;
87
- height: 6px !important;
88
- width: 4px !important;
89
- }
90
- .gradio-accordion {
91
- margin-top: 0px !important;
92
- margin-bottom: 0px !important;
93
- }
94
- """
95
-
96
  chat_css = """
97
  .gr-button { font-size: 20px !important; }
98
  .gr-button svg { width: 32px !important; height: 32px !important; }
99
  """
100
 
101
  def safe_load_embeddings(filepath: str) -> any:
102
- """Safely load embeddings with proper weights_only handling"""
103
  try:
104
- # First try with weights_only=True
105
  return torch.load(filepath, weights_only=True)
106
  except Exception as e:
107
  logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
108
  try:
109
- # Fallback to unsafe load if needed
110
  return torch.load(filepath, weights_only=False)
111
  except Exception as e:
112
  logger.error(f"Failed to load embeddings: {str(e)}")
113
  return None
114
 
115
  def patch_embedding_loading():
116
- """Monkey-patch the embedding loading functionality"""
117
  try:
118
  from txagent.toolrag import ToolRAGModel
119
-
120
- original_load = ToolRAGModel.load_tool_desc_embedding
121
-
122
  def patched_load(self, tooluniverse):
123
  try:
124
  if not os.path.exists(CONFIG["embedding_filename"]):
125
  logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
126
  return False
127
-
128
  self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
129
-
130
- # Updated tool loading approach
131
  if hasattr(tooluniverse, 'get_all_tools'):
132
  tools = tooluniverse.get_all_tools()
133
  elif hasattr(tooluniverse, 'tools'):
@@ -135,13 +71,13 @@ def patch_embedding_loading():
135
  else:
136
  logger.error("No method found to access tools from ToolUniverse")
137
  return False
138
-
139
  current_count = len(tools)
140
  embedding_count = len(self.tool_desc_embedding)
141
-
142
  if current_count != embedding_count:
143
  logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
144
-
145
  if current_count < embedding_count:
146
  self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
147
  logger.info(f"Truncated embeddings to match {current_count} tools")
@@ -150,22 +86,21 @@ def patch_embedding_loading():
150
  padding = [last_embedding] * (current_count - embedding_count)
151
  self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding)
152
  logger.info(f"Padded embeddings to match {current_count} tools")
153
-
154
  return True
155
-
156
  except Exception as e:
157
  logger.error(f"Failed to load embeddings: {str(e)}")
158
  return False
159
-
160
  ToolRAGModel.load_tool_desc_embedding = patched_load
161
  logger.info("Successfully patched embedding loading")
162
-
163
  except Exception as e:
164
  logger.error(f"Failed to patch embedding loading: {str(e)}")
165
  raise
166
 
167
  def prepare_tool_files():
168
- """Ensure tool files exist and are populated"""
169
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
170
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
171
  logger.info("Generating tool list using ToolUniverse...")
@@ -178,7 +113,7 @@ def prepare_tool_files():
178
  else:
179
  tools = []
180
  logger.error("Could not access tools from ToolUniverse")
181
-
182
  with open(CONFIG["tool_files"]["new_tool"], "w") as f:
183
  json.dump(tools, f, indent=2)
184
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
@@ -186,12 +121,9 @@ def prepare_tool_files():
186
  logger.error(f"Failed to prepare tool files: {str(e)}")
187
 
188
  def create_agent():
189
- """Create and initialize the TxAgent"""
190
- # Apply the embedding patch before creating the agent
191
  patch_embedding_loading()
192
  prepare_tool_files()
193
 
194
- # Initialize the agent
195
  try:
196
  agent = TxAgent(
197
  CONFIG["model_name"],
@@ -209,224 +141,45 @@ def create_agent():
209
  logger.error(f"Failed to create agent: {str(e)}")
210
  raise
211
 
212
- def handle_chat_response(history, message, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
213
- """Convert generator output to Gradio-compatible format"""
214
- full_response = ""
215
- for chunk in message:
 
216
  if isinstance(chunk, dict):
217
- full_response += chunk.get("content", "")
218
  else:
219
- full_response += str(chunk)
220
- history.append((None, full_response))
221
- return history
222
-
223
- def update_model_parameters(agent, enable_finish, enable_rag, enable_summary,
224
- init_rag_num, step_rag_num, skip_last_k,
225
- summary_mode, summary_skip_last_k, summary_context_length,
226
- force_finish, seed):
227
- """Update model parameters"""
228
- updated_params = agent.update_parameters(
229
- enable_finish=enable_finish,
230
- enable_rag=enable_rag,
231
- enable_summary=enable_summary,
232
- init_rag_num=init_rag_num,
233
- step_rag_num=step_rag_num,
234
- skip_last_k=skip_last_k,
235
- summary_mode=summary_mode,
236
- summary_skip_last_k=summary_skip_last_k,
237
- summary_context_length=summary_context_length,
238
- force_finish=force_finish,
239
- seed=seed,
240
- )
241
- return updated_params
242
-
243
- def update_seed(agent):
244
- """Update random seed"""
245
- seed = random.randint(0, 10000)
246
- updated_params = agent.update_parameters(seed=seed)
247
- return updated_params
248
-
249
- def handle_retry(agent, history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
250
- """Handle retry functionality"""
251
- print("Updated seed:", update_seed(agent))
252
- new_history = history[:retry_data.index]
253
- previous_prompt = history[retry_data.index]['content']
254
- print("previous_prompt", previous_prompt)
255
- response = agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}],
256
- temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
257
- yield from handle_chat_response(new_history, response, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
258
-
259
- PASSWORD = "mypassword"
260
-
261
- def check_password(input_password):
262
- """Check password for protected settings"""
263
- if input_password == PASSWORD:
264
- return gr.update(visible=True), ""
265
- else:
266
- return gr.update(visible=False), "Incorrect password, try again!"
267
 
268
  def create_demo(agent):
269
- """Create the Gradio interface"""
270
- default_temperature = 0.3
271
- default_max_new_tokens = 1024
272
- default_max_tokens = 81920
273
- default_max_round = 30
274
-
275
- question_examples = [
276
- ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'],
277
- ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'],
278
- ['A 30-year-old patient is taking Prozac to treat their depression. They were recently diagnosed with WHIM syndrome and require a treatment for that condition as well. Is Xolremdi suitable for this patient, considering contraindications?'],
279
- ]
280
-
281
- chatbot = gr.Chatbot(height=800, placeholder=PLACEHOLDER,
282
- label='TxAgent', show_copy_button=True)
283
-
284
- with gr.Blocks(css=css) as demo:
285
- gr.Markdown(DESCRIPTION)
286
- gr.Markdown(INTRO)
287
-
288
- temperature_state = gr.State(value=default_temperature)
289
- max_new_tokens_state = gr.State(value=default_max_new_tokens)
290
- max_tokens_state = gr.State(value=default_max_tokens)
291
- max_round_state = gr.State(value=default_max_round)
292
-
293
- chatbot.retry(
294
- lambda *args: handle_retry(agent, *args),
295
- inputs=[chatbot, chatbot, temperature_state, max_new_tokens_state,
296
- max_tokens_state, gr.Checkbox(value=False, render=False),
297
- gr.State([]), max_round_state]
298
- )
299
-
300
  with gr.Row():
301
- with gr.Column(scale=4):
302
- msg = gr.Textbox(label="Input", placeholder="Type your question here...")
303
- with gr.Column(scale=1):
304
- submit_btn = gr.Button("Submit", variant="primary")
305
-
306
  with gr.Row():
307
- clear_btn = gr.ClearButton([msg, chatbot])
308
-
309
- def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
310
- response = agent.run_gradio_chat(
311
- chat_history + [{"role": "user", "content": message}],
312
- temperature,
313
- max_new_tokens,
314
- max_tokens,
315
- multi_agent,
316
- conversation,
317
- max_round
318
- )
319
- return handle_chat_response(chat_history, response, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
320
 
321
- submit_btn.click(
322
- respond,
323
- inputs=[msg, chatbot, temperature_state, max_new_tokens_state,
324
- max_tokens_state, gr.Checkbox(value=False, render=False),
325
- gr.State([]), max_round_state],
326
- outputs=[chatbot]
327
- )
328
- msg.submit(
329
  respond,
330
- inputs=[msg, chatbot, temperature_state, max_new_tokens_state,
331
- max_tokens_state, gr.Checkbox(value=False, render=False),
332
- gr.State([]), max_round_state],
333
  outputs=[chatbot]
334
  )
335
-
336
- with gr.Accordion("Settings", open=False):
337
- temperature_slider = gr.Slider(
338
- minimum=0,
339
- maximum=1,
340
- step=0.1,
341
- value=default_temperature,
342
- label="Temperature"
343
- )
344
- max_new_tokens_slider = gr.Slider(
345
- minimum=128,
346
- maximum=4096,
347
- step=1,
348
- value=default_max_new_tokens,
349
- label="Max new tokens"
350
- )
351
- max_tokens_slider = gr.Slider(
352
- minimum=128,
353
- maximum=32000,
354
- step=1,
355
- value=default_max_tokens,
356
- label="Max tokens"
357
- )
358
- max_round_slider = gr.Slider(
359
- minimum=0,
360
- maximum=50,
361
- step=1,
362
- value=default_max_round,
363
- label="Max round")
364
-
365
- temperature_slider.change(
366
- lambda x: x, inputs=temperature_slider, outputs=temperature_state)
367
- max_new_tokens_slider.change(
368
- lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
369
- max_tokens_slider.change(
370
- lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
371
- max_round_slider.change(
372
- lambda x: x, inputs=max_round_slider, outputs=max_round_state)
373
-
374
- password_input = gr.Textbox(
375
- label="Enter Password for More Settings", type="password")
376
- incorrect_message = gr.Textbox(visible=False, interactive=False)
377
-
378
- with gr.Accordion("⚙��� Advanced Settings", open=False, visible=False) as protected_accordion:
379
- with gr.Row():
380
- with gr.Column(scale=1):
381
- with gr.Accordion("Model Settings", open=False):
382
- model_name_input = gr.Textbox(
383
- label="Enter model path", value=CONFIG["model_name"])
384
- load_model_btn = gr.Button(value="Load Model")
385
- load_model_btn.click(
386
- agent.load_models,
387
- inputs=model_name_input,
388
- outputs=gr.Textbox(label="Status"))
389
- with gr.Column(scale=1):
390
- with gr.Accordion("Functional Parameters", open=False):
391
- enable_finish = gr.Checkbox(label="Enable Finish", value=True)
392
- enable_rag = gr.Checkbox(label="Enable RAG", value=True)
393
- enable_summary = gr.Checkbox(label="Enable Summary", value=False)
394
- init_rag_num = gr.Number(label="Initial RAG Num", value=0)
395
- step_rag_num = gr.Number(label="Step RAG Num", value=10)
396
- skip_last_k = gr.Number(label="Skip Last K", value=0)
397
- summary_mode = gr.Textbox(label="Summary Mode", value='step')
398
- summary_skip_last_k = gr.Number(label="Summary Skip Last K", value=0)
399
- summary_context_length = gr.Number(label="Summary Context Length", value=None)
400
- force_finish = gr.Checkbox(label="Force FinalAnswer", value=True)
401
- seed = gr.Number(label="Seed", value=100)
402
- submit_btn = gr.Button("Update Parameters")
403
- updated_parameters_output = gr.JSON()
404
- submit_btn.click(
405
- lambda *args: update_model_parameters(agent, *args),
406
- inputs=[enable_finish, enable_rag, enable_summary,
407
- init_rag_num, step_rag_num, skip_last_k,
408
- summary_mode, summary_skip_last_k,
409
- summary_context_length, force_finish, seed],
410
- outputs=updated_parameters_output
411
- )
412
-
413
- submit_button = gr.Button("Submit")
414
- submit_button.click(
415
- check_password,
416
- inputs=password_input,
417
- outputs=[protected_accordion, incorrect_message]
418
- )
419
-
420
- gr.Markdown(LICENSE)
421
-
422
  return demo
423
 
424
  def main():
425
- """Main function to run the application"""
426
  try:
 
427
  agent = create_agent()
428
  demo = create_demo(agent)
429
- demo.launch(share=True)
430
  except Exception as e:
431
  logger.error(f"Application failed to start: {str(e)}")
432
  raise
 
36
  }
37
  }
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  chat_css = """
40
  .gr-button { font-size: 20px !important; }
41
  .gr-button svg { width: 32px !important; height: 32px !important; }
42
  """
43
 
44
  def safe_load_embeddings(filepath: str) -> any:
 
45
  try:
 
46
  return torch.load(filepath, weights_only=True)
47
  except Exception as e:
48
  logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
49
  try:
 
50
  return torch.load(filepath, weights_only=False)
51
  except Exception as e:
52
  logger.error(f"Failed to load embeddings: {str(e)}")
53
  return None
54
 
55
  def patch_embedding_loading():
 
56
  try:
57
  from txagent.toolrag import ToolRAGModel
58
+
 
 
59
  def patched_load(self, tooluniverse):
60
  try:
61
  if not os.path.exists(CONFIG["embedding_filename"]):
62
  logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
63
  return False
64
+
65
  self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
66
+
 
67
  if hasattr(tooluniverse, 'get_all_tools'):
68
  tools = tooluniverse.get_all_tools()
69
  elif hasattr(tooluniverse, 'tools'):
 
71
  else:
72
  logger.error("No method found to access tools from ToolUniverse")
73
  return False
74
+
75
  current_count = len(tools)
76
  embedding_count = len(self.tool_desc_embedding)
77
+
78
  if current_count != embedding_count:
79
  logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
80
+
81
  if current_count < embedding_count:
82
  self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
83
  logger.info(f"Truncated embeddings to match {current_count} tools")
 
86
  padding = [last_embedding] * (current_count - embedding_count)
87
  self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding)
88
  logger.info(f"Padded embeddings to match {current_count} tools")
89
+
90
  return True
91
+
92
  except Exception as e:
93
  logger.error(f"Failed to load embeddings: {str(e)}")
94
  return False
95
+
96
  ToolRAGModel.load_tool_desc_embedding = patched_load
97
  logger.info("Successfully patched embedding loading")
98
+
99
  except Exception as e:
100
  logger.error(f"Failed to patch embedding loading: {str(e)}")
101
  raise
102
 
103
  def prepare_tool_files():
 
104
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
105
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
106
  logger.info("Generating tool list using ToolUniverse...")
 
113
  else:
114
  tools = []
115
  logger.error("Could not access tools from ToolUniverse")
116
+
117
  with open(CONFIG["tool_files"]["new_tool"], "w") as f:
118
  json.dump(tools, f, indent=2)
119
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
 
121
  logger.error(f"Failed to prepare tool files: {str(e)}")
122
 
123
  def create_agent():
 
 
124
  patch_embedding_loading()
125
  prepare_tool_files()
126
 
 
127
  try:
128
  agent = TxAgent(
129
  CONFIG["model_name"],
 
141
  logger.error(f"Failed to create agent: {str(e)}")
142
  raise
143
 
144
+ def respond(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
145
+ updated_history = history + [{"role": "user", "content": message}]
146
+ response_generator = agent.run_gradio_chat(updated_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
147
+ collected = ""
148
+ for chunk in response_generator:
149
  if isinstance(chunk, dict):
150
+ collected += chunk.get("content", "")
151
  else:
152
+ collected += str(chunk)
153
+ updated_history.append({"role": "assistant", "content": collected})
154
+ return updated_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  def create_demo(agent):
157
+ with gr.Blocks(css=chat_css) as demo:
158
+ chatbot = gr.Chatbot(label="TxAgent", type="messages")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  with gr.Row():
160
+ msg = gr.Textbox(label="Your question")
 
 
 
 
161
  with gr.Row():
162
+ temp = gr.Slider(0, 1, value=0.3, label="Temperature")
163
+ max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
164
+ max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
165
+ max_rounds = gr.Slider(1, 30, value=30, label="Max Rounds")
166
+ multi_agent = gr.Checkbox(label="Multi-Agent Mode")
167
+ with gr.Row():
168
+ submit = gr.Button("Ask TxAgent")
 
 
 
 
 
 
169
 
170
+ submit.click(
 
 
 
 
 
 
 
171
  respond,
172
+ inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
 
 
173
  outputs=[chatbot]
174
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  return demo
176
 
177
  def main():
 
178
  try:
179
+ global agent
180
  agent = create_agent()
181
  demo = create_demo(agent)
182
+ demo.launch()
183
  except Exception as e:
184
  logger.error(f"Application failed to start: {str(e)}")
185
  raise