Ali2206 commited on
Commit
fc30674
·
verified ·
1 Parent(s): ae94627

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +349 -97
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import random
 
 
2
  import os
3
  import torch
4
  import logging
@@ -15,15 +17,16 @@ logging.basicConfig(
15
  )
16
  logger = logging.getLogger(__name__)
17
 
 
18
  current_dir = os.path.dirname(os.path.abspath(__file__))
19
  os.environ["MKL_THREADING_LAYER"] = "GNU"
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
 
22
- # Configuration - Update paths as needed
23
  CONFIG = {
24
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
25
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
26
- "embedding_filename": "path_to_your_embeddings.pt", # Update this path
27
  "tool_files": {
28
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
29
  "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
@@ -33,62 +36,166 @@ CONFIG = {
33
  }
34
  }
35
 
36
- def safe_load_embeddings(filepath: str):
37
- """Handle embedding loading with fallbacks"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  try:
39
- # Try with weights_only=True first
40
  return torch.load(filepath, weights_only=True)
41
  except Exception as e:
42
- logger.warning(f"Secure load failed, trying without weights_only: {str(e)}")
43
  try:
 
44
  return torch.load(filepath, weights_only=False)
45
  except Exception as e:
46
  logger.error(f"Failed to load embeddings: {str(e)}")
47
  return None
48
 
49
- def get_tools_from_universe(tooluniverse):
50
- """Flexible tool extraction from ToolUniverse"""
51
- if hasattr(tooluniverse, 'get_all_tools'):
52
- return tooluniverse.get_all_tools()
53
- elif hasattr(tooluniverse, 'tools'):
54
- return tooluniverse.tools
55
- elif hasattr(tooluniverse, 'list_tools'):
56
- return tooluniverse.list_tools()
57
- else:
58
- logger.error("Could not find any tool access method in ToolUniverse")
59
- # Try to load from files directly as fallback
60
- tools = []
61
- for tool_file in CONFIG["tool_files"].values():
62
- if os.path.exists(tool_file):
63
- with open(tool_file, 'r') as f:
64
- tools.extend(json.load(f))
65
- return tools if tools else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def prepare_tool_files():
68
  """Ensure tool files exist and are populated"""
69
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
70
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
71
- logger.info("Generating tool list...")
72
  try:
73
  tu = ToolUniverse()
74
- tools = get_tools_from_universe(tu)
75
- if tools:
76
- with open(CONFIG["tool_files"]["new_tool"], "w") as f:
77
- json.dump(tools, f, indent=2)
78
- logger.info(f"Saved {len(tools)} tools")
79
  else:
80
- logger.error("No tools could be loaded")
 
 
 
 
 
81
  except Exception as e:
82
- logger.error(f"Tool file preparation failed: {str(e)}")
83
 
84
  def create_agent():
85
- """Create and initialize the TxAgent with robust error handling"""
 
 
86
  prepare_tool_files()
87
-
 
88
  try:
89
  agent = TxAgent(
90
- model_name=CONFIG["model_name"],
91
- rag_model_name=CONFIG["rag_model_name"],
92
  tool_files_dict=CONFIG["tool_files"],
93
  force_finish=True,
94
  enable_checker=True,
@@ -99,82 +206,227 @@ def create_agent():
99
  agent.init_model()
100
  return agent
101
  except Exception as e:
102
- logger.error(f"Agent creation failed: {str(e)}")
103
  raise
104
 
105
- def format_response(history, message):
106
- """Properly format responses for Gradio Chatbot"""
107
- if isinstance(message, (str, dict)):
108
- return history + [[None, str(message)]]
109
- elif hasattr(message, '__iter__'):
110
- full_response = ""
111
- for chunk in message:
112
- if isinstance(chunk, dict):
113
- full_response += chunk.get("content", "")
114
- else:
115
- full_response += str(chunk)
116
- return history + [[None, full_response]]
117
- return history + [[None, str(message)]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  def create_demo(agent):
120
- """Create the Gradio interface with proper message handling"""
121
- with gr.Blocks() as demo:
122
- chatbot = gr.Chatbot(
123
- height=800,
124
- label='TxAgent',
125
- show_copy_button=True,
126
- type="messages" # Use the modern message format
127
- )
 
 
 
 
 
 
 
 
 
 
128
 
129
- msg = gr.Textbox(label="Input", placeholder="Type your question...")
130
- clear = gr.ClearButton([msg, chatbot])
 
 
131
 
132
- def respond(message, chat_history):
133
- try:
134
- # Convert Gradio history to agent format
135
- agent_history = []
136
- for user_msg, bot_msg in chat_history:
137
- if user_msg:
138
- agent_history.append({"role": "user", "content": user_msg})
139
- if bot_msg:
140
- agent_history.append({"role": "assistant", "content": bot_msg})
141
-
142
- # Get response from agent
143
- response = agent.run_gradio_chat(
144
- agent_history + [{"role": "user", "content": message}],
145
- temperature=0.3,
146
- max_new_tokens=1024,
147
- max_tokens=81920,
148
- multi_agent=False,
149
- conversation=[],
150
- max_round=30
151
- )
152
-
153
- # Format the response properly
154
- full_response = ""
155
- for chunk in response:
156
- if isinstance(chunk, dict):
157
- full_response += chunk.get("content", "")
158
- else:
159
- full_response += str(chunk)
160
-
161
- return chat_history + [(message, full_response)]
162
-
163
- except Exception as e:
164
- logger.error(f"Error in response handling: {str(e)}")
165
- return chat_history + [(message, f"Error: {str(e)}")]
166
 
167
- msg.submit(respond, [msg, chatbot], [chatbot])
168
- clear.click(lambda: [], None, [chatbot])
 
 
 
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  return demo
171
 
172
  def main():
173
- """Main application entry point"""
174
  try:
175
  agent = create_agent()
176
  demo = create_demo(agent)
177
- demo.launch(server_name="0.0.0.0", server_port=7860)
178
  except Exception as e:
179
  logger.error(f"Application failed to start: {str(e)}")
180
  raise
 
1
  import random
2
+ import datetime
3
+ import sys
4
  import os
5
  import torch
6
  import logging
 
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
+ # Determine the directory where the current file is located
21
  current_dir = os.path.dirname(os.path.abspath(__file__))
22
  os.environ["MKL_THREADING_LAYER"] = "GNU"
23
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
24
 
25
+ # Configuration
26
  CONFIG = {
27
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
28
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
29
+ "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
30
  "tool_files": {
31
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
32
  "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
 
36
  }
37
  }
38
 
39
+ DESCRIPTION = '''
40
+ <div>
41
+ <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools</h1>
42
+ </div>
43
+ '''
44
+
45
+ INTRO = """
46
+ Precision therapeutics require multimodal adaptive models that provide personalized treatment recommendations.
47
+ We introduce TxAgent, an AI agent that leverages multi-step reasoning and real-time biomedical knowledge
48
+ retrieval across a toolbox of 211 expert-curated tools to navigate complex drug interactions,
49
+ contraindications, and patient-specific treatment strategies, delivering evidence-grounded therapeutic decisions.
50
+ """
51
+
52
+ LICENSE = """
53
+ We welcome your feedback and suggestions to enhance your experience with TxAgent, and if you're interested
54
+ in collaboration, please email Marinka Zitnik and Shanghua Gao.
55
+
56
+ ### Medical Advice Disclaimer
57
+ DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE
58
+ The information, including but not limited to, text, graphics, images and other material contained on this
59
+ website are for informational purposes only. No material on this site is intended to be a substitute for
60
+ professional medical advice, diagnosis or treatment.
61
+ """
62
+
63
+ PLACEHOLDER = """
64
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
65
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
66
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
67
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️ (top-right) to remove previous context before submitting a new question.</p>
68
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
69
+ </div>
70
+ """
71
+
72
+ css = """
73
+ h1 {
74
+ text-align: center;
75
+ display: block;
76
+ }
77
+
78
+ #duplicate-button {
79
+ margin: auto;
80
+ color: white;
81
+ background: #1565c0;
82
+ border-radius: 100vh;
83
+ }
84
+ .small-button button {
85
+ font-size: 12px !important;
86
+ padding: 4px 8px !important;
87
+ height: 6px !important;
88
+ width: 4px !important;
89
+ }
90
+ .gradio-accordion {
91
+ margin-top: 0px !important;
92
+ margin-bottom: 0px !important;
93
+ }
94
+ """
95
+
96
+ chat_css = """
97
+ .gr-button { font-size: 20px !important; }
98
+ .gr-button svg { width: 32px !important; height: 32px !important; }
99
+ """
100
+
101
+ def safe_load_embeddings(filepath: str) -> any:
102
+ """Safely load embeddings with proper weights_only handling"""
103
  try:
104
+ # First try with weights_only=True
105
  return torch.load(filepath, weights_only=True)
106
  except Exception as e:
107
+ logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
108
  try:
109
+ # Fallback to unsafe load if needed
110
  return torch.load(filepath, weights_only=False)
111
  except Exception as e:
112
  logger.error(f"Failed to load embeddings: {str(e)}")
113
  return None
114
 
115
+ def patch_embedding_loading():
116
+ """Monkey-patch the embedding loading functionality"""
117
+ try:
118
+ from txagent.toolrag import ToolRAGModel
119
+
120
+ original_load = ToolRAGModel.load_tool_desc_embedding
121
+
122
+ def patched_load(self, tooluniverse):
123
+ try:
124
+ if not os.path.exists(CONFIG["embedding_filename"]):
125
+ logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
126
+ return False
127
+
128
+ self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
129
+
130
+ # Updated tool loading approach
131
+ if hasattr(tooluniverse, 'get_all_tools'):
132
+ tools = tooluniverse.get_all_tools()
133
+ elif hasattr(tooluniverse, 'tools'):
134
+ tools = tooluniverse.tools
135
+ else:
136
+ logger.error("No method found to access tools from ToolUniverse")
137
+ return False
138
+
139
+ current_count = len(tools)
140
+ embedding_count = len(self.tool_desc_embedding)
141
+
142
+ if current_count != embedding_count:
143
+ logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
144
+
145
+ if current_count < embedding_count:
146
+ self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
147
+ logger.info(f"Truncated embeddings to match {current_count} tools")
148
+ else:
149
+ last_embedding = self.tool_desc_embedding[-1]
150
+ padding = [last_embedding] * (current_count - embedding_count)
151
+ self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding)
152
+ logger.info(f"Padded embeddings to match {current_count} tools")
153
+
154
+ return True
155
+
156
+ except Exception as e:
157
+ logger.error(f"Failed to load embeddings: {str(e)}")
158
+ return False
159
+
160
+ ToolRAGModel.load_tool_desc_embedding = patched_load
161
+ logger.info("Successfully patched embedding loading")
162
+
163
+ except Exception as e:
164
+ logger.error(f"Failed to patch embedding loading: {str(e)}")
165
+ raise
166
 
167
  def prepare_tool_files():
168
  """Ensure tool files exist and are populated"""
169
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
170
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
171
+ logger.info("Generating tool list using ToolUniverse...")
172
  try:
173
  tu = ToolUniverse()
174
+ if hasattr(tu, 'get_all_tools'):
175
+ tools = tu.get_all_tools()
176
+ elif hasattr(tu, 'tools'):
177
+ tools = tu.tools
 
178
  else:
179
+ tools = []
180
+ logger.error("Could not access tools from ToolUniverse")
181
+
182
+ with open(CONFIG["tool_files"]["new_tool"], "w") as f:
183
+ json.dump(tools, f, indent=2)
184
+ logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
185
  except Exception as e:
186
+ logger.error(f"Failed to prepare tool files: {str(e)}")
187
 
188
  def create_agent():
189
+ """Create and initialize the TxAgent"""
190
+ # Apply the embedding patch before creating the agent
191
+ patch_embedding_loading()
192
  prepare_tool_files()
193
+
194
+ # Initialize the agent
195
  try:
196
  agent = TxAgent(
197
+ CONFIG["model_name"],
198
+ CONFIG["rag_model_name"],
199
  tool_files_dict=CONFIG["tool_files"],
200
  force_finish=True,
201
  enable_checker=True,
 
206
  agent.init_model()
207
  return agent
208
  except Exception as e:
209
+ logger.error(f"Failed to create agent: {str(e)}")
210
  raise
211
 
212
+ def handle_chat_response(history, message, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
213
+ """Convert generator output to Gradio-compatible format"""
214
+ full_response = ""
215
+ for chunk in message:
216
+ if isinstance(chunk, dict):
217
+ full_response += chunk.get("content", "")
218
+ else:
219
+ full_response += str(chunk)
220
+ history.append((None, full_response))
221
+ return history
222
+
223
+ def update_model_parameters(agent, enable_finish, enable_rag, enable_summary,
224
+ init_rag_num, step_rag_num, skip_last_k,
225
+ summary_mode, summary_skip_last_k, summary_context_length,
226
+ force_finish, seed):
227
+ """Update model parameters"""
228
+ updated_params = agent.update_parameters(
229
+ enable_finish=enable_finish,
230
+ enable_rag=enable_rag,
231
+ enable_summary=enable_summary,
232
+ init_rag_num=init_rag_num,
233
+ step_rag_num=step_rag_num,
234
+ skip_last_k=skip_last_k,
235
+ summary_mode=summary_mode,
236
+ summary_skip_last_k=summary_skip_last_k,
237
+ summary_context_length=summary_context_length,
238
+ force_finish=force_finish,
239
+ seed=seed,
240
+ )
241
+ return updated_params
242
+
243
+ def update_seed(agent):
244
+ """Update random seed"""
245
+ seed = random.randint(0, 10000)
246
+ updated_params = agent.update_parameters(seed=seed)
247
+ return updated_params
248
+
249
+ def handle_retry(agent, history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
250
+ """Handle retry functionality"""
251
+ print("Updated seed:", update_seed(agent))
252
+ new_history = history[:retry_data.index]
253
+ previous_prompt = history[retry_data.index]['content']
254
+ print("previous_prompt", previous_prompt)
255
+ response = agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}],
256
+ temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
257
+ yield from handle_chat_response(new_history, response, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
258
+
259
+ PASSWORD = "mypassword"
260
+
261
+ def check_password(input_password):
262
+ """Check password for protected settings"""
263
+ if input_password == PASSWORD:
264
+ return gr.update(visible=True), ""
265
+ else:
266
+ return gr.update(visible=False), "Incorrect password, try again!"
267
 
268
  def create_demo(agent):
269
+ """Create the Gradio interface"""
270
+ default_temperature = 0.3
271
+ default_max_new_tokens = 1024
272
+ default_max_tokens = 81920
273
+ default_max_round = 30
274
+
275
+ question_examples = [
276
+ ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'],
277
+ ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'],
278
+ ['A 30-year-old patient is taking Prozac to treat their depression. They were recently diagnosed with WHIM syndrome and require a treatment for that condition as well. Is Xolremdi suitable for this patient, considering contraindications?'],
279
+ ]
280
+
281
+ chatbot = gr.Chatbot(height=800, placeholder=PLACEHOLDER,
282
+ label='TxAgent', show_copy_button=True)
283
+
284
+ with gr.Blocks(css=css) as demo:
285
+ gr.Markdown(DESCRIPTION)
286
+ gr.Markdown(INTRO)
287
 
288
+ temperature_state = gr.State(value=default_temperature)
289
+ max_new_tokens_state = gr.State(value=default_max_new_tokens)
290
+ max_tokens_state = gr.State(value=default_max_tokens)
291
+ max_round_state = gr.State(value=default_max_round)
292
 
293
+ chatbot.retry(
294
+ lambda *args: handle_retry(agent, *args),
295
+ inputs=[chatbot, chatbot, temperature_state, max_new_tokens_state,
296
+ max_tokens_state, gr.Checkbox(value=False, render=False),
297
+ gr.State([]), max_round_state]
298
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
+ with gr.Row():
301
+ with gr.Column(scale=4):
302
+ msg = gr.Textbox(label="Input", placeholder="Type your question here...")
303
+ with gr.Column(scale=1):
304
+ submit_btn = gr.Button("Submit", variant="primary")
305
 
306
+ with gr.Row():
307
+ clear_btn = gr.ClearButton([msg, chatbot])
308
+
309
+ def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
310
+ response = agent.run_gradio_chat(
311
+ chat_history + [{"role": "user", "content": message}],
312
+ temperature,
313
+ max_new_tokens,
314
+ max_tokens,
315
+ multi_agent,
316
+ conversation,
317
+ max_round
318
+ )
319
+ return handle_chat_response(chat_history, response, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
320
+
321
+ submit_btn.click(
322
+ respond,
323
+ inputs=[msg, chatbot, temperature_state, max_new_tokens_state,
324
+ max_tokens_state, gr.Checkbox(value=False, render=False),
325
+ gr.State([]), max_round_state],
326
+ outputs=[chatbot]
327
+ )
328
+ msg.submit(
329
+ respond,
330
+ inputs=[msg, chatbot, temperature_state, max_new_tokens_state,
331
+ max_tokens_state, gr.Checkbox(value=False, render=False),
332
+ gr.State([]), max_round_state],
333
+ outputs=[chatbot]
334
+ )
335
+
336
+ with gr.Accordion("Settings", open=False):
337
+ temperature_slider = gr.Slider(
338
+ minimum=0,
339
+ maximum=1,
340
+ step=0.1,
341
+ value=default_temperature,
342
+ label="Temperature"
343
+ )
344
+ max_new_tokens_slider = gr.Slider(
345
+ minimum=128,
346
+ maximum=4096,
347
+ step=1,
348
+ value=default_max_new_tokens,
349
+ label="Max new tokens"
350
+ )
351
+ max_tokens_slider = gr.Slider(
352
+ minimum=128,
353
+ maximum=32000,
354
+ step=1,
355
+ value=default_max_tokens,
356
+ label="Max tokens"
357
+ )
358
+ max_round_slider = gr.Slider(
359
+ minimum=0,
360
+ maximum=50,
361
+ step=1,
362
+ value=default_max_round,
363
+ label="Max round")
364
+
365
+ temperature_slider.change(
366
+ lambda x: x, inputs=temperature_slider, outputs=temperature_state)
367
+ max_new_tokens_slider.change(
368
+ lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
369
+ max_tokens_slider.change(
370
+ lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
371
+ max_round_slider.change(
372
+ lambda x: x, inputs=max_round_slider, outputs=max_round_state)
373
+
374
+ password_input = gr.Textbox(
375
+ label="Enter Password for More Settings", type="password")
376
+ incorrect_message = gr.Textbox(visible=False, interactive=False)
377
+
378
+ with gr.Accordion("⚙️ Advanced Settings", open=False, visible=False) as protected_accordion:
379
+ with gr.Row():
380
+ with gr.Column(scale=1):
381
+ with gr.Accordion("Model Settings", open=False):
382
+ model_name_input = gr.Textbox(
383
+ label="Enter model path", value=CONFIG["model_name"])
384
+ load_model_btn = gr.Button(value="Load Model")
385
+ load_model_btn.click(
386
+ agent.load_models,
387
+ inputs=model_name_input,
388
+ outputs=gr.Textbox(label="Status"))
389
+ with gr.Column(scale=1):
390
+ with gr.Accordion("Functional Parameters", open=False):
391
+ enable_finish = gr.Checkbox(label="Enable Finish", value=True)
392
+ enable_rag = gr.Checkbox(label="Enable RAG", value=True)
393
+ enable_summary = gr.Checkbox(label="Enable Summary", value=False)
394
+ init_rag_num = gr.Number(label="Initial RAG Num", value=0)
395
+ step_rag_num = gr.Number(label="Step RAG Num", value=10)
396
+ skip_last_k = gr.Number(label="Skip Last K", value=0)
397
+ summary_mode = gr.Textbox(label="Summary Mode", value='step')
398
+ summary_skip_last_k = gr.Number(label="Summary Skip Last K", value=0)
399
+ summary_context_length = gr.Number(label="Summary Context Length", value=None)
400
+ force_finish = gr.Checkbox(label="Force FinalAnswer", value=True)
401
+ seed = gr.Number(label="Seed", value=100)
402
+ submit_btn = gr.Button("Update Parameters")
403
+ updated_parameters_output = gr.JSON()
404
+ submit_btn.click(
405
+ lambda *args: update_model_parameters(agent, *args),
406
+ inputs=[enable_finish, enable_rag, enable_summary,
407
+ init_rag_num, step_rag_num, skip_last_k,
408
+ summary_mode, summary_skip_last_k,
409
+ summary_context_length, force_finish, seed],
410
+ outputs=updated_parameters_output
411
+ )
412
+
413
+ submit_button = gr.Button("Submit")
414
+ submit_button.click(
415
+ check_password,
416
+ inputs=password_input,
417
+ outputs=[protected_accordion, incorrect_message]
418
+ )
419
+
420
+ gr.Markdown(LICENSE)
421
+
422
  return demo
423
 
424
  def main():
425
+ """Main function to run the application"""
426
  try:
427
  agent = create_agent()
428
  demo = create_demo(agent)
429
+ demo.launch(share=True)
430
  except Exception as e:
431
  logger.error(f"Application failed to start: {str(e)}")
432
  raise