franceth commited on
Commit
c377998
·
verified ·
1 Parent(s): 6673c87

Table select, bug upload and metrics_bug

Browse files
Files changed (1) hide show
  1. app.py +175 -119
app.py CHANGED
@@ -40,6 +40,7 @@ function refresh() {
40
  }
41
  }
42
  """
 
43
 
44
  with open('style.css', 'r') as file:
45
  css = file.read()
@@ -80,8 +81,9 @@ def load_data(file, path, use_default):
80
  try:
81
  input_data["input_method"] = 'uploaded_file'
82
  input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
 
83
  #input_data["data_path"] = os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
84
- input_data["data_path"] = f"{input_data['db_name']}.sqlite"
85
  input_data["data"] = us.load_data(file, input_data["db_name"])
86
  df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
87
  if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
@@ -295,10 +297,11 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
295
  ]
296
  )
297
 
298
- ######################################
299
  # TABLE SELECTION PART #
300
  ######################################
301
  with select_table_acc:
 
302
  table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
303
  table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
304
  selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
@@ -310,37 +313,69 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
310
  """Dynamically updates the list of available tables."""
311
  if isinstance(data, dict) and data:
312
  table_names = []
313
- table_names.append("All")
314
- table_names.extend(data.keys()) # Concatena data.keys() alla lista
 
 
 
 
 
 
315
  return gr.update(choices=table_names, value=[]) # Reset selections
316
- return gr.update(choices=[], value=[])
317
 
 
 
318
  def show_selected_tables(data, selected_tables):
319
- """Displays only the tables selected by the user and enables the button."""
320
  updates = []
321
- if isinstance(data, dict) and data:
322
-
323
- available_tables = list(data.keys()) # Actually available names
324
- if "All" in selected_tables:
325
- selected_tables = available_tables
 
 
 
 
 
 
 
 
326
  else:
327
- selected_tables = [t for t in selected_tables if t in available_tables] # Filter valid selections
328
-
329
- tables = {name: data[name] for name in selected_tables} # Filter the DataFrames
330
-
331
- for i, (name, df) in enumerate(tables.items()):
332
- updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
333
 
334
- # If there are fewer than 5 tables, hide the other DataFrames
335
- for _ in range(len(tables), 50):
336
- updates.append(gr.update(visible=False))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  else:
338
- updates = [gr.update(value=pd.DataFrame(), visible=False) for _ in range(50)]
 
 
 
 
 
 
 
 
 
339
 
340
- # Enable/disable the button based on selections
341
- button_state = bool(selected_tables) # True if at least one table is selected, False otherwise
342
- updates.append(gr.update(interactive=button_state)) # Update button state
343
-
344
  return updates
345
 
346
  def show_selected_table_names(data, selected_tables):
@@ -357,8 +392,12 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
357
  data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector])
358
 
359
  # Updates the visible tables and the button state based on user selections
360
- table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
361
-
 
 
 
 
362
  # Shows the list of selected tables when "Choose your models" is clicked
363
  open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names])
364
  open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
@@ -568,6 +607,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
568
 
569
  def qatch_flow():
570
  #caching
 
571
  predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
572
  metrics_conc = pd.DataFrame()
573
  columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
@@ -577,35 +617,36 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
577
  target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])]
578
  target_df = target_df[target_df["model"].isin(input_data['models'])]
579
  predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list}
580
-
581
  for model in input_data['models']:
582
  model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
583
  yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
584
  count=1
585
  for _, row in predictions_dict[model].iterrows():
586
  #for index, row in target_df.iterrows():
587
- percent_complete = round(count / len(predictions_dict[model]) * 100, 2)
588
- count=count+1
589
- load_text = f"{generate_loading_text(percent_complete)}"
590
- question = row['question']
591
-
592
- display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Natural Language:</div>
593
- <div style='display: flex; align-items: center;'>
594
- <div class='sqlquery' font-family: 'Inter', sans-serif;>{question}</div>
595
- <div style='font-size: 3rem'>➡️</div>
596
- </div>
597
- """
598
- yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
599
- #time.sleep(0.02)
600
- prediction = row['predicted_sql']
601
-
602
- display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
603
- <div style='display: flex; align-items: center;'>
604
- <div style='font-size: 3rem'>➡️</div>
605
- <div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
606
- </div>
607
- """
608
- yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
 
609
  yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
610
  metrics_conc = target_df
611
  if 'valid_efficiency_score' not in metrics_conc.columns:
@@ -622,74 +663,74 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
622
  #target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
623
 
624
  predictor = ModelPrediction()
625
-
626
  for model in input_data["models"]:
627
  model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
628
  yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
629
  count=0
630
  for index, row in target_df.iterrows():
631
-
632
- percent_complete = round(((index+1) / len(target_df)) * 100, 2)
633
- load_text = f"{generate_loading_text(percent_complete)}"
634
-
635
- question = row['question']
636
- display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Natural Language:</div>
637
- <div style='display: flex; align-items: center;'>
638
- <div class='sqlquery' font-family: 'Inter', sans-serif;>{question}</div>
639
- <div style='font-size: 3rem'>➡️</div>
640
- </div>
641
- """
642
- yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
643
- start_time = time.time()
644
- samples = us.generate_some_samples(input_data['data']['db'], row["tbl_name"])
645
-
646
- schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
647
- db_id = input_data["db_name"],
648
- base_path = input_data["data_path"],
649
- normalize=False,
650
- sql=row["query"]
651
- )
652
-
653
- prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
654
- #PREDICTION SQL
655
- if prompt_to_send == prompt_default:
656
- prompt_to_send = None
657
- response = predictor.make_prediction(question=question, db_schema=schema_text, model_name=model, prompt=f"{prompt_to_send}")
658
- prediction = response['response_parsed']
659
- price = response['cost']
660
- answer = response['response']
661
-
662
- end_time = time.time()
663
- display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>>Predicted SQL:</div>
664
- <div style='display: flex; align-items: center;'>
665
- <div style='font-size: 3rem'>➡️</div>
666
- <div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
667
- </div>
668
- """
669
- # Create a new row as dataframe
670
- new_row = pd.DataFrame([{
671
- 'id': index,
672
- 'question': question,
673
- 'predicted_sql': prediction,
674
- 'time': end_time - start_time,
675
- 'query': row["query"],
676
- 'db_path': input_data["data_path"],
677
- 'price':price,
678
- 'answer':answer,
679
- 'number_question':count
680
- }]).dropna(how="all") # Remove only completely empty rows
681
- count=count+1
682
- # TODO: use a for loop
683
- for col in target_df.columns:
684
- if col not in new_row.columns:
685
- new_row[col] = row[col]
686
-
687
- # Update model's prediction dataframe incrementally
688
- if not new_row.empty:
689
- predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
690
-
691
- # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
692
- yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
693
 
694
  yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
695
  # END
@@ -802,7 +843,11 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
802
  # fn=lambda: gr.update(open=True, visible=True),
803
  # outputs=[download_metrics]
804
  # )
805
- reset_data = gr.Button("Back to upload data section", interactive=False)
 
 
 
 
806
 
807
  metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data])
808
 
@@ -812,6 +857,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
812
  fn=lambda: gr.update(visible=False),
813
  outputs=[download_metrics]
814
  )
 
815
 
816
  reset_data.click(
817
  fn=enable_disable,
@@ -1958,7 +2004,8 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1958
  model_multiselect_bar = gr.CheckboxGroup(
1959
  choices=models,
1960
  label="Select one or more models:",
1961
- value=models
 
1962
  )
1963
 
1964
  group_radio = gr.Radio(
@@ -1991,7 +2038,8 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1991
  model_multiselect_radar = gr.CheckboxGroup(
1992
  choices=models,
1993
  label="Select one or more models:",
1994
- value=models
 
1995
  )
1996
 
1997
  with gr.Row():
@@ -2022,11 +2070,18 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
2022
  label="Select the metrics group that you want to use:",
2023
  value="Qatch"
2024
  )
 
 
 
 
 
 
 
2025
 
2026
  model_radio_ranking = gr.Radio(
2027
- choices=list(all_model_as_dic.keys()),
2028
  label="Select the model that you want to use:",
2029
- value="All"
2030
  )
2031
 
2032
  category_radio_ranking = gr.Radio(
@@ -2071,7 +2126,8 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
2071
  model_multiselect_rate = gr.CheckboxGroup(
2072
  choices=models,
2073
  label="Select one or more models:",
2074
- value=models
 
2075
  )
2076
 
2077
 
 
40
  }
41
  }
42
  """
43
+ reset_flag=False
44
 
45
  with open('style.css', 'r') as file:
46
  css = file.read()
 
81
  try:
82
  input_data["input_method"] = 'uploaded_file'
83
  input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
84
+ #TODO if not sqlite
85
  #input_data["data_path"] = os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
86
+ input_data["data_path"] = file #f"{input_data['db_name']}.sqlite"
87
  input_data["data"] = us.load_data(file, input_data["db_name"])
88
  df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
89
  if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
 
297
  ]
298
  )
299
 
300
+ ######################################
301
  # TABLE SELECTION PART #
302
  ######################################
303
  with select_table_acc:
304
+ previous_selection = gr.State([])
305
  table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
306
  table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
307
  selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
 
313
  """Dynamically updates the list of available tables."""
314
  if isinstance(data, dict) and data:
315
  table_names = []
316
+
317
+ if input_data['input_method'] == "default":
318
+ table_names.append("All")
319
+
320
+ elif len(data) < 6:
321
+ table_names.append("All") # In caso ci siano poche tabelle, ha senso mantenere "All"
322
+
323
+ table_names.extend(data.keys())
324
  return gr.update(choices=table_names, value=[]) # Reset selections
 
325
 
326
+ return gr.update(choices=[], value=[])
327
+
328
  def show_selected_tables(data, selected_tables):
 
329
  updates = []
330
+ available_tables = list(data.keys()) if isinstance(data, dict) and data else []
331
+ input_method = input_data['input_method']
332
+
333
+ allow_all = input_method == "default" or len(available_tables) < 6
334
+ selected_set = set(selected_tables)
335
+ tables_set = set(available_tables)
336
+
337
+ # ▶️
338
+ if allow_all:
339
+ if "All" in selected_set:
340
+ selected_tables = ["All"] + available_tables
341
+ elif selected_set == tables_set:
342
+ selected_tables = []
343
  else:
344
+ #
345
+ selected_tables = [t for t in selected_tables if t in available_tables]
346
+ else:
347
+ #
348
+ selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5]
 
349
 
350
+ #
351
+ tables = {name: data[name] for name in selected_tables if name in data}
352
+
353
+ for i, (name, df) in enumerate(tables.items()):
354
+ updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
355
+ for _ in range(len(tables), 50):
356
+ updates.append(gr.update(visible=False))
357
+
358
+ # ✅ Bottone abilitato solo se c'è almeno una tabella valida
359
+ updates.append(gr.update(interactive=bool(tables)))
360
+
361
+ # 🔄 Aggiorna la CheckboxGroup con logica coerente
362
+ if allow_all:
363
+ updates.insert(0, gr.update(
364
+ choices=["All"] + available_tables,
365
+ value=selected_tables
366
+ ))
367
  else:
368
+ if len(selected_tables) >= 5:
369
+ updates.insert(0, gr.update(
370
+ choices=selected_tables,
371
+ value=selected_tables
372
+ ))
373
+ else:
374
+ updates.insert(0, gr.update(
375
+ choices=available_tables,
376
+ value=selected_tables
377
+ ))
378
 
 
 
 
 
379
  return updates
380
 
381
  def show_selected_table_names(data, selected_tables):
 
392
  data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector])
393
 
394
  # Updates the visible tables and the button state based on user selections
395
+ #table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
396
+ table_selector.change(
397
+ fn=show_selected_tables,
398
+ inputs=[data_state, table_selector],
399
+ outputs=[table_selector] + table_outputs + [open_model_selection]
400
+ )
401
  # Shows the list of selected tables when "Choose your models" is clicked
402
  open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names])
403
  open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
 
607
 
608
  def qatch_flow():
609
  #caching
610
+ global reset_flag
611
  predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
612
  metrics_conc = pd.DataFrame()
613
  columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
 
617
  target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])]
618
  target_df = target_df[target_df["model"].isin(input_data['models'])]
619
  predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list}
620
+ reset_flag = False
621
  for model in input_data['models']:
622
  model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
623
  yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
624
  count=1
625
  for _, row in predictions_dict[model].iterrows():
626
  #for index, row in target_df.iterrows():
627
+ if (reset_flag == False):
628
+ percent_complete = round(count / len(predictions_dict[model]) * 100, 2)
629
+ count=count+1
630
+ load_text = f"{generate_loading_text(percent_complete)}"
631
+ question = row['question']
632
+
633
+ display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Natural Language:</div>
634
+ <div style='display: flex; align-items: center;'>
635
+ <div class='sqlquery' font-family: 'Inter', sans-serif;>{question}</div>
636
+ <div style='font-size: 3rem'>➡️</div>
637
+ </div>
638
+ """
639
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
640
+ #time.sleep(0.02)
641
+ prediction = row['predicted_sql']
642
+
643
+ display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
644
+ <div style='display: flex; align-items: center;'>
645
+ <div style='font-size: 3rem'>➡️</div>
646
+ <div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
647
+ </div>
648
+ """
649
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
650
  yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
651
  metrics_conc = target_df
652
  if 'valid_efficiency_score' not in metrics_conc.columns:
 
663
  #target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
664
 
665
  predictor = ModelPrediction()
666
+ reset_flag = False
667
  for model in input_data["models"]:
668
  model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
669
  yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
670
  count=0
671
  for index, row in target_df.iterrows():
672
+ if (reset_flag == False):
673
+ percent_complete = round(((index+1) / len(target_df)) * 100, 2)
674
+ load_text = f"{generate_loading_text(percent_complete)}"
675
+
676
+ question = row['question']
677
+ display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Natural Language:</div>
678
+ <div style='display: flex; align-items: center;'>
679
+ <div class='sqlquery' font-family: 'Inter', sans-serif;>{question}</div>
680
+ <div style='font-size: 3rem'>➡️</div>
681
+ </div>
682
+ """
683
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
684
+ start_time = time.time()
685
+ samples = us.generate_some_samples(input_data['data']['db'], row["tbl_name"])
686
+
687
+ schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
688
+ db_id = input_data["db_name"],
689
+ base_path = input_data["data_path"],
690
+ normalize=False,
691
+ sql=row["query"]
692
+ )
693
+
694
+ prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
695
+ #PREDICTION SQL
696
+
697
+ response = predictor.make_prediction(question=question, db_schema=schema_text, model_name=model, prompt=f"{prompt_to_send}")
698
+ prediction = response['response_parsed']
699
+ price = response['cost']
700
+ answer = response['response']
701
+
702
+ end_time = time.time()
703
+ display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>>Predicted SQL:</div>
704
+ <div style='display: flex; align-items: center;'>
705
+ <div style='font-size: 3rem'>➡️</div>
706
+ <div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
707
+ </div>
708
+ """
709
+ # Create a new row as dataframe
710
+ new_row = pd.DataFrame([{
711
+ 'id': index,
712
+ 'question': question,
713
+ 'predicted_sql': prediction,
714
+ 'time': end_time - start_time,
715
+ 'query': row["query"],
716
+ 'db_path': input_data["data_path"],
717
+ 'price':price,
718
+ 'answer':answer,
719
+ 'number_question':count,
720
+ 'prompt' : prompt_to_send
721
+ }]).dropna(how="all") # Remove only completely empty rows
722
+ count=count+1
723
+ # TODO: use a for loop
724
+ for col in target_df.columns:
725
+ if col not in new_row.columns:
726
+ new_row[col] = row[col]
727
+
728
+ # Update model's prediction dataframe incrementally
729
+ if not new_row.empty:
730
+ predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
731
+
732
+ # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
733
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
734
 
735
  yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
736
  # END
 
843
  # fn=lambda: gr.update(open=True, visible=True),
844
  # outputs=[download_metrics]
845
  # )
846
+ def refresh():
847
+ global reset_flag
848
+ reset_flag = True
849
+
850
+ reset_data = gr.Button("Back to upload data section", interactive=True)
851
 
852
  metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data])
853
 
 
857
  fn=lambda: gr.update(visible=False),
858
  outputs=[download_metrics]
859
  )
860
+ reset_data.click(refresh)
861
 
862
  reset_data.click(
863
  fn=enable_disable,
 
2004
  model_multiselect_bar = gr.CheckboxGroup(
2005
  choices=models,
2006
  label="Select one or more models:",
2007
+ value=models,
2008
+ interactive=len(models) > 1
2009
  )
2010
 
2011
  group_radio = gr.Radio(
 
2038
  model_multiselect_radar = gr.CheckboxGroup(
2039
  choices=models,
2040
  label="Select one or more models:",
2041
+ value=models,
2042
+ interactive=len(models) > 1
2043
  )
2044
 
2045
  with gr.Row():
 
2070
  label="Select the metrics group that you want to use:",
2071
  value="Qatch"
2072
  )
2073
+ model_choices = list(all_model_as_dic.keys())
2074
+
2075
+ if len(model_choices) == 2:
2076
+ model_choices = [model_choices[0]] # supponiamo che il modello sia in prima posizione
2077
+ selected_value = model_choices[0]
2078
+ else:
2079
+ selected_value = "All"
2080
 
2081
  model_radio_ranking = gr.Radio(
2082
+ choices=model_choices,
2083
  label="Select the model that you want to use:",
2084
+ value=selected_value
2085
  )
2086
 
2087
  category_radio_ranking = gr.Radio(
 
2126
  model_multiselect_rate = gr.CheckboxGroup(
2127
  choices=models,
2128
  label="Select one or more models:",
2129
+ value=models,
2130
+ interactive=len(models) > 1
2131
  )
2132
 
2133