Spaces:
Sleeping
Sleeping
Add TQA task (#21)
Browse files- Add TQA task (989a40afe7aa0221b4790e76f837e1eecf82614a)
Co-authored-by: Francesco Giannuzzo <[email protected]>
- app.py +196 -146
- concatenated_output.csv +1 -1
- utilities.py +97 -1
- utils_get_db_tables_info.py +12 -6
app.py
CHANGED
|
@@ -12,6 +12,7 @@ import plotly.colors as pc
|
|
| 12 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
| 13 |
from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
|
| 14 |
from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
|
|
|
|
| 15 |
from prediction import ModelPrediction
|
| 16 |
import utils_get_db_tables_info
|
| 17 |
import utilities as us
|
|
@@ -31,7 +32,6 @@ import utilities as us
|
|
| 31 |
#pnp_path = os.path.join("data", "evaluation_p_np_metrics.csv")
|
| 32 |
pnp_path = "concatenated_output.csv"
|
| 33 |
PATH_PKL_TABLES = 'tables_dict_beaver.pkl'
|
| 34 |
-
|
| 35 |
js_func = """
|
| 36 |
function refresh() {
|
| 37 |
const url = new URL(window.location);
|
|
@@ -42,7 +42,8 @@ function refresh() {
|
|
| 42 |
}
|
| 43 |
}
|
| 44 |
"""
|
| 45 |
-
reset_flag=False
|
|
|
|
| 46 |
|
| 47 |
with open('style.css', 'r') as file:
|
| 48 |
css = file.read()
|
|
@@ -65,6 +66,8 @@ description = """## 📊 Comparison of Proprietary and Non-Proprietary Databases
|
|
| 65 |
### ➤ **Non-Proprietary**
|
| 66 |
###     ⇒ Spider 1.0 🕷️"""
|
| 67 |
prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n"
|
|
|
|
|
|
|
| 68 |
|
| 69 |
input_data = {
|
| 70 |
'input_method': "",
|
|
@@ -93,6 +96,7 @@ def load_data(file, path, use_default):
|
|
| 93 |
#change path
|
| 94 |
input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
|
| 95 |
input_data["data"] = us.load_data(file, input_data["db_name"])
|
|
|
|
| 96 |
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
|
| 97 |
if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
|
| 98 |
table2primary_key = {}
|
|
@@ -317,7 +321,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 317 |
|
| 318 |
# Model selection button (initially disabled)
|
| 319 |
open_model_selection = gr.Button("Choose your models", interactive=False)
|
| 320 |
-
|
| 321 |
def update_table_list(data):
|
| 322 |
"""Dynamically updates the list of available tables and excluded ones."""
|
| 323 |
if isinstance(data, dict) and data:
|
|
@@ -458,9 +461,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 458 |
default_checkbox
|
| 459 |
]
|
| 460 |
)
|
| 461 |
-
|
| 462 |
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
|
| 463 |
-
|
| 464 |
|
| 465 |
####################################
|
| 466 |
# MODEL SELECTION PART #
|
|
@@ -506,10 +507,9 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 506 |
# Function to get selected models
|
| 507 |
def get_selected_models(*model_selections):
|
| 508 |
selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
|
| 509 |
-
|
| 510 |
input_data['models'] = selected_models
|
| 511 |
button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
| 512 |
-
return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
|
| 513 |
|
| 514 |
# Add the Textbox to the interface
|
| 515 |
prompt = gr.TextArea(
|
|
@@ -517,17 +517,19 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 517 |
placeholder=prompt_default,
|
| 518 |
elem_id="custom-textarea"
|
| 519 |
)
|
|
|
|
| 520 |
warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False)
|
| 521 |
|
| 522 |
# Submit button (initially disabled)
|
| 523 |
-
|
| 524 |
-
|
|
|
|
| 525 |
|
| 526 |
def check_prompt(prompt):
|
| 527 |
#TODO
|
| 528 |
missing_elements = []
|
| 529 |
if(prompt==""):
|
| 530 |
-
input_data["prompt"]=prompt_default
|
| 531 |
button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
| 532 |
else:
|
| 533 |
input_data["prompt"]=prompt
|
|
@@ -544,18 +546,18 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 544 |
), gr.update(interactive=button_state)
|
| 545 |
return gr.update(visible=False), gr.update(interactive=button_state)
|
| 546 |
|
| 547 |
-
prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button])
|
| 548 |
# Link checkboxes to selection events
|
| 549 |
for checkbox in model_checkboxes:
|
| 550 |
checkbox.change(
|
| 551 |
fn=get_selected_models,
|
| 552 |
inputs=model_checkboxes,
|
| 553 |
-
outputs=[selected_models_output, select_model_acc, submit_models_button]
|
| 554 |
)
|
| 555 |
prompt.change(
|
| 556 |
fn=get_selected_models,
|
| 557 |
inputs=model_checkboxes,
|
| 558 |
-
outputs=[selected_models_output, select_model_acc, submit_models_button]
|
| 559 |
)
|
| 560 |
|
| 561 |
submit_models_button.click(
|
|
@@ -564,6 +566,17 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 564 |
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
| 565 |
)
|
| 566 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
def enable_disable(enable):
|
| 568 |
return (
|
| 569 |
*[gr.update(interactive=enable) for _ in model_checkboxes],
|
|
@@ -574,6 +587,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 574 |
gr.update(interactive=enable),
|
| 575 |
gr.update(interactive=enable),
|
| 576 |
*[gr.update(interactive=enable) for _ in table_outputs],
|
|
|
|
| 577 |
gr.update(interactive=enable)
|
| 578 |
)
|
| 579 |
|
|
@@ -591,7 +605,24 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 591 |
default_checkbox,
|
| 592 |
table_selector,
|
| 593 |
*table_outputs,
|
| 594 |
-
open_model_selection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
]
|
| 596 |
)
|
| 597 |
|
|
@@ -609,7 +640,8 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 609 |
default_checkbox,
|
| 610 |
table_selector,
|
| 611 |
*table_outputs,
|
| 612 |
-
open_model_selection
|
|
|
|
| 613 |
]
|
| 614 |
)
|
| 615 |
|
|
@@ -660,9 +692,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 660 |
{mirrored_symbols}
|
| 661 |
</div>
|
| 662 |
"""
|
| 663 |
-
|
| 664 |
-
|
| 665 |
global reset_flag
|
|
|
|
| 666 |
predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
|
| 667 |
metrics_conc = pd.DataFrame()
|
| 668 |
columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
|
|
@@ -692,7 +725,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 692 |
</div>
|
| 693 |
"""
|
| 694 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 695 |
-
|
| 696 |
prediction = row['predicted_sql']
|
| 697 |
|
| 698 |
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
|
|
@@ -700,22 +733,25 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 700 |
<div style='font-size: 3rem'>➡️</div>
|
| 701 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
| 702 |
</div>
|
| 703 |
-
"""
|
|
|
|
| 704 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 705 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 706 |
metrics_conc = target_df
|
| 707 |
-
if '
|
| 708 |
-
metrics_conc['
|
| 709 |
eval_text = generate_eval_text("End evaluation")
|
| 710 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
|
|
|
| 711 |
else:
|
| 712 |
-
|
| 713 |
orchestrator_generator = OrchestratorGenerator()
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
#
|
| 717 |
-
|
| 718 |
-
|
|
|
|
|
|
|
| 719 |
|
| 720 |
predictor = ModelPrediction()
|
| 721 |
reset_flag = False
|
|
@@ -736,15 +772,18 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 736 |
</div>
|
| 737 |
"""
|
| 738 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
|
|
|
|
|
|
|
|
|
| 748 |
)
|
| 749 |
|
| 750 |
#prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
|
|
@@ -752,19 +791,27 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 752 |
#PREDICTION SQL
|
| 753 |
|
| 754 |
# TODO add button for QA or SP and pass to .make_prediction parameter TASK
|
|
|
|
|
|
|
|
|
|
| 755 |
response = predictor.make_prediction(
|
| 756 |
question=question,
|
| 757 |
-
db_schema=
|
| 758 |
model_name=model,
|
| 759 |
prompt=f"{prompt_to_send}",
|
| 760 |
-
task=
|
| 761 |
)
|
| 762 |
prediction = response['response_parsed']
|
| 763 |
price = response['cost']
|
| 764 |
answer = response['response']
|
| 765 |
|
| 766 |
end_time = time.time()
|
| 767 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 768 |
<div style='display: flex; align-items: center;'>
|
| 769 |
<div style='font-size: 3rem'>➡️</div>
|
| 770 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
|
@@ -779,40 +826,47 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 779 |
'query': row["query"],
|
| 780 |
'db_path': input_data["data_path"],
|
| 781 |
'price':price,
|
| 782 |
-
'answer':answer,
|
| 783 |
'number_question':count,
|
| 784 |
-
'
|
|
|
|
| 785 |
}]).dropna(how="all") # Remove only completely empty rows
|
| 786 |
count=count+1
|
| 787 |
# TODO: use a for loop
|
|
|
|
|
|
|
| 788 |
for col in target_df.columns:
|
| 789 |
if col not in new_row.columns:
|
| 790 |
new_row[col] = row[col]
|
| 791 |
-
|
| 792 |
# Update model's prediction dataframe incrementally
|
| 793 |
if not new_row.empty:
|
| 794 |
predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
|
| 795 |
|
| 796 |
# yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
|
| 797 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
|
| 798 |
-
|
| 799 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 800 |
# END
|
| 801 |
eval_text = generate_eval_text("Evaluation")
|
| 802 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
|
|
|
| 803 |
evaluator = OrchestratorEvaluator()
|
|
|
|
| 804 |
for model in input_data["models"]:
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
|
|
|
|
|
|
|
|
|
| 811 |
metrics_df_model['model'] = model
|
| 812 |
metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
|
| 813 |
-
|
| 814 |
-
if '
|
| 815 |
-
metrics_conc['
|
|
|
|
| 816 |
eval_text = generate_eval_text("End evaluation")
|
| 817 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 818 |
|
|
@@ -848,6 +902,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 848 |
gr.Markdown(f"**Results for {model}**")
|
| 849 |
tab_dict[model] = tab
|
| 850 |
dataframe_per_model[model] = gr.DataFrame()
|
|
|
|
| 851 |
# download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
|
| 852 |
|
| 853 |
evaluation_loading = gr.Markdown()
|
|
@@ -860,13 +915,24 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 860 |
inputs=[],
|
| 861 |
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
|
| 862 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 863 |
|
| 864 |
selected_models_display = gr.JSON(label="Final input data", visible=False)
|
| 865 |
metrics_df = gr.DataFrame(visible=False)
|
| 866 |
metrics_df_out = gr.DataFrame(visible=False)
|
| 867 |
|
| 868 |
submit_models_button.click(
|
| 869 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 870 |
inputs=[],
|
| 871 |
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
| 872 |
)
|
|
@@ -875,6 +941,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 875 |
fn=lambda: gr.update(value=input_data),
|
| 876 |
outputs=[selected_models_display]
|
| 877 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
|
| 879 |
# Works for METRICS
|
| 880 |
metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
|
|
@@ -897,10 +967,16 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 897 |
fn=lambda: gr.update(visible=False),
|
| 898 |
outputs=[download_metrics]
|
| 899 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 900 |
|
| 901 |
def refresh():
|
| 902 |
global reset_flag
|
|
|
|
| 903 |
reset_flag = True
|
|
|
|
| 904 |
|
| 905 |
reset_data = gr.Button("Back to upload data section", interactive=True)
|
| 906 |
|
|
@@ -926,10 +1002,12 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 926 |
default_checkbox,
|
| 927 |
table_selector,
|
| 928 |
*table_outputs,
|
| 929 |
-
open_model_selection
|
|
|
|
| 930 |
]
|
| 931 |
)
|
| 932 |
-
|
|
|
|
| 933 |
##########################################
|
| 934 |
# METRICS VISUALIZATION SECTION #
|
| 935 |
##########################################
|
|
@@ -944,8 +1022,9 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 944 |
####################################
|
| 945 |
|
| 946 |
def load_data_csv_es():
|
| 947 |
-
|
| 948 |
if input_data["input_method"]=="default":
|
|
|
|
| 949 |
df = pd.read_csv(pnp_path)
|
| 950 |
df = df[df['model'].isin(input_data["models"])]
|
| 951 |
df = df[df['tbl_name'].isin(input_data["data"]["selected_tables"])]
|
|
@@ -956,6 +1035,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 956 |
df['model'] = df['model'].replace('llama-70', 'Llama-70B')
|
| 957 |
df['model'] = df['model'].replace('llama-8', 'Llama-8B')
|
| 958 |
df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY')
|
|
|
|
| 959 |
return df
|
| 960 |
return metrics_df_out
|
| 961 |
|
|
@@ -998,20 +1078,21 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 998 |
|
| 999 |
DB_CATEGORY_COLORS = generate_db_category_colors()
|
| 1000 |
|
| 1001 |
-
def
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
min_val = df['valid_efficiency_score'].min()
|
| 1007 |
-
max_val = df['valid_efficiency_score'].max()
|
| 1008 |
|
| 1009 |
-
if min_val == max_val:
|
| 1010 |
-
|
| 1011 |
-
|
|
|
|
|
|
|
|
|
|
| 1012 |
else:
|
| 1013 |
-
df['
|
| 1014 |
-
df['
|
| 1015 |
) / (max_val - min_val)
|
| 1016 |
|
| 1017 |
return df
|
|
@@ -1024,7 +1105,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1024 |
# BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
|
| 1025 |
def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models):
|
| 1026 |
df = df[df['model'].isin(selected_models)]
|
| 1027 |
-
df =
|
| 1028 |
|
| 1029 |
# Mappatura nomi leggibili -> tecnici
|
| 1030 |
qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
|
|
@@ -1141,7 +1222,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1141 |
selected_models = [selected_models]
|
| 1142 |
|
| 1143 |
df = df[df['model'].isin(selected_models)]
|
| 1144 |
-
df =
|
| 1145 |
|
| 1146 |
# Converti nomi leggibili -> tecnici
|
| 1147 |
qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
|
|
@@ -1226,54 +1307,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1226 |
)
|
| 1227 |
|
| 1228 |
return gr.Plot(fig, visible=True)
|
| 1229 |
-
|
| 1230 |
-
"""
|
| 1231 |
-
def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
|
| 1232 |
-
if selected_models == "All":
|
| 1233 |
-
selected_models = models
|
| 1234 |
-
else:
|
| 1235 |
-
selected_models = [selected_models]
|
| 1236 |
-
|
| 1237 |
-
df = df[df['model'].isin(selected_models)]
|
| 1238 |
-
df = normalize_valid_efficiency_score(df)
|
| 1239 |
-
|
| 1240 |
-
if radio_metric == "Qatch":
|
| 1241 |
-
selected_metrics = qatch_selected_metrics
|
| 1242 |
-
else:
|
| 1243 |
-
selected_metrics = external_selected_metric
|
| 1244 |
-
|
| 1245 |
-
df = calculate_average_metrics(df, selected_metrics)
|
| 1246 |
-
|
| 1247 |
-
# Raggruppamento per modello e categoria
|
| 1248 |
-
avg_metrics = df.groupby(["model", "db_category"])['avg_metric'].mean().reset_index()
|
| 1249 |
-
avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
|
| 1250 |
-
|
| 1251 |
-
# Plot orizzontale con modello sull'asse Y
|
| 1252 |
-
fig = px.bar(
|
| 1253 |
-
avg_metrics,
|
| 1254 |
-
x='avg_metric',
|
| 1255 |
-
y='model',
|
| 1256 |
-
color='db_category', # categoria come colore
|
| 1257 |
-
text='text_label',
|
| 1258 |
-
barmode='group',
|
| 1259 |
-
orientation='h',
|
| 1260 |
-
color_discrete_map=DB_CATEGORY_COLORS, # devi avere questo dict come MODEL_COLORS
|
| 1261 |
-
title='Average metric per model and db_category 📊',
|
| 1262 |
-
labels={'avg_metric': 'AVG Metric', 'model': 'Model'},
|
| 1263 |
-
template='plotly_dark'
|
| 1264 |
-
)
|
| 1265 |
-
|
| 1266 |
-
fig.update_traces(textposition='outside', textfont_size=10)
|
| 1267 |
-
fig.update_layout(
|
| 1268 |
-
margin=dict(t=80),
|
| 1269 |
-
yaxis=dict(title=''),
|
| 1270 |
-
xaxis=dict(title='AVG Metrics'),
|
| 1271 |
-
legend_title='DB Name',
|
| 1272 |
-
height=600 # puoi aumentare se ci sono tanti modelli
|
| 1273 |
-
)
|
| 1274 |
-
|
| 1275 |
-
return gr.Plot(fig, visible=True)
|
| 1276 |
-
"""
|
| 1277 |
|
| 1278 |
def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
|
| 1279 |
df = load_data_csv_es()
|
|
@@ -1289,7 +1322,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1289 |
df = df[df['db_category'].isin(target_cats)]
|
| 1290 |
df = df[df['model'].isin(selected_models)]
|
| 1291 |
|
| 1292 |
-
df =
|
| 1293 |
df = calculate_average_metrics(df, qatch_metrics)
|
| 1294 |
|
| 1295 |
# Calcola la media per db_category e modello
|
|
@@ -1410,14 +1443,14 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1410 |
|
| 1411 |
# RADAR OR BAR CHART BASED ON CATEGORY COUNT
|
| 1412 |
def plot_radar(df, selected_models, selected_metrics, selected_categories):
|
| 1413 |
-
if "
|
| 1414 |
-
selected_metrics = ["execution_accuracy", "
|
| 1415 |
else:
|
| 1416 |
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
|
| 1417 |
|
| 1418 |
# Filtro modelli e normalizzazione
|
| 1419 |
df = df[df['model'].isin(selected_models)]
|
| 1420 |
-
df =
|
| 1421 |
df = calculate_average_metrics(df, selected_metrics)
|
| 1422 |
|
| 1423 |
avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
|
|
@@ -1574,13 +1607,13 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1574 |
|
| 1575 |
# RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT
|
| 1576 |
def plot_radar_sub(df, selected_models, selected_metrics, selected_category):
|
| 1577 |
-
if "
|
| 1578 |
-
selected_metrics = ["execution_accuracy", "
|
| 1579 |
else:
|
| 1580 |
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
|
| 1581 |
|
| 1582 |
df = df[df['model'].isin(selected_models)]
|
| 1583 |
-
df =
|
| 1584 |
df = calculate_average_metrics(df, selected_metrics)
|
| 1585 |
|
| 1586 |
if isinstance(selected_category, str):
|
|
@@ -1743,6 +1776,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1743 |
|
| 1744 |
# RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
|
| 1745 |
def worst_cases_text(df, selected_models, selected_metrics, selected_categories):
|
|
|
|
| 1746 |
if selected_models == "All":
|
| 1747 |
selected_models = models
|
| 1748 |
else:
|
|
@@ -1757,15 +1791,25 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1757 |
df = df[df['test_category'].isin(selected_categories)]
|
| 1758 |
|
| 1759 |
if "external" in selected_metrics:
|
| 1760 |
-
selected_metrics = ["execution_accuracy", "
|
| 1761 |
else:
|
| 1762 |
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
|
| 1763 |
|
| 1764 |
-
df =
|
| 1765 |
df = calculate_average_metrics(df, selected_metrics)
|
| 1766 |
-
|
| 1767 |
-
worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
|
| 1768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1769 |
worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
|
| 1770 |
|
| 1771 |
worst_cases_top_3 = worst_cases_df.head(3)
|
|
@@ -1778,14 +1822,24 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1778 |
medals = ["🥇", "🥈", "🥉"]
|
| 1779 |
|
| 1780 |
for i, row in worst_cases_top_3.iterrows():
|
| 1781 |
-
|
| 1782 |
-
|
| 1783 |
-
|
| 1784 |
-
|
| 1785 |
-
|
| 1786 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1787 |
|
| 1788 |
-
|
| 1789 |
|
| 1790 |
raw_answer = (
|
| 1791 |
f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
|
|
@@ -1793,7 +1847,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1793 |
)
|
| 1794 |
|
| 1795 |
answer_str.append(raw_answer)
|
| 1796 |
-
|
| 1797 |
return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2]
|
| 1798 |
|
| 1799 |
def update_worst_cases_text(selected_models, selected_metrics, selected_categories):
|
|
@@ -1803,7 +1857,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1803 |
# LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
|
| 1804 |
def plot_cumulative_flow(df, selected_models, max_points):
|
| 1805 |
df = df[df['model'].isin(selected_models)]
|
| 1806 |
-
df =
|
| 1807 |
|
| 1808 |
fig = go.Figure()
|
| 1809 |
|
|
@@ -1937,10 +1991,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1937 |
|
| 1938 |
external_metrics_dict = {
|
| 1939 |
"Execution Accuracy": "execution_accuracy",
|
| 1940 |
-
"Valid
|
| 1941 |
}
|
| 1942 |
|
| 1943 |
-
external_metric = ["execution_accuracy", "
|
| 1944 |
last_valid_external_metric_selection = external_metric.copy()
|
| 1945 |
def enforce_external_metric_selection(selected):
|
| 1946 |
global last_valid_external_metric_selection
|
|
@@ -1987,10 +2041,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1987 |
|
| 1988 |
all_model_as_dic = {cat: [f"{cat}"] for cat in models}
|
| 1989 |
all_model_as_dic["All"] = models
|
| 1990 |
-
|
| 1991 |
-
#with gr.Blocks(theme=gr.themes.Default(primary_hue='blue')) as demo:
|
| 1992 |
-
|
| 1993 |
-
|
| 1994 |
|
| 1995 |
###########################
|
| 1996 |
# VISUALIZATION SECTION #
|
|
@@ -2029,7 +2079,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 2029 |
<span
|
| 2030 |
title="External metric info:
|
| 2031 |
Execution Accuracy: Checks if the predicted query returns exactly the same result as the ground truth query when executed. It is a binary metric: 1 if the output matches, 0 otherwise.
|
| 2032 |
-
Valid
|
| 2033 |
style="margin-left: 6px; cursor: help; color: #00bfff; font-size: 16px; white-space: pre-line;"
|
| 2034 |
>External metric info ℹ️</span>
|
| 2035 |
</div>
|
|
@@ -2304,6 +2354,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 2304 |
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
|
| 2305 |
reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics])
|
| 2306 |
reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
|
| 2307 |
-
|
| 2308 |
|
| 2309 |
interface.launch(share = True)
|
|
|
|
| 12 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
| 13 |
from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
|
| 14 |
from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
|
| 15 |
+
import qatch.evaluate_dataset.orchestrator_evaluator as eva
|
| 16 |
from prediction import ModelPrediction
|
| 17 |
import utils_get_db_tables_info
|
| 18 |
import utilities as us
|
|
|
|
| 32 |
#pnp_path = os.path.join("data", "evaluation_p_np_metrics.csv")
|
| 33 |
pnp_path = "concatenated_output.csv"
|
| 34 |
PATH_PKL_TABLES = 'tables_dict_beaver.pkl'
|
|
|
|
| 35 |
js_func = """
|
| 36 |
function refresh() {
|
| 37 |
const url = new URL(window.location);
|
|
|
|
| 42 |
}
|
| 43 |
}
|
| 44 |
"""
|
| 45 |
+
reset_flag = False
|
| 46 |
+
flag_TQA = False
|
| 47 |
|
| 48 |
with open('style.css', 'r') as file:
|
| 49 |
css = file.read()
|
|
|
|
| 66 |
### ➤ **Non-Proprietary**
|
| 67 |
###     ⇒ Spider 1.0 🕷️"""
|
| 68 |
prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n"
|
| 69 |
+
prompt_default_tqa = "Return the answer of the following question based on the provided database. Return your answer as the result of a query executed over the database. Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n Return the answer in answer tag as <answer> </answer>.\n Question \n {question}\n Database Schema\n {db_schema}\n"
|
| 70 |
+
|
| 71 |
|
| 72 |
input_data = {
|
| 73 |
'input_method': "",
|
|
|
|
| 96 |
#change path
|
| 97 |
input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
|
| 98 |
input_data["data"] = us.load_data(file, input_data["db_name"])
|
| 99 |
+
|
| 100 |
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
|
| 101 |
if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
|
| 102 |
table2primary_key = {}
|
|
|
|
| 321 |
|
| 322 |
# Model selection button (initially disabled)
|
| 323 |
open_model_selection = gr.Button("Choose your models", interactive=False)
|
|
|
|
| 324 |
def update_table_list(data):
|
| 325 |
"""Dynamically updates the list of available tables and excluded ones."""
|
| 326 |
if isinstance(data, dict) and data:
|
|
|
|
| 461 |
default_checkbox
|
| 462 |
]
|
| 463 |
)
|
|
|
|
| 464 |
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
|
|
|
|
| 465 |
|
| 466 |
####################################
|
| 467 |
# MODEL SELECTION PART #
|
|
|
|
| 507 |
# Function to get selected models
|
| 508 |
def get_selected_models(*model_selections):
|
| 509 |
selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
|
|
|
|
| 510 |
input_data['models'] = selected_models
|
| 511 |
button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
| 512 |
+
return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state), gr.update(interactive=button_state)
|
| 513 |
|
| 514 |
# Add the Textbox to the interface
|
| 515 |
prompt = gr.TextArea(
|
|
|
|
| 517 |
placeholder=prompt_default,
|
| 518 |
elem_id="custom-textarea"
|
| 519 |
)
|
| 520 |
+
|
| 521 |
warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False)
|
| 522 |
|
| 523 |
# Submit button (initially disabled)
|
| 524 |
+
with gr.Row():
|
| 525 |
+
submit_models_button = gr.Button("Submit Models for NL2SQL task", interactive=False)
|
| 526 |
+
submit_models_button_tqa = gr.Button("Submit Models for TQA task", interactive=False)
|
| 527 |
|
| 528 |
def check_prompt(prompt):
|
| 529 |
#TODO
|
| 530 |
missing_elements = []
|
| 531 |
if(prompt==""):
|
| 532 |
+
input_data["prompt"] = prompt_default
|
| 533 |
button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
| 534 |
else:
|
| 535 |
input_data["prompt"]=prompt
|
|
|
|
| 546 |
), gr.update(interactive=button_state)
|
| 547 |
return gr.update(visible=False), gr.update(interactive=button_state)
|
| 548 |
|
| 549 |
+
prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, submit_models_button_tqa])
|
| 550 |
# Link checkboxes to selection events
|
| 551 |
for checkbox in model_checkboxes:
|
| 552 |
checkbox.change(
|
| 553 |
fn=get_selected_models,
|
| 554 |
inputs=model_checkboxes,
|
| 555 |
+
outputs=[selected_models_output, select_model_acc, submit_models_button, submit_models_button_tqa]
|
| 556 |
)
|
| 557 |
prompt.change(
|
| 558 |
fn=get_selected_models,
|
| 559 |
inputs=model_checkboxes,
|
| 560 |
+
outputs=[selected_models_output, select_model_acc, submit_models_button, submit_models_button_tqa]
|
| 561 |
)
|
| 562 |
|
| 563 |
submit_models_button.click(
|
|
|
|
| 566 |
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
| 567 |
)
|
| 568 |
|
| 569 |
+
submit_models_button_tqa.click(
|
| 570 |
+
fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
|
| 571 |
+
inputs=model_checkboxes,
|
| 572 |
+
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
| 573 |
+
)
|
| 574 |
+
def change_flag():
|
| 575 |
+
global flag_TQA
|
| 576 |
+
flag_TQA = True
|
| 577 |
+
|
| 578 |
+
submit_models_button_tqa.click(fn = change_flag, inputs=[], outputs=[])
|
| 579 |
+
|
| 580 |
def enable_disable(enable):
|
| 581 |
return (
|
| 582 |
*[gr.update(interactive=enable) for _ in model_checkboxes],
|
|
|
|
| 587 |
gr.update(interactive=enable),
|
| 588 |
gr.update(interactive=enable),
|
| 589 |
*[gr.update(interactive=enable) for _ in table_outputs],
|
| 590 |
+
gr.update(interactive=enable),
|
| 591 |
gr.update(interactive=enable)
|
| 592 |
)
|
| 593 |
|
|
|
|
| 605 |
default_checkbox,
|
| 606 |
table_selector,
|
| 607 |
*table_outputs,
|
| 608 |
+
open_model_selection,
|
| 609 |
+
submit_models_button_tqa
|
| 610 |
+
]
|
| 611 |
+
)
|
| 612 |
+
submit_models_button_tqa.click(
|
| 613 |
+
fn=enable_disable,
|
| 614 |
+
inputs=[gr.State(False)],
|
| 615 |
+
outputs=[
|
| 616 |
+
*model_checkboxes,
|
| 617 |
+
submit_models_button,
|
| 618 |
+
preview_output,
|
| 619 |
+
submit_button,
|
| 620 |
+
file_input,
|
| 621 |
+
default_checkbox,
|
| 622 |
+
table_selector,
|
| 623 |
+
*table_outputs,
|
| 624 |
+
open_model_selection,
|
| 625 |
+
submit_models_button_tqa
|
| 626 |
]
|
| 627 |
)
|
| 628 |
|
|
|
|
| 640 |
default_checkbox,
|
| 641 |
table_selector,
|
| 642 |
*table_outputs,
|
| 643 |
+
open_model_selection,
|
| 644 |
+
submit_models_button_tqa
|
| 645 |
]
|
| 646 |
)
|
| 647 |
|
|
|
|
| 692 |
{mirrored_symbols}
|
| 693 |
</div>
|
| 694 |
"""
|
| 695 |
+
|
| 696 |
+
def qatch_flow_nl_sql():
|
| 697 |
global reset_flag
|
| 698 |
+
global flag_TQA
|
| 699 |
predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
|
| 700 |
metrics_conc = pd.DataFrame()
|
| 701 |
columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
|
|
|
|
| 725 |
</div>
|
| 726 |
"""
|
| 727 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 728 |
+
|
| 729 |
prediction = row['predicted_sql']
|
| 730 |
|
| 731 |
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
|
|
|
|
| 733 |
<div style='font-size: 3rem'>➡️</div>
|
| 734 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
| 735 |
</div>
|
| 736 |
+
"""
|
| 737 |
+
|
| 738 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 739 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 740 |
metrics_conc = target_df
|
| 741 |
+
if 'valid_efficency_score' not in metrics_conc.columns:
|
| 742 |
+
metrics_conc['valid_efficency_score'] = metrics_conc['VES']
|
| 743 |
eval_text = generate_eval_text("End evaluation")
|
| 744 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 745 |
+
|
| 746 |
else:
|
|
|
|
| 747 |
orchestrator_generator = OrchestratorGenerator()
|
| 748 |
+
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=input_data['data']['selected_tables'])
|
| 749 |
+
|
| 750 |
+
#create target_df[target_answer]
|
| 751 |
+
if flag_TQA :
|
| 752 |
+
if (input_data["prompt"] == prompt_default):
|
| 753 |
+
input_data["prompt"] = prompt_default_tqa
|
| 754 |
+
target_df = us.extract_answer(target_df)
|
| 755 |
|
| 756 |
predictor = ModelPrediction()
|
| 757 |
reset_flag = False
|
|
|
|
| 772 |
</div>
|
| 773 |
"""
|
| 774 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
|
| 775 |
+
#samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
|
| 776 |
+
model_to_send = None if not flag_TQA else model
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
| 780 |
+
db_id = input_data["db_name"],
|
| 781 |
+
base_path = input_data["data_path"],
|
| 782 |
+
normalize=False,
|
| 783 |
+
sql=row["query"],
|
| 784 |
+
get_insert_into=True,
|
| 785 |
+
model = model_to_send,
|
| 786 |
+
prompt = input_data["prompt"].format(question=question, db_schema=""),
|
| 787 |
)
|
| 788 |
|
| 789 |
#prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
|
|
|
|
| 791 |
#PREDICTION SQL
|
| 792 |
|
| 793 |
# TODO add button for QA or SP and pass to .make_prediction parameter TASK
|
| 794 |
+
if flag_TQA: task="QA"
|
| 795 |
+
else: task="SP"
|
| 796 |
+
start_time = time.time()
|
| 797 |
response = predictor.make_prediction(
|
| 798 |
question=question,
|
| 799 |
+
db_schema=db_schema_text,
|
| 800 |
model_name=model,
|
| 801 |
prompt=f"{prompt_to_send}",
|
| 802 |
+
task=task
|
| 803 |
)
|
| 804 |
prediction = response['response_parsed']
|
| 805 |
price = response['cost']
|
| 806 |
answer = response['response']
|
| 807 |
|
| 808 |
end_time = time.time()
|
| 809 |
+
if flag_TQA:
|
| 810 |
+
task_string = "Answer"
|
| 811 |
+
else:
|
| 812 |
+
task_string = "SQL"
|
| 813 |
+
|
| 814 |
+
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted {task_string}:</div>
|
| 815 |
<div style='display: flex; align-items: center;'>
|
| 816 |
<div style='font-size: 3rem'>➡️</div>
|
| 817 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
|
|
|
| 826 |
'query': row["query"],
|
| 827 |
'db_path': input_data["data_path"],
|
| 828 |
'price':price,
|
| 829 |
+
'answer': answer,
|
| 830 |
'number_question':count,
|
| 831 |
+
'target_answer' : row["target_answer"] if flag_TQA else None,
|
| 832 |
+
|
| 833 |
}]).dropna(how="all") # Remove only completely empty rows
|
| 834 |
count=count+1
|
| 835 |
# TODO: use a for loop
|
| 836 |
+
if (flag_TQA) :
|
| 837 |
+
new_row['predicted_answer'] = prediction
|
| 838 |
for col in target_df.columns:
|
| 839 |
if col not in new_row.columns:
|
| 840 |
new_row[col] = row[col]
|
|
|
|
| 841 |
# Update model's prediction dataframe incrementally
|
| 842 |
if not new_row.empty:
|
| 843 |
predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
|
| 844 |
|
| 845 |
# yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
|
| 846 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
|
|
|
|
| 847 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 848 |
# END
|
| 849 |
eval_text = generate_eval_text("Evaluation")
|
| 850 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 851 |
+
|
| 852 |
evaluator = OrchestratorEvaluator()
|
| 853 |
+
|
| 854 |
for model in input_data["models"]:
|
| 855 |
+
if not flag_TQA:
|
| 856 |
+
metrics_df_model = evaluator.evaluate_df(
|
| 857 |
+
df=predictions_dict[model],
|
| 858 |
+
target_col_name="query",
|
| 859 |
+
prediction_col_name="predicted_sql",
|
| 860 |
+
db_path_name="db_path"
|
| 861 |
+
)
|
| 862 |
+
else:
|
| 863 |
+
metrics_df_model = us.evaluate_answer(predictions_dict[model])
|
| 864 |
metrics_df_model['model'] = model
|
| 865 |
metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
|
| 866 |
+
|
| 867 |
+
if 'valid_efficency_score' not in metrics_conc.columns and 'VES' in metrics_conc.columns:
|
| 868 |
+
metrics_conc['valid_efficency_score'] = metrics_conc['VES']
|
| 869 |
+
|
| 870 |
eval_text = generate_eval_text("End evaluation")
|
| 871 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 872 |
|
|
|
|
| 902 |
gr.Markdown(f"**Results for {model}**")
|
| 903 |
tab_dict[model] = tab
|
| 904 |
dataframe_per_model[model] = gr.DataFrame()
|
| 905 |
+
#TODO download metrics per model
|
| 906 |
# download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
|
| 907 |
|
| 908 |
evaluation_loading = gr.Markdown()
|
|
|
|
| 915 |
inputs=[],
|
| 916 |
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
|
| 917 |
)
|
| 918 |
+
submit_models_button_tqa.click(
|
| 919 |
+
change_tab,
|
| 920 |
+
inputs=[],
|
| 921 |
+
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
|
| 922 |
+
)
|
| 923 |
|
| 924 |
selected_models_display = gr.JSON(label="Final input data", visible=False)
|
| 925 |
metrics_df = gr.DataFrame(visible=False)
|
| 926 |
metrics_df_out = gr.DataFrame(visible=False)
|
| 927 |
|
| 928 |
submit_models_button.click(
|
| 929 |
+
fn=qatch_flow_nl_sql,
|
| 930 |
+
inputs=[],
|
| 931 |
+
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
submit_models_button_tqa.click(
|
| 935 |
+
fn=qatch_flow_nl_sql,
|
| 936 |
inputs=[],
|
| 937 |
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
| 938 |
)
|
|
|
|
| 941 |
fn=lambda: gr.update(value=input_data),
|
| 942 |
outputs=[selected_models_display]
|
| 943 |
)
|
| 944 |
+
submit_models_button_tqa.click(
|
| 945 |
+
fn=lambda: gr.update(value=input_data),
|
| 946 |
+
outputs=[selected_models_display]
|
| 947 |
+
)
|
| 948 |
|
| 949 |
# Works for METRICS
|
| 950 |
metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
|
|
|
|
| 967 |
fn=lambda: gr.update(visible=False),
|
| 968 |
outputs=[download_metrics]
|
| 969 |
)
|
| 970 |
+
submit_models_button_tqa.click(
|
| 971 |
+
fn=lambda: gr.update(visible=False),
|
| 972 |
+
outputs=[download_metrics]
|
| 973 |
+
)
|
| 974 |
|
| 975 |
def refresh():
|
| 976 |
global reset_flag
|
| 977 |
+
global flag_TQA
|
| 978 |
reset_flag = True
|
| 979 |
+
flag_TQA = False
|
| 980 |
|
| 981 |
reset_data = gr.Button("Back to upload data section", interactive=True)
|
| 982 |
|
|
|
|
| 1002 |
default_checkbox,
|
| 1003 |
table_selector,
|
| 1004 |
*table_outputs,
|
| 1005 |
+
open_model_selection,
|
| 1006 |
+
submit_models_button_tqa
|
| 1007 |
]
|
| 1008 |
)
|
| 1009 |
+
|
| 1010 |
+
|
| 1011 |
##########################################
|
| 1012 |
# METRICS VISUALIZATION SECTION #
|
| 1013 |
##########################################
|
|
|
|
| 1022 |
####################################
|
| 1023 |
|
| 1024 |
def load_data_csv_es():
|
| 1025 |
+
|
| 1026 |
if input_data["input_method"]=="default":
|
| 1027 |
+
global flag_TQA
|
| 1028 |
df = pd.read_csv(pnp_path)
|
| 1029 |
df = df[df['model'].isin(input_data["models"])]
|
| 1030 |
df = df[df['tbl_name'].isin(input_data["data"]["selected_tables"])]
|
|
|
|
| 1035 |
df['model'] = df['model'].replace('llama-70', 'Llama-70B')
|
| 1036 |
df['model'] = df['model'].replace('llama-8', 'Llama-8B')
|
| 1037 |
df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY')
|
| 1038 |
+
if (flag_TQA) : flag_TQA = False #TODO delete after make pred
|
| 1039 |
return df
|
| 1040 |
return metrics_df_out
|
| 1041 |
|
|
|
|
| 1078 |
|
| 1079 |
DB_CATEGORY_COLORS = generate_db_category_colors()
|
| 1080 |
|
| 1081 |
+
def normalize_valid_efficency_score(df):
|
| 1082 |
+
df['valid_efficency_score'] = df['valid_efficency_score'].replace([np.nan, ''], 0)
|
| 1083 |
+
df['valid_efficency_score'] = df['valid_efficency_score'].astype(int)
|
| 1084 |
+
min_val = df['valid_efficency_score'].min()
|
| 1085 |
+
max_val = df['valid_efficency_score'].max()
|
|
|
|
|
|
|
| 1086 |
|
| 1087 |
+
if min_val == max_val :
|
| 1088 |
+
# All values are equal, so for avoid division by zero, we set the score to 1/0
|
| 1089 |
+
if min_val == None:
|
| 1090 |
+
df['valid_efficency_score'] = 0
|
| 1091 |
+
else:
|
| 1092 |
+
df['valid_efficency_score'] = 1.0
|
| 1093 |
else:
|
| 1094 |
+
df['valid_efficency_score'] = (
|
| 1095 |
+
df['valid_efficency_score'] - min_val
|
| 1096 |
) / (max_val - min_val)
|
| 1097 |
|
| 1098 |
return df
|
|
|
|
| 1105 |
# BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
|
| 1106 |
def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models):
|
| 1107 |
df = df[df['model'].isin(selected_models)]
|
| 1108 |
+
df = normalize_valid_efficency_score(df)
|
| 1109 |
|
| 1110 |
# Mappatura nomi leggibili -> tecnici
|
| 1111 |
qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
|
|
|
|
| 1222 |
selected_models = [selected_models]
|
| 1223 |
|
| 1224 |
df = df[df['model'].isin(selected_models)]
|
| 1225 |
+
df = normalize_valid_efficency_score(df)
|
| 1226 |
|
| 1227 |
# Converti nomi leggibili -> tecnici
|
| 1228 |
qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
|
|
|
|
| 1307 |
)
|
| 1308 |
|
| 1309 |
return gr.Plot(fig, visible=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1310 |
|
| 1311 |
def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
|
| 1312 |
df = load_data_csv_es()
|
|
|
|
| 1322 |
df = df[df['db_category'].isin(target_cats)]
|
| 1323 |
df = df[df['model'].isin(selected_models)]
|
| 1324 |
|
| 1325 |
+
df = normalize_valid_efficency_score(df)
|
| 1326 |
df = calculate_average_metrics(df, qatch_metrics)
|
| 1327 |
|
| 1328 |
# Calcola la media per db_category e modello
|
|
|
|
| 1443 |
|
| 1444 |
# RADAR OR BAR CHART BASED ON CATEGORY COUNT
|
| 1445 |
def plot_radar(df, selected_models, selected_metrics, selected_categories):
|
| 1446 |
+
if "External" in selected_metrics:
|
| 1447 |
+
selected_metrics = ["execution_accuracy", "valid_efficency_score"]
|
| 1448 |
else:
|
| 1449 |
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
|
| 1450 |
|
| 1451 |
# Filtro modelli e normalizzazione
|
| 1452 |
df = df[df['model'].isin(selected_models)]
|
| 1453 |
+
df = normalize_valid_efficency_score(df)
|
| 1454 |
df = calculate_average_metrics(df, selected_metrics)
|
| 1455 |
|
| 1456 |
avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
|
|
|
|
| 1607 |
|
| 1608 |
# RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT
|
| 1609 |
def plot_radar_sub(df, selected_models, selected_metrics, selected_category):
|
| 1610 |
+
if "External" in selected_metrics:
|
| 1611 |
+
selected_metrics = ["execution_accuracy", "valid_efficency_score"]
|
| 1612 |
else:
|
| 1613 |
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
|
| 1614 |
|
| 1615 |
df = df[df['model'].isin(selected_models)]
|
| 1616 |
+
df = normalize_valid_efficency_score(df)
|
| 1617 |
df = calculate_average_metrics(df, selected_metrics)
|
| 1618 |
|
| 1619 |
if isinstance(selected_category, str):
|
|
|
|
| 1776 |
|
| 1777 |
# RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
|
| 1778 |
def worst_cases_text(df, selected_models, selected_metrics, selected_categories):
|
| 1779 |
+
global flag_TQA
|
| 1780 |
if selected_models == "All":
|
| 1781 |
selected_models = models
|
| 1782 |
else:
|
|
|
|
| 1791 |
df = df[df['test_category'].isin(selected_categories)]
|
| 1792 |
|
| 1793 |
if "external" in selected_metrics:
|
| 1794 |
+
selected_metrics = ["execution_accuracy", "valid_efficency_score"]
|
| 1795 |
else:
|
| 1796 |
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
|
| 1797 |
|
| 1798 |
+
df = normalize_valid_efficency_score(df)
|
| 1799 |
df = calculate_average_metrics(df, selected_metrics)
|
|
|
|
|
|
|
| 1800 |
|
| 1801 |
+
if flag_TQA:
|
| 1802 |
+
df["target_answer"] = df["target_answer"].apply(
|
| 1803 |
+
lambda x: " - ".join([",".join(map(str, item)) for item in x]) if isinstance(x, list) else str(x)
|
| 1804 |
+
)
|
| 1805 |
+
df["predicted_answer"] = df["predicted_answer"].apply(
|
| 1806 |
+
lambda x: " - ".join([",".join(map(str, item)) for item in x]) if isinstance(x, list) else str(x)
|
| 1807 |
+
)
|
| 1808 |
+
|
| 1809 |
+
worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'target_answer', 'predicted_answer', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
|
| 1810 |
+
else:
|
| 1811 |
+
worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
|
| 1812 |
+
|
| 1813 |
worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
|
| 1814 |
|
| 1815 |
worst_cases_top_3 = worst_cases_df.head(3)
|
|
|
|
| 1822 |
medals = ["🥇", "🥈", "🥉"]
|
| 1823 |
|
| 1824 |
for i, row in worst_cases_top_3.iterrows():
|
| 1825 |
+
if flag_TQA:
|
| 1826 |
+
entry = (
|
| 1827 |
+
f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
|
| 1828 |
+
f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n"
|
| 1829 |
+
f"<span style='font-size:16px;'>- <b>Original Answer:</b> `{row['target_answer']}`</span> \n"
|
| 1830 |
+
f"<span style='font-size:16px;'>- <b>Predicted Answer:</b> `{row['predicted_answer']}`</span> \n\n"
|
| 1831 |
+
)
|
| 1832 |
+
|
| 1833 |
+
worst_str.append(entry)
|
| 1834 |
+
else:
|
| 1835 |
+
entry = (
|
| 1836 |
+
f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
|
| 1837 |
+
f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n"
|
| 1838 |
+
f"<span style='font-size:16px;'>- <b>Original Query:</b> `{row['query']}`</span> \n"
|
| 1839 |
+
f"<span style='font-size:16px;'>- <b>Predicted SQL:</b> `{row['predicted_sql']}`</span> \n\n"
|
| 1840 |
+
)
|
| 1841 |
|
| 1842 |
+
worst_str.append(entry)
|
| 1843 |
|
| 1844 |
raw_answer = (
|
| 1845 |
f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
|
|
|
|
| 1847 |
)
|
| 1848 |
|
| 1849 |
answer_str.append(raw_answer)
|
| 1850 |
+
|
| 1851 |
return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2]
|
| 1852 |
|
| 1853 |
def update_worst_cases_text(selected_models, selected_metrics, selected_categories):
|
|
|
|
| 1857 |
# LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
|
| 1858 |
def plot_cumulative_flow(df, selected_models, max_points):
|
| 1859 |
df = df[df['model'].isin(selected_models)]
|
| 1860 |
+
df = normalize_valid_efficency_score(df)
|
| 1861 |
|
| 1862 |
fig = go.Figure()
|
| 1863 |
|
|
|
|
| 1991 |
|
| 1992 |
external_metrics_dict = {
|
| 1993 |
"Execution Accuracy": "execution_accuracy",
|
| 1994 |
+
"Valid Efficency Score": "valid_efficency_score"
|
| 1995 |
}
|
| 1996 |
|
| 1997 |
+
external_metric = ["execution_accuracy", "valid_efficency_score"]
|
| 1998 |
last_valid_external_metric_selection = external_metric.copy()
|
| 1999 |
def enforce_external_metric_selection(selected):
|
| 2000 |
global last_valid_external_metric_selection
|
|
|
|
| 2041 |
|
| 2042 |
all_model_as_dic = {cat: [f"{cat}"] for cat in models}
|
| 2043 |
all_model_as_dic["All"] = models
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2044 |
|
| 2045 |
###########################
|
| 2046 |
# VISUALIZATION SECTION #
|
|
|
|
| 2079 |
<span
|
| 2080 |
title="External metric info:
|
| 2081 |
Execution Accuracy: Checks if the predicted query returns exactly the same result as the ground truth query when executed. It is a binary metric: 1 if the output matches, 0 otherwise.
|
| 2082 |
+
Valid Efficency Score: Evaluates the efficency of a query by combining execution time and correctness. It rewards queries that are both accurate and fast."
|
| 2083 |
style="margin-left: 6px; cursor: help; color: #00bfff; font-size: 16px; white-space: pre-line;"
|
| 2084 |
>External metric info ℹ️</span>
|
| 2085 |
</div>
|
|
|
|
| 2354 |
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
|
| 2355 |
reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics])
|
| 2356 |
reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
|
| 2357 |
+
|
| 2358 |
|
| 2359 |
interface.launch(share = True)
|
concatenated_output.csv
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
cell_precision,sql_tag,tuple_cardinality,answer,predicted_sql,db_category,tuple_constraint,VES,number_question,
|
| 2 |
1.0,DISTINCT-SINGLE,1.0,"```sql
|
| 3 |
SELECT DISTINCT WAREHOUSE_LOAD_DATE
|
| 4 |
FROM FAC_BUILDING_ADDRESS;
|
|
|
|
| 1 |
+
cell_precision,sql_tag,tuple_cardinality,answer,predicted_sql,db_category,tuple_constraint,VES,number_question,valid_efficency_score,tbl_name,tuple_order,time,price,question,model,cell_recall,db_path,execution_accuracy,test_category,query
|
| 2 |
1.0,DISTINCT-SINGLE,1.0,"```sql
|
| 3 |
SELECT DISTINCT WAREHOUSE_LOAD_DATE
|
| 4 |
FROM FAC_BUILDING_ADDRESS;
|
utilities.py
CHANGED
|
@@ -6,6 +6,11 @@ import sqlite3
|
|
| 6 |
import gradio as gr
|
| 7 |
import os
|
| 8 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def extract_tables(file_path):
|
| 10 |
conn = sqlite3.connect(file_path)
|
| 11 |
cursor = conn.cursor()
|
|
@@ -26,7 +31,7 @@ def extract_dataframes(file_path):
|
|
| 26 |
return dfs
|
| 27 |
|
| 28 |
def carica_sqlite(file_path, db_id):
|
| 29 |
-
data_output = {'data_frames': extract_dataframes(file_path),'db':SqliteConnector(relative_db_path=file_path, db_name=db_id)}
|
| 30 |
return data_output
|
| 31 |
|
| 32 |
# Funzione per leggere un file CSV
|
|
@@ -113,3 +118,94 @@ def generate_some_samples(file_path, tbl_name):
|
|
| 113 |
def load_tables_dict_from_pkl(file_path):
|
| 114 |
with open(file_path, 'rb') as f:
|
| 115 |
return pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import os
|
| 8 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
| 9 |
+
from qatch.evaluate_dataset.metrics_evaluators import CellPrecision, CellRecall, ExecutionAccuracy, TupleCardinality, TupleConstraint, TupleOrder, ValidEfficiencyScore
|
| 10 |
+
import qatch.evaluate_dataset.orchestrator_evaluator as eva
|
| 11 |
+
#import tiktoken
|
| 12 |
+
from transformers import AutoTokenizer
|
| 13 |
+
|
| 14 |
def extract_tables(file_path):
|
| 15 |
conn = sqlite3.connect(file_path)
|
| 16 |
cursor = conn.cursor()
|
|
|
|
| 31 |
return dfs
|
| 32 |
|
| 33 |
def carica_sqlite(file_path, db_id):
|
| 34 |
+
data_output = {'data_frames': extract_dataframes(file_path),'db': SqliteConnector(relative_db_path=file_path, db_name=db_id)}
|
| 35 |
return data_output
|
| 36 |
|
| 37 |
# Funzione per leggere un file CSV
|
|
|
|
| 118 |
def load_tables_dict_from_pkl(file_path):
|
| 119 |
with open(file_path, 'rb') as f:
|
| 120 |
return pickle.load(f)
|
| 121 |
+
|
| 122 |
+
def extract_tables_dict(pnp_path):
|
| 123 |
+
return load_tables_dict_from_pkl('tables_dict_beaver.pkl')
|
| 124 |
+
tables_dict = {}
|
| 125 |
+
with open(pnp_path, mode='r', encoding='utf-8') as file:
|
| 126 |
+
reader = csv.DictReader(file)
|
| 127 |
+
tbl_db_pairs = set() # Use a set to avoid duplicates
|
| 128 |
+
for row in reader:
|
| 129 |
+
tbl_name = row.get("tbl_name")
|
| 130 |
+
db_path = row.get("db_path")
|
| 131 |
+
if tbl_name and db_path:
|
| 132 |
+
tbl_db_pairs.add((tbl_name, db_path)) # Add the pair to the set
|
| 133 |
+
for tbl_name, db_path in list(tbl_db_pairs):
|
| 134 |
+
if tbl_name and db_path:
|
| 135 |
+
connector = sqlite3.connect(db_path)
|
| 136 |
+
query = f"SELECT * FROM {tbl_name} LIMIT 5"
|
| 137 |
+
try:
|
| 138 |
+
df = pd.read_sql_query(query, connector)
|
| 139 |
+
tables_dict[tbl_name] = df
|
| 140 |
+
except Exception as e:
|
| 141 |
+
tables_dict[tbl_name] = pd.DataFrame({"Error": [str(e)]}) # DataFrame con messaggio di errore
|
| 142 |
+
#with open('tables_dict_beaver.pkl', 'wb') as f:
|
| 143 |
+
# pickle.dump(tables_dict, f)
|
| 144 |
+
return tables_dict
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def extract_answer(df):
|
| 148 |
+
if "query" not in df.columns or "db_path" not in df.columns:
|
| 149 |
+
raise ValueError("The DataFrame must contain 'query' and 'data_path' columns.")
|
| 150 |
+
|
| 151 |
+
answers = []
|
| 152 |
+
for _, row in df.iterrows():
|
| 153 |
+
query = row["query"]
|
| 154 |
+
db_path = row["db_path"]
|
| 155 |
+
try:
|
| 156 |
+
conn = SqliteConnector(relative_db_path = db_path , db_name= "db")
|
| 157 |
+
answer = eva._utils_run_query_if_str(query, conn)
|
| 158 |
+
answers.append(answer)
|
| 159 |
+
except Exception as e:
|
| 160 |
+
answers.append(f"Error: {e}")
|
| 161 |
+
|
| 162 |
+
df["target_answer"] = answers
|
| 163 |
+
return df
|
| 164 |
+
|
| 165 |
+
evaluator = {
|
| 166 |
+
"cell_precision": CellPrecision(),
|
| 167 |
+
"cell_recall": CellRecall(),
|
| 168 |
+
"tuple_cardinality": TupleCardinality(),
|
| 169 |
+
"tuple_order": TupleOrder(),
|
| 170 |
+
"tuple_constraint": TupleConstraint(),
|
| 171 |
+
"execution_accuracy": ExecutionAccuracy(),
|
| 172 |
+
"valid_efficency_score": ValidEfficiencyScore()
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
def evaluate_answer(df):
|
| 176 |
+
for metric_name, metric in evaluator.items():
|
| 177 |
+
results = []
|
| 178 |
+
for _, row in df.iterrows():
|
| 179 |
+
target = row["target_answer"]
|
| 180 |
+
predicted = row["predicted_answer"]
|
| 181 |
+
try:
|
| 182 |
+
result = metric.run_metric(target = target, prediction = predicted)
|
| 183 |
+
except Exception as e:
|
| 184 |
+
result = None
|
| 185 |
+
results.append(result)
|
| 186 |
+
df[metric_name] = results
|
| 187 |
+
return df
|
| 188 |
+
|
| 189 |
+
models = [
|
| 190 |
+
"gpt-4o-mini",
|
| 191 |
+
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
def crop_entries_per_token(entries_list, model, prompt: str | None = None):
|
| 195 |
+
#open_ai_models = ["gpt-3.5", "gpt-4o-mini"]
|
| 196 |
+
dimension = 2048
|
| 197 |
+
#enties_string = [", ".join(map(str, entry)) for entry in entries_list]
|
| 198 |
+
if prompt:
|
| 199 |
+
entries_string = prompt.join(entries_list)
|
| 200 |
+
else:
|
| 201 |
+
entries_string = " ".join(entries_list)
|
| 202 |
+
#if model in ["deepseek-ai/DeepSeek-R1-Distill-Llama-70B" ,"gpt-4o-mini" ] :
|
| 203 |
+
#tokenizer = tiktoken.encoding_for_model("gpt-4o-mini")
|
| 204 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = "deepseek-ai/DeepSeek-R1-Distill-Llama-70B")
|
| 205 |
+
|
| 206 |
+
tokens = tokenizer.encode(entries_string)
|
| 207 |
+
number_of_tokens = len(tokens)
|
| 208 |
+
if number_of_tokens > dimension and len(entries_list) > 4:
|
| 209 |
+
entries_list = entries_list[:round(len(entries_list)/2)]
|
| 210 |
+
entries_list = crop_entries_per_token(entries_list, model)
|
| 211 |
+
return entries_list
|
utils_get_db_tables_info.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
import os
|
| 2 |
import sqlite3
|
| 3 |
import re
|
| 4 |
-
|
| 5 |
|
| 6 |
def utils_extract_db_schema_as_string(
|
| 7 |
-
db_id, base_path, normalize=False, sql: str | None = None, get_insert_into: bool = False
|
| 8 |
):
|
| 9 |
"""
|
| 10 |
Extracts the full schema of an SQLite database into a single string.
|
|
@@ -19,7 +19,7 @@ def utils_extract_db_schema_as_string(
|
|
| 19 |
cursor = connection.cursor()
|
| 20 |
|
| 21 |
# Get the schema entries based on the provided SQL query
|
| 22 |
-
schema_entries = _get_schema_entries(cursor, sql, get_insert_into)
|
| 23 |
|
| 24 |
# Combine all schema definitions into a single string
|
| 25 |
schema_string = _combine_schema_entries(schema_entries, normalize)
|
|
@@ -28,7 +28,7 @@ def utils_extract_db_schema_as_string(
|
|
| 28 |
|
| 29 |
|
| 30 |
|
| 31 |
-
def _get_schema_entries(cursor, sql=None, get_insert_into=False):
|
| 32 |
"""
|
| 33 |
Retrieves schema entries and optionally data entries from the SQLite database.
|
| 34 |
|
|
@@ -62,11 +62,17 @@ def _get_schema_entries(cursor, sql=None, get_insert_into=False):
|
|
| 62 |
column_names = [description[0] for description in cursor.description]
|
| 63 |
|
| 64 |
# Generate INSERT INTO statements for each row
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
|
| 68 |
insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
|
| 69 |
entries.append(insert_stmt)
|
|
|
|
|
|
|
| 70 |
|
| 71 |
return entries
|
| 72 |
|
|
|
|
| 1 |
import os
|
| 2 |
import sqlite3
|
| 3 |
import re
|
| 4 |
+
import utilities as us
|
| 5 |
|
| 6 |
def utils_extract_db_schema_as_string(
|
| 7 |
+
db_id, base_path, model : str | None = None , normalize=False, sql: str | None = None, get_insert_into: bool = False, prompt : str | None = None
|
| 8 |
):
|
| 9 |
"""
|
| 10 |
Extracts the full schema of an SQLite database into a single string.
|
|
|
|
| 19 |
cursor = connection.cursor()
|
| 20 |
|
| 21 |
# Get the schema entries based on the provided SQL query
|
| 22 |
+
schema_entries = _get_schema_entries(cursor, sql, get_insert_into, model, prompt)
|
| 23 |
|
| 24 |
# Combine all schema definitions into a single string
|
| 25 |
schema_string = _combine_schema_entries(schema_entries, normalize)
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
|
| 31 |
+
def _get_schema_entries(cursor, sql=None, get_insert_into=False, model: str | None = None, prompt : str | None = None):
|
| 32 |
"""
|
| 33 |
Retrieves schema entries and optionally data entries from the SQLite database.
|
| 34 |
|
|
|
|
| 62 |
column_names = [description[0] for description in cursor.description]
|
| 63 |
|
| 64 |
# Generate INSERT INTO statements for each row
|
| 65 |
+
if model==None :
|
| 66 |
+
max_len=3
|
| 67 |
+
else:
|
| 68 |
+
max_len = len(rows)
|
| 69 |
+
|
| 70 |
+
for row in rows[:max_len]:
|
| 71 |
values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
|
| 72 |
insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
|
| 73 |
entries.append(insert_stmt)
|
| 74 |
+
|
| 75 |
+
if model != None : entries = us.crop_entries_per_token(entries, model, prompt)
|
| 76 |
|
| 77 |
return entries
|
| 78 |
|