Spaces:
Sleeping
Sleeping
Table select, bug upload and metrics_bug
Browse files
app.py
CHANGED
|
@@ -40,6 +40,7 @@ function refresh() {
|
|
| 40 |
}
|
| 41 |
}
|
| 42 |
"""
|
|
|
|
| 43 |
|
| 44 |
with open('style.css', 'r') as file:
|
| 45 |
css = file.read()
|
|
@@ -80,8 +81,9 @@ def load_data(file, path, use_default):
|
|
| 80 |
try:
|
| 81 |
input_data["input_method"] = 'uploaded_file'
|
| 82 |
input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
|
|
|
|
| 83 |
#input_data["data_path"] = os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
|
| 84 |
-
input_data["data_path"] = f"{input_data['db_name']}.sqlite"
|
| 85 |
input_data["data"] = us.load_data(file, input_data["db_name"])
|
| 86 |
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
|
| 87 |
if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
|
|
@@ -295,10 +297,11 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 295 |
]
|
| 296 |
)
|
| 297 |
|
| 298 |
-
|
| 299 |
# TABLE SELECTION PART #
|
| 300 |
######################################
|
| 301 |
with select_table_acc:
|
|
|
|
| 302 |
table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
|
| 303 |
table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
|
| 304 |
selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
|
|
@@ -310,37 +313,69 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 310 |
"""Dynamically updates the list of available tables."""
|
| 311 |
if isinstance(data, dict) and data:
|
| 312 |
table_names = []
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
return gr.update(choices=table_names, value=[]) # Reset selections
|
| 316 |
-
return gr.update(choices=[], value=[])
|
| 317 |
|
|
|
|
|
|
|
| 318 |
def show_selected_tables(data, selected_tables):
|
| 319 |
-
"""Displays only the tables selected by the user and enables the button."""
|
| 320 |
updates = []
|
| 321 |
-
if isinstance(data, dict) and data
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
else:
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
for
|
| 332 |
-
updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
|
| 333 |
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
else:
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
-
# Enable/disable the button based on selections
|
| 341 |
-
button_state = bool(selected_tables) # True if at least one table is selected, False otherwise
|
| 342 |
-
updates.append(gr.update(interactive=button_state)) # Update button state
|
| 343 |
-
|
| 344 |
return updates
|
| 345 |
|
| 346 |
def show_selected_table_names(data, selected_tables):
|
|
@@ -357,8 +392,12 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 357 |
data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector])
|
| 358 |
|
| 359 |
# Updates the visible tables and the button state based on user selections
|
| 360 |
-
table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
# Shows the list of selected tables when "Choose your models" is clicked
|
| 363 |
open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names])
|
| 364 |
open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
|
|
@@ -568,6 +607,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 568 |
|
| 569 |
def qatch_flow():
|
| 570 |
#caching
|
|
|
|
| 571 |
predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
|
| 572 |
metrics_conc = pd.DataFrame()
|
| 573 |
columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
|
|
@@ -577,35 +617,36 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 577 |
target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])]
|
| 578 |
target_df = target_df[target_df["model"].isin(input_data['models'])]
|
| 579 |
predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list}
|
| 580 |
-
|
| 581 |
for model in input_data['models']:
|
| 582 |
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
|
| 583 |
yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 584 |
count=1
|
| 585 |
for _, row in predictions_dict[model].iterrows():
|
| 586 |
#for index, row in target_df.iterrows():
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
<div
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
<div style='
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
|
|
|
| 609 |
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 610 |
metrics_conc = target_df
|
| 611 |
if 'valid_efficiency_score' not in metrics_conc.columns:
|
|
@@ -622,74 +663,74 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 622 |
#target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
|
| 623 |
|
| 624 |
predictor = ModelPrediction()
|
| 625 |
-
|
| 626 |
for model in input_data["models"]:
|
| 627 |
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
|
| 628 |
yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 629 |
count=0
|
| 630 |
for index, row in target_df.iterrows():
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
|
| 694 |
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 695 |
# END
|
|
@@ -802,7 +843,11 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 802 |
# fn=lambda: gr.update(open=True, visible=True),
|
| 803 |
# outputs=[download_metrics]
|
| 804 |
# )
|
| 805 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
|
| 807 |
metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data])
|
| 808 |
|
|
@@ -812,6 +857,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 812 |
fn=lambda: gr.update(visible=False),
|
| 813 |
outputs=[download_metrics]
|
| 814 |
)
|
|
|
|
| 815 |
|
| 816 |
reset_data.click(
|
| 817 |
fn=enable_disable,
|
|
@@ -1958,7 +2004,8 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1958 |
model_multiselect_bar = gr.CheckboxGroup(
|
| 1959 |
choices=models,
|
| 1960 |
label="Select one or more models:",
|
| 1961 |
-
value=models
|
|
|
|
| 1962 |
)
|
| 1963 |
|
| 1964 |
group_radio = gr.Radio(
|
|
@@ -1991,7 +2038,8 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1991 |
model_multiselect_radar = gr.CheckboxGroup(
|
| 1992 |
choices=models,
|
| 1993 |
label="Select one or more models:",
|
| 1994 |
-
value=models
|
|
|
|
| 1995 |
)
|
| 1996 |
|
| 1997 |
with gr.Row():
|
|
@@ -2022,11 +2070,18 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 2022 |
label="Select the metrics group that you want to use:",
|
| 2023 |
value="Qatch"
|
| 2024 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2025 |
|
| 2026 |
model_radio_ranking = gr.Radio(
|
| 2027 |
-
choices=
|
| 2028 |
label="Select the model that you want to use:",
|
| 2029 |
-
value=
|
| 2030 |
)
|
| 2031 |
|
| 2032 |
category_radio_ranking = gr.Radio(
|
|
@@ -2071,7 +2126,8 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 2071 |
model_multiselect_rate = gr.CheckboxGroup(
|
| 2072 |
choices=models,
|
| 2073 |
label="Select one or more models:",
|
| 2074 |
-
value=models
|
|
|
|
| 2075 |
)
|
| 2076 |
|
| 2077 |
|
|
|
|
| 40 |
}
|
| 41 |
}
|
| 42 |
"""
|
| 43 |
+
reset_flag=False
|
| 44 |
|
| 45 |
with open('style.css', 'r') as file:
|
| 46 |
css = file.read()
|
|
|
|
| 81 |
try:
|
| 82 |
input_data["input_method"] = 'uploaded_file'
|
| 83 |
input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
|
| 84 |
+
#TODO if not sqlite
|
| 85 |
#input_data["data_path"] = os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
|
| 86 |
+
input_data["data_path"] = file #f"{input_data['db_name']}.sqlite"
|
| 87 |
input_data["data"] = us.load_data(file, input_data["db_name"])
|
| 88 |
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
|
| 89 |
if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
|
|
|
|
| 297 |
]
|
| 298 |
)
|
| 299 |
|
| 300 |
+
######################################
|
| 301 |
# TABLE SELECTION PART #
|
| 302 |
######################################
|
| 303 |
with select_table_acc:
|
| 304 |
+
previous_selection = gr.State([])
|
| 305 |
table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
|
| 306 |
table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
|
| 307 |
selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
|
|
|
|
| 313 |
"""Dynamically updates the list of available tables."""
|
| 314 |
if isinstance(data, dict) and data:
|
| 315 |
table_names = []
|
| 316 |
+
|
| 317 |
+
if input_data['input_method'] == "default":
|
| 318 |
+
table_names.append("All")
|
| 319 |
+
|
| 320 |
+
elif len(data) < 6:
|
| 321 |
+
table_names.append("All") # In caso ci siano poche tabelle, ha senso mantenere "All"
|
| 322 |
+
|
| 323 |
+
table_names.extend(data.keys())
|
| 324 |
return gr.update(choices=table_names, value=[]) # Reset selections
|
|
|
|
| 325 |
|
| 326 |
+
return gr.update(choices=[], value=[])
|
| 327 |
+
|
| 328 |
def show_selected_tables(data, selected_tables):
|
|
|
|
| 329 |
updates = []
|
| 330 |
+
available_tables = list(data.keys()) if isinstance(data, dict) and data else []
|
| 331 |
+
input_method = input_data['input_method']
|
| 332 |
+
|
| 333 |
+
allow_all = input_method == "default" or len(available_tables) < 6
|
| 334 |
+
selected_set = set(selected_tables)
|
| 335 |
+
tables_set = set(available_tables)
|
| 336 |
+
|
| 337 |
+
# ▶️
|
| 338 |
+
if allow_all:
|
| 339 |
+
if "All" in selected_set:
|
| 340 |
+
selected_tables = ["All"] + available_tables
|
| 341 |
+
elif selected_set == tables_set:
|
| 342 |
+
selected_tables = []
|
| 343 |
else:
|
| 344 |
+
#
|
| 345 |
+
selected_tables = [t for t in selected_tables if t in available_tables]
|
| 346 |
+
else:
|
| 347 |
+
#
|
| 348 |
+
selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5]
|
|
|
|
| 349 |
|
| 350 |
+
#
|
| 351 |
+
tables = {name: data[name] for name in selected_tables if name in data}
|
| 352 |
+
|
| 353 |
+
for i, (name, df) in enumerate(tables.items()):
|
| 354 |
+
updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
|
| 355 |
+
for _ in range(len(tables), 50):
|
| 356 |
+
updates.append(gr.update(visible=False))
|
| 357 |
+
|
| 358 |
+
# ✅ Bottone abilitato solo se c'è almeno una tabella valida
|
| 359 |
+
updates.append(gr.update(interactive=bool(tables)))
|
| 360 |
+
|
| 361 |
+
# 🔄 Aggiorna la CheckboxGroup con logica coerente
|
| 362 |
+
if allow_all:
|
| 363 |
+
updates.insert(0, gr.update(
|
| 364 |
+
choices=["All"] + available_tables,
|
| 365 |
+
value=selected_tables
|
| 366 |
+
))
|
| 367 |
else:
|
| 368 |
+
if len(selected_tables) >= 5:
|
| 369 |
+
updates.insert(0, gr.update(
|
| 370 |
+
choices=selected_tables,
|
| 371 |
+
value=selected_tables
|
| 372 |
+
))
|
| 373 |
+
else:
|
| 374 |
+
updates.insert(0, gr.update(
|
| 375 |
+
choices=available_tables,
|
| 376 |
+
value=selected_tables
|
| 377 |
+
))
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
return updates
|
| 380 |
|
| 381 |
def show_selected_table_names(data, selected_tables):
|
|
|
|
| 392 |
data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector])
|
| 393 |
|
| 394 |
# Updates the visible tables and the button state based on user selections
|
| 395 |
+
#table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
|
| 396 |
+
table_selector.change(
|
| 397 |
+
fn=show_selected_tables,
|
| 398 |
+
inputs=[data_state, table_selector],
|
| 399 |
+
outputs=[table_selector] + table_outputs + [open_model_selection]
|
| 400 |
+
)
|
| 401 |
# Shows the list of selected tables when "Choose your models" is clicked
|
| 402 |
open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names])
|
| 403 |
open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
|
|
|
|
| 607 |
|
| 608 |
def qatch_flow():
|
| 609 |
#caching
|
| 610 |
+
global reset_flag
|
| 611 |
predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
|
| 612 |
metrics_conc = pd.DataFrame()
|
| 613 |
columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
|
|
|
|
| 617 |
target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])]
|
| 618 |
target_df = target_df[target_df["model"].isin(input_data['models'])]
|
| 619 |
predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list}
|
| 620 |
+
reset_flag = False
|
| 621 |
for model in input_data['models']:
|
| 622 |
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
|
| 623 |
yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 624 |
count=1
|
| 625 |
for _, row in predictions_dict[model].iterrows():
|
| 626 |
#for index, row in target_df.iterrows():
|
| 627 |
+
if (reset_flag == False):
|
| 628 |
+
percent_complete = round(count / len(predictions_dict[model]) * 100, 2)
|
| 629 |
+
count=count+1
|
| 630 |
+
load_text = f"{generate_loading_text(percent_complete)}"
|
| 631 |
+
question = row['question']
|
| 632 |
+
|
| 633 |
+
display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Natural Language:</div>
|
| 634 |
+
<div style='display: flex; align-items: center;'>
|
| 635 |
+
<div class='sqlquery' font-family: 'Inter', sans-serif;>{question}</div>
|
| 636 |
+
<div style='font-size: 3rem'>➡️</div>
|
| 637 |
+
</div>
|
| 638 |
+
"""
|
| 639 |
+
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 640 |
+
#time.sleep(0.02)
|
| 641 |
+
prediction = row['predicted_sql']
|
| 642 |
+
|
| 643 |
+
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
|
| 644 |
+
<div style='display: flex; align-items: center;'>
|
| 645 |
+
<div style='font-size: 3rem'>➡️</div>
|
| 646 |
+
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
| 647 |
+
</div>
|
| 648 |
+
"""
|
| 649 |
+
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 650 |
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 651 |
metrics_conc = target_df
|
| 652 |
if 'valid_efficiency_score' not in metrics_conc.columns:
|
|
|
|
| 663 |
#target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
|
| 664 |
|
| 665 |
predictor = ModelPrediction()
|
| 666 |
+
reset_flag = False
|
| 667 |
for model in input_data["models"]:
|
| 668 |
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
|
| 669 |
yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 670 |
count=0
|
| 671 |
for index, row in target_df.iterrows():
|
| 672 |
+
if (reset_flag == False):
|
| 673 |
+
percent_complete = round(((index+1) / len(target_df)) * 100, 2)
|
| 674 |
+
load_text = f"{generate_loading_text(percent_complete)}"
|
| 675 |
+
|
| 676 |
+
question = row['question']
|
| 677 |
+
display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Natural Language:</div>
|
| 678 |
+
<div style='display: flex; align-items: center;'>
|
| 679 |
+
<div class='sqlquery' font-family: 'Inter', sans-serif;>{question}</div>
|
| 680 |
+
<div style='font-size: 3rem'>➡️</div>
|
| 681 |
+
</div>
|
| 682 |
+
"""
|
| 683 |
+
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
|
| 684 |
+
start_time = time.time()
|
| 685 |
+
samples = us.generate_some_samples(input_data['data']['db'], row["tbl_name"])
|
| 686 |
+
|
| 687 |
+
schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
| 688 |
+
db_id = input_data["db_name"],
|
| 689 |
+
base_path = input_data["data_path"],
|
| 690 |
+
normalize=False,
|
| 691 |
+
sql=row["query"]
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
|
| 695 |
+
#PREDICTION SQL
|
| 696 |
+
|
| 697 |
+
response = predictor.make_prediction(question=question, db_schema=schema_text, model_name=model, prompt=f"{prompt_to_send}")
|
| 698 |
+
prediction = response['response_parsed']
|
| 699 |
+
price = response['cost']
|
| 700 |
+
answer = response['response']
|
| 701 |
+
|
| 702 |
+
end_time = time.time()
|
| 703 |
+
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>>Predicted SQL:</div>
|
| 704 |
+
<div style='display: flex; align-items: center;'>
|
| 705 |
+
<div style='font-size: 3rem'>➡️</div>
|
| 706 |
+
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
| 707 |
+
</div>
|
| 708 |
+
"""
|
| 709 |
+
# Create a new row as dataframe
|
| 710 |
+
new_row = pd.DataFrame([{
|
| 711 |
+
'id': index,
|
| 712 |
+
'question': question,
|
| 713 |
+
'predicted_sql': prediction,
|
| 714 |
+
'time': end_time - start_time,
|
| 715 |
+
'query': row["query"],
|
| 716 |
+
'db_path': input_data["data_path"],
|
| 717 |
+
'price':price,
|
| 718 |
+
'answer':answer,
|
| 719 |
+
'number_question':count,
|
| 720 |
+
'prompt' : prompt_to_send
|
| 721 |
+
}]).dropna(how="all") # Remove only completely empty rows
|
| 722 |
+
count=count+1
|
| 723 |
+
# TODO: use a for loop
|
| 724 |
+
for col in target_df.columns:
|
| 725 |
+
if col not in new_row.columns:
|
| 726 |
+
new_row[col] = row[col]
|
| 727 |
+
|
| 728 |
+
# Update model's prediction dataframe incrementally
|
| 729 |
+
if not new_row.empty:
|
| 730 |
+
predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
|
| 731 |
+
|
| 732 |
+
# yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
|
| 733 |
+
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
|
| 734 |
|
| 735 |
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 736 |
# END
|
|
|
|
| 843 |
# fn=lambda: gr.update(open=True, visible=True),
|
| 844 |
# outputs=[download_metrics]
|
| 845 |
# )
|
| 846 |
+
def refresh():
|
| 847 |
+
global reset_flag
|
| 848 |
+
reset_flag = True
|
| 849 |
+
|
| 850 |
+
reset_data = gr.Button("Back to upload data section", interactive=True)
|
| 851 |
|
| 852 |
metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data])
|
| 853 |
|
|
|
|
| 857 |
fn=lambda: gr.update(visible=False),
|
| 858 |
outputs=[download_metrics]
|
| 859 |
)
|
| 860 |
+
reset_data.click(refresh)
|
| 861 |
|
| 862 |
reset_data.click(
|
| 863 |
fn=enable_disable,
|
|
|
|
| 2004 |
model_multiselect_bar = gr.CheckboxGroup(
|
| 2005 |
choices=models,
|
| 2006 |
label="Select one or more models:",
|
| 2007 |
+
value=models,
|
| 2008 |
+
interactive=len(models) > 1
|
| 2009 |
)
|
| 2010 |
|
| 2011 |
group_radio = gr.Radio(
|
|
|
|
| 2038 |
model_multiselect_radar = gr.CheckboxGroup(
|
| 2039 |
choices=models,
|
| 2040 |
label="Select one or more models:",
|
| 2041 |
+
value=models,
|
| 2042 |
+
interactive=len(models) > 1
|
| 2043 |
)
|
| 2044 |
|
| 2045 |
with gr.Row():
|
|
|
|
| 2070 |
label="Select the metrics group that you want to use:",
|
| 2071 |
value="Qatch"
|
| 2072 |
)
|
| 2073 |
+
model_choices = list(all_model_as_dic.keys())
|
| 2074 |
+
|
| 2075 |
+
if len(model_choices) == 2:
|
| 2076 |
+
model_choices = [model_choices[0]] # supponiamo che il modello sia in prima posizione
|
| 2077 |
+
selected_value = model_choices[0]
|
| 2078 |
+
else:
|
| 2079 |
+
selected_value = "All"
|
| 2080 |
|
| 2081 |
model_radio_ranking = gr.Radio(
|
| 2082 |
+
choices=model_choices,
|
| 2083 |
label="Select the model that you want to use:",
|
| 2084 |
+
value=selected_value
|
| 2085 |
)
|
| 2086 |
|
| 2087 |
category_radio_ranking = gr.Radio(
|
|
|
|
| 2126 |
model_multiselect_rate = gr.CheckboxGroup(
|
| 2127 |
choices=models,
|
| 2128 |
label="Select one or more models:",
|
| 2129 |
+
value=models,
|
| 2130 |
+
interactive=len(models) > 1
|
| 2131 |
)
|
| 2132 |
|
| 2133 |
|