Ali2206 commited on
Commit
8e533b3
·
verified ·
1 Parent(s): 2567e17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -166
app.py CHANGED
@@ -5,7 +5,6 @@ from txagent import TxAgent
5
  import spaces
6
  import gradio as gr
7
  import os
8
- import os
9
 
10
  # Determine the directory where the current file is located
11
  current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -14,7 +13,6 @@ os.environ["MKL_THREADING_LAYER"] = "GNU"
14
  # Set an environment variable
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
 
17
-
18
  DESCRIPTION = '''
19
  <div>
20
  <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools </h1>
@@ -36,9 +34,8 @@ PLACEHOLDER = """
36
  <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
37
  <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
38
  <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
39
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️
40
- (top-right) to remove previous context before sumbmitting a new question.</p>
41
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
42
  </div>
43
  """
44
 
@@ -71,13 +68,11 @@ chat_css = """
71
  .gr-button svg { width: 32px !important; height: 32px !important; } /* Enlarges SVG icons */
72
  """
73
 
74
- # model_name = '/n/holylfs06/LABS/mzitnik_lab/Lab/shgao/bioagent/bio/alignment-handbook/data_new/L8-qlora-biov49v9v7v16_32k_chat01_merged'
75
  model_name = 'mims-harvard/TxAgent-T1-Llama-3.1-8B'
76
  rag_model_name = 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B'
77
-
78
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
79
 
80
-
81
  question_examples = [
82
  ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'],
83
  ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'],
@@ -88,20 +83,9 @@ new_tool_files = {
88
  'new_tool': os.path.join(current_dir, 'data', 'new_tool.json'),
89
  }
90
 
91
- agent = TxAgent(model_name,
92
- rag_model_name,
93
- tool_files_dict=new_tool_files,
94
- force_finish=True,
95
- enable_checker=True,
96
- step_rag_num=10,
97
- seed=100,
98
- additional_default_tools=['DirectResponse', 'RequireClarification'])
99
- agent.init_model()
100
-
101
-
102
  def update_model_parameters(enable_finish, enable_rag, enable_summary,
103
- init_rag_num, step_rag_num, skip_last_k,
104
- summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed):
105
  # Update model instance parameters dynamically
106
  updated_params = agent.update_parameters(
107
  enable_finish=enable_finish,
@@ -116,10 +100,8 @@ def update_model_parameters(enable_finish, enable_rag, enable_summary,
116
  force_finish=force_finish,
117
  seed=seed,
118
  )
119
-
120
  return updated_params
121
 
122
-
123
  def update_seed():
124
  # Update model instance parameters dynamically
125
  seed = random.randint(0, 10000)
@@ -128,166 +110,166 @@ def update_seed():
128
  )
129
  return updated_params
130
 
131
-
132
  def handle_retry(history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
133
  print("Updated seed:", update_seed())
134
  new_history = history[:retry_data.index]
135
  previous_prompt = history[retry_data.index]['content']
136
-
137
  print("previous_prompt", previous_prompt)
138
-
139
  yield from agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}], temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
140
 
141
-
142
  PASSWORD = "mypassword"
143
 
144
- # Function to check if the password is correct
145
-
146
-
147
  def check_password(input_password):
148
  if input_password == PASSWORD:
149
  return gr.update(visible=True), ""
150
  else:
151
  return gr.update(visible=False), "Incorrect password, try again!"
152
 
153
-
154
- conversation_state = gr.State([])
155
-
156
- # Gradio block
157
- chatbot = gr.Chatbot(height=800, placeholder=PLACEHOLDER,
158
- label='TxAgent', type="messages", show_copy_button=True)
159
-
160
- with gr.Blocks(css=css) as demo:
161
- gr.Markdown(DESCRIPTION)
162
- gr.Markdown(INTRO)
163
- default_temperature = 0.3
164
- default_max_new_tokens = 1024
165
- default_max_tokens = 81920
166
- default_max_round = 30
167
- temperature_state = gr.State(value=default_temperature)
168
- max_new_tokens_state = gr.State(value=default_max_new_tokens)
169
- max_tokens_state = gr.State(value=default_max_tokens)
170
- max_round_state = gr.State(value=default_max_round)
171
- chatbot.retry(handle_retry, chatbot, chatbot, temperature_state, max_new_tokens_state,
172
- max_tokens_state, gr.Checkbox(value=False, render=False), conversation_state, max_round_state)
173
-
174
- gr.ChatInterface(
175
- fn=agent.run_gradio_chat,
176
- chatbot=chatbot,
177
- fill_height=True, fill_width=True, stop_btn=True,
178
- additional_inputs_accordion=gr.Accordion(
179
- label="⚙️ Inference Parameters", open=False, render=False),
180
- additional_inputs=[
181
- temperature_state, max_new_tokens_state, max_tokens_state,
182
- gr.Checkbox(
183
- label="Activate multi-agent reasoning mode (it requires additional time but offers a more comprehensive analysis).", value=False, render=False),
184
- conversation_state,
185
- max_round_state,
186
- gr.Number(label="Seed", value=100, render=False)
187
- ],
188
- examples=question_examples,
189
- cache_examples=False,
190
- css=chat_css,
191
- )
192
-
193
- with gr.Accordion("Settings", open=False):
194
-
195
- # Define the sliders
196
- temperature_slider = gr.Slider(
197
- minimum=0,
198
- maximum=1,
199
- step=0.1,
200
- value=default_temperature,
201
- label="Temperature"
202
- )
203
- max_new_tokens_slider = gr.Slider(
204
- minimum=128,
205
- maximum=4096,
206
- step=1,
207
- value=default_max_new_tokens,
208
- label="Max new tokens"
209
- )
210
- max_tokens_slider = gr.Slider(
211
- minimum=128,
212
- maximum=32000,
213
- step=1,
214
- value=default_max_tokens,
215
- label="Max tokens"
216
  )
217
- max_round_slider = gr.Slider(
218
- minimum=0,
219
- maximum=50,
220
- step=1,
221
- value=default_max_round,
222
- label="Max round")
223
-
224
- # Automatically update states when slider values change
225
- temperature_slider.change(
226
- lambda x: x, inputs=temperature_slider, outputs=temperature_state)
227
- max_new_tokens_slider.change(
228
- lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
229
- max_tokens_slider.change(
230
- lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
231
- max_round_slider.change(
232
- lambda x: x, inputs=max_round_slider, outputs=max_round_state)
233
-
234
- password_input = gr.Textbox(
235
- label="Enter Password for More Settings", type="password")
236
- incorrect_message = gr.Textbox(visible=False, interactive=False)
237
- with gr.Accordion("⚙️ Settings", open=False, visible=False) as protected_accordion:
238
- with gr.Row():
239
- with gr.Column(scale=1):
240
- with gr.Accordion("⚙️ Model Loading", open=False):
241
- model_name_input = gr.Textbox(
242
- label="Enter model path", value=model_name)
243
- load_model_btn = gr.Button(value="Load Model")
244
- load_model_btn.click(
245
- agent.load_models, inputs=model_name_input, outputs=gr.Textbox(label="Status"))
246
- with gr.Column(scale=1):
247
- with gr.Accordion("⚙️ Functional Parameters", open=False):
248
- # Create Gradio components for parameter inputs
249
- enable_finish = gr.Checkbox(
250
- label="Enable Finish", value=True)
251
- enable_rag = gr.Checkbox(
252
- label="Enable RAG", value=True)
253
- enable_summary = gr.Checkbox(
254
- label="Enable Summary", value=False)
255
- init_rag_num = gr.Number(
256
- label="Initial RAG Num", value=0)
257
- step_rag_num = gr.Number(
258
- label="Step RAG Num", value=10)
259
- skip_last_k = gr.Number(label="Skip Last K", value=0)
260
- summary_mode = gr.Textbox(
261
- label="Summary Mode", value='step')
262
- summary_skip_last_k = gr.Number(
263
- label="Summary Skip Last K", value=0)
264
- summary_context_length = gr.Number(
265
- label="Summary Context Length", value=None)
266
- force_finish = gr.Checkbox(
267
- label="Force FinalAnswer", value=True)
268
- seed = gr.Number(label="Seed", value=100)
269
- # Button to submit and update parameters
270
- submit_btn = gr.Button("Update Parameters")
271
-
272
- # Display the updated parameters
273
- updated_parameters_output = gr.JSON()
274
-
275
- # When button is clicked, update parameters
276
- submit_btn.click(fn=update_model_parameters,
277
- inputs=[enable_finish, enable_rag, enable_summary, init_rag_num, step_rag_num, skip_last_k,
278
- summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed],
279
- outputs=updated_parameters_output)
280
- # Button to submit the password
281
- submit_button = gr.Button("Submit")
282
-
283
- # When the button is clicked, check if the password is correct
284
- submit_button.click(
285
- check_password,
286
- inputs=password_input,
287
- outputs=[protected_accordion, incorrect_message]
288
- )
289
- gr.Markdown(LICENSE)
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  demo.launch(share=True)
 
5
  import spaces
6
  import gradio as gr
7
  import os
 
8
 
9
  # Determine the directory where the current file is located
10
  current_dir = os.path.dirname(os.path.abspath(__file__))
 
13
  # Set an environment variable
14
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
15
 
 
16
  DESCRIPTION = '''
17
  <div>
18
  <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools </h1>
 
34
  <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
35
  <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
36
  <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
37
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️ (top-right) to remove previous context before submitting a new question.</p>
38
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
 
39
  </div>
40
  """
41
 
 
68
  .gr-button svg { width: 32px !important; height: 32px !important; } /* Enlarges SVG icons */
69
  """
70
 
71
+ # Configuration variables (safe to keep at module level)
72
  model_name = 'mims-harvard/TxAgent-T1-Llama-3.1-8B'
73
  rag_model_name = 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B'
 
74
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
75
 
 
76
  question_examples = [
77
  ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'],
78
  ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'],
 
83
  'new_tool': os.path.join(current_dir, 'data', 'new_tool.json'),
84
  }
85
 
 
 
 
 
 
 
 
 
 
 
 
86
  def update_model_parameters(enable_finish, enable_rag, enable_summary,
87
+ init_rag_num, step_rag_num, skip_last_k,
88
+ summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed):
89
  # Update model instance parameters dynamically
90
  updated_params = agent.update_parameters(
91
  enable_finish=enable_finish,
 
100
  force_finish=force_finish,
101
  seed=seed,
102
  )
 
103
  return updated_params
104
 
 
105
  def update_seed():
106
  # Update model instance parameters dynamically
107
  seed = random.randint(0, 10000)
 
110
  )
111
  return updated_params
112
 
 
113
  def handle_retry(history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
114
  print("Updated seed:", update_seed())
115
  new_history = history[:retry_data.index]
116
  previous_prompt = history[retry_data.index]['content']
 
117
  print("previous_prompt", previous_prompt)
 
118
  yield from agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}], temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
119
 
 
120
  PASSWORD = "mypassword"
121
 
 
 
 
122
  def check_password(input_password):
123
  if input_password == PASSWORD:
124
  return gr.update(visible=True), ""
125
  else:
126
  return gr.update(visible=False), "Incorrect password, try again!"
127
 
128
+ # Create the Gradio interface
129
+ def create_interface(agent):
130
+ conversation_state = gr.State([])
131
+ chatbot = gr.Chatbot(height=800, placeholder=PLACEHOLDER,
132
+ label='TxAgent', type="messages", show_copy_button=True)
133
+
134
+ with gr.Blocks(css=css) as demo:
135
+ gr.Markdown(DESCRIPTION)
136
+ gr.Markdown(INTRO)
137
+
138
+ default_temperature = 0.3
139
+ default_max_new_tokens = 1024
140
+ default_max_tokens = 81920
141
+ default_max_round = 30
142
+
143
+ temperature_state = gr.State(value=default_temperature)
144
+ max_new_tokens_state = gr.State(value=default_max_new_tokens)
145
+ max_tokens_state = gr.State(value=default_max_tokens)
146
+ max_round_state = gr.State(value=default_max_round)
147
+
148
+ chatbot.retry(handle_retry, chatbot, chatbot, temperature_state, max_new_tokens_state,
149
+ max_tokens_state, gr.Checkbox(value=False, render=False), conversation_state, max_round_state)
150
+
151
+ gr.ChatInterface(
152
+ fn=agent.run_gradio_chat,
153
+ chatbot=chatbot,
154
+ fill_height=True, fill_width=True, stop_btn=True,
155
+ additional_inputs_accordion=gr.Accordion(
156
+ label="⚙️ Inference Parameters", open=False, render=False),
157
+ additional_inputs=[
158
+ temperature_state, max_new_tokens_state, max_tokens_state,
159
+ gr.Checkbox(
160
+ label="Activate multi-agent reasoning mode (it requires additional time but offers a more comprehensive analysis).",
161
+ value=False, render=False),
162
+ conversation_state,
163
+ max_round_state,
164
+ gr.Number(label="Seed", value=100, render=False)
165
+ ],
166
+ examples=question_examples,
167
+ cache_examples=False,
168
+ css=chat_css,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ with gr.Accordion("Settings", open=False):
172
+ temperature_slider = gr.Slider(
173
+ minimum=0,
174
+ maximum=1,
175
+ step=0.1,
176
+ value=default_temperature,
177
+ label="Temperature"
178
+ )
179
+ max_new_tokens_slider = gr.Slider(
180
+ minimum=128,
181
+ maximum=4096,
182
+ step=1,
183
+ value=default_max_new_tokens,
184
+ label="Max new tokens"
185
+ )
186
+ max_tokens_slider = gr.Slider(
187
+ minimum=128,
188
+ maximum=32000,
189
+ step=1,
190
+ value=default_max_tokens,
191
+ label="Max tokens"
192
+ )
193
+ max_round_slider = gr.Slider(
194
+ minimum=0,
195
+ maximum=50,
196
+ step=1,
197
+ value=default_max_round,
198
+ label="Max round")
199
+
200
+ temperature_slider.change(
201
+ lambda x: x, inputs=temperature_slider, outputs=temperature_state)
202
+ max_new_tokens_slider.change(
203
+ lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
204
+ max_tokens_slider.change(
205
+ lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
206
+ max_round_slider.change(
207
+ lambda x: x, inputs=max_round_slider, outputs=max_round_state)
208
+
209
+ password_input = gr.Textbox(
210
+ label="Enter Password for More Settings", type="password")
211
+ incorrect_message = gr.Textbox(visible=False, interactive=False)
212
+
213
+ with gr.Accordion("⚙️ Settings", open=False, visible=False) as protected_accordion:
214
+ with gr.Row():
215
+ with gr.Column(scale=1):
216
+ with gr.Accordion("⚙️ Model Loading", open=False):
217
+ model_name_input = gr.Textbox(
218
+ label="Enter model path", value=model_name)
219
+ load_model_btn = gr.Button(value="Load Model")
220
+ load_model_btn.click(
221
+ agent.load_models, inputs=model_name_input, outputs=gr.Textbox(label="Status"))
222
+
223
+ with gr.Column(scale=1):
224
+ with gr.Accordion("⚙️ Functional Parameters", open=False):
225
+ enable_finish = gr.Checkbox(label="Enable Finish", value=True)
226
+ enable_rag = gr.Checkbox(label="Enable RAG", value=True)
227
+ enable_summary = gr.Checkbox(label="Enable Summary", value=False)
228
+ init_rag_num = gr.Number(label="Initial RAG Num", value=0)
229
+ step_rag_num = gr.Number(label="Step RAG Num", value=10)
230
+ skip_last_k = gr.Number(label="Skip Last K", value=0)
231
+ summary_mode = gr.Textbox(label="Summary Mode", value='step')
232
+ summary_skip_last_k = gr.Number(label="Summary Skip Last K", value=0)
233
+ summary_context_length = gr.Number(label="Summary Context Length", value=None)
234
+ force_finish = gr.Checkbox(label="Force FinalAnswer", value=True)
235
+ seed = gr.Number(label="Seed", value=100)
236
+
237
+ submit_btn = gr.Button("Update Parameters")
238
+ updated_parameters_output = gr.JSON()
239
+
240
+ submit_btn.click(
241
+ fn=update_model_parameters,
242
+ inputs=[enable_finish, enable_rag, enable_summary, init_rag_num,
243
+ step_rag_num, skip_last_k, summary_mode, summary_skip_last_k,
244
+ summary_context_length, force_finish, seed],
245
+ outputs=updated_parameters_output
246
+ )
247
+
248
+ submit_button = gr.Button("Submit")
249
+ submit_button.click(
250
+ check_password,
251
+ inputs=password_input,
252
+ outputs=[protected_accordion, incorrect_message]
253
+ )
254
+
255
+ gr.Markdown(LICENSE)
256
+
257
+ return demo
258
 
259
  if __name__ == "__main__":
260
+ # Initialize the agent only when running directly
261
+ agent = TxAgent(
262
+ model_name,
263
+ rag_model_name,
264
+ tool_files_dict=new_tool_files,
265
+ force_finish=True,
266
+ enable_checker=True,
267
+ step_rag_num=10,
268
+ seed=100,
269
+ additional_default_tools=['DirectResponse', 'RequireClarification']
270
+ )
271
+ agent.init_model()
272
+
273
+ # Create and launch the interface
274
+ demo = create_interface(agent)
275
  demo.launch(share=True)