diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,2355 +1,2355 @@ -import os -import sys -import time -import re -import csv -import gradio as gr -import pandas as pd -import numpy as np -import plotly.express as px -import plotly.graph_objects as go -import plotly.colors as pc -from qatch.connectors.sqlite_connector import SqliteConnector -from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator -from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator -import qatch.evaluate_dataset.orchestrator_evaluator as eva -from prediction import ModelPrediction -import utils_get_db_tables_info -import utilities as us -# @spaces.GPU -# def model_prediction(): -# pass -# # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10 -# if os.environ.get("SPACES_ZERO_GPU") is not None: -# import spaces -# else: -# class spaces: -# @staticmethod -# def GPU(func): -# def wrapper(*args, **kwargs): -# return func(*args, **kwargs) -# return wrapper -#pnp_path = os.path.join("data", "evaluation_p_np_metrics.csv") -pnp_path = "concatenated_output.csv" -PATH_PKL_TABLES = 'tables_dict_beaver.pkl' -PNP_TQA_PATH = 'concatenated_output_tqa.csv' -js_func = """ -function refresh() { - const url = new URL(window.location); - - if (url.searchParams.get('__theme') !== 'light') { - url.searchParams.set('__theme', 'light'); - window.location.href = url.href; - } -} -""" -reset_flag = False -flag_TQA = False - -with open('style.css', 'r') as file: - css = file.read() - -# DataFrame di default -df_default = pd.DataFrame({ - 'Name': ['Alice', 'Bob', 'Charlie'], - 'Age': [25, 30, 35], - 'City': ['New York', 'Los Angeles', 'Chicago'] -}) -models_path ="models.csv" - -# Variabile globale per tenere traccia dei dati correnti -df_current = df_default.copy() - -description = """## πŸ“Š Comparison of Proprietary and Non-Proprietary Databases - ### ➀ **Proprietary** : - ###     β‡’ Economic πŸ’°, Medical πŸ₯, Financial πŸ’³, Miscellaneous πŸ“‚ - ###     β‡’ BEAVER (FAC BUILDING ADDRESS 🏒 , TIME QUARTER ⏱️) - ### ➀ **Non-Proprietary** - ###     β‡’ Spider 1.0 πŸ•·οΈ""" -prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n" -prompt_default_tqa = "Return the answer of the following question based on the provided database. Return your answer as the result of a query executed over the database. Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n Return the answer in answer tag as .\n Question \n {question}\n Database Schema\n {db_schema}\n" - - -input_data = { - 'input_method': "", - 'data_path': "", - 'db_name': "", - 'data': { - 'data_frames': {}, # dictionary of dataframes - 'db': None, # SQLITE3 database object - 'selected_tables' :[] - }, - 'models': [], - 'prompt': prompt_default -} - -def load_data(file, path, use_default): - """Carica i dati da un file, un percorso o usa il DataFrame di default.""" - global df_current - if file is not None: - try: - input_data["input_method"] = 'uploaded_file' - input_data["db_name"] = os.path.splitext(os.path.basename(file))[0] - if file.endswith('.sqlite'): - #return 'Error: The uploaded file is not a valid SQLite database.' - input_data["data_path"] = file #os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite") - else: - #change path - input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite") - input_data["data"] = us.load_data(file, input_data["db_name"]) - - df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame - if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files - table2primary_key = {} - for table_name, df in input_data["data"]['data_frames'].items(): - # Assign primary keys for each table - table2primary_key[table_name] = 'id' - input_data["data"]["db"] = SqliteConnector( - relative_db_path=input_data["data_path"], - db_name=input_data["db_name"], - tables= input_data["data"]['data_frames'], - table2primary_key=table2primary_key - ) - return input_data["data"]['data_frames'] - except Exception as e: - return f'Errore nel caricamento del file: {e}' - if use_default: - if(use_default == 'Custom'): - input_data["input_method"] = 'custom' - #input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable_0.sqlite") - input_data["data_path"] = os.path.join(".","mytable_0.sqlite") - #if file already exist - while os.path.exists(input_data["data_path"]): - input_data["data_path"] = us.increment_filename(input_data["data_path"]) - input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0] - input_data["data"]['data_frames'] = {'MyTable': df_current} - - if(input_data["data"]['data_frames']): - table2primary_key = {} - for table_name, df in input_data["data"]['data_frames'].items(): - # Assign primary keys for each table - table2primary_key[table_name] = 'id' - input_data["data"]["db"] = SqliteConnector( - relative_db_path=input_data["data_path"], - db_name=input_data["db_name"], - tables= input_data["data"]['data_frames'], - table2primary_key=table2primary_key - ) - df_current = df_default.copy() # Ripristina i dati di default - return input_data["data"]['data_frames'] - - if(use_default == 'Proprietary vs Non-proprietary'): - input_data["input_method"] = 'default' - #input_data["data_path"] = os.path.join(".", "data", "data_interface", "default.sqlite") - #input_data["data_path"] = os.path.join(".", "data", "spider_databases", "defeault.sqlite") - #input_data["db_name"] = "default" - #input_data["data"]['db'] = SqliteConnector(relative_db_path=input_data["data_path"], db_name=input_data["db_name"]) - input_data["data"]['data_frames'] = us.load_tables_dict_from_pkl(PATH_PKL_TABLES) - return input_data["data"]['data_frames'] - - selected_inputs = sum([file is not None, bool(path), use_default]) - if selected_inputs > 1: - return 'Error: Select only one input method at a time.' - - return input_data["data"]['data_frames'] - -def preview_default(use_default, file): - if file: - return gr.DataFrame(interactive=True, visible = False, value = df_default), gr.update(value="## βœ… File successfully uploaded!", visible=True) - else : - if use_default == 'Custom': - return gr.DataFrame(interactive=True, visible = True, value = df_default), gr.update(value="## πŸ“ Toy Table", visible=True) - else: - return gr.DataFrame(interactive=False, visible = False, value = df_default), gr.update(value = description, visible=True) - #return gr.DataFrame(interactive=True, value = df_current) # Mostra il DataFrame corrente, che potrebbe essere stato modificato - -def update_df(new_df): - """Aggiorna il DataFrame corrente.""" - global df_current # Usa la variabile globale per aggiornarla - df_current = new_df - return df_current - -def open_accordion(target): - # Apre uno e chiude l'altro - if target == "reset": - df_current = df_default.copy() - input_data['input_method'] = "" - input_data['data_path'] = "" - input_data['db_name'] = "" - input_data['data']['data_frames'] = {} - input_data['data']['selected_tables'] = [] - input_data['data']['db'] = None - input_data['models'] = [] - return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value='Proprietary vs Non-proprietary'), gr.update(value=None) - elif target == "model_selection": - return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False) - -# Interfaccia Gradio -#with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: -with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as interface: - with gr.Row(): - with gr.Column(scale=1): - gr.Image( - value=os.path.join(".", "qatch_logo.png"), - show_label=False, - container=False, - interactive=False, - show_fullscreen_button=False, - show_download_button=False, - show_share_button=False, - height=150, # in pixel - width=300 - ) - with gr.Column(scale=1): - pass - data_state = gr.State(None) # Memorizza i dati caricati - upload_acc = gr.Accordion("Upload data section", open=True, visible=True) - select_table_acc = gr.Accordion("Select tables section", open=False, visible=False) - select_model_acc = gr.Accordion("Select models section", open=False, visible=False) - qatch_acc = gr.Accordion("QATCH execution section", open=False, visible=False) - metrics_acc = gr.Accordion("Metrics section", open=False, visible=False) - - ################################# - # DATABASE INSERTION # - ################################# - with upload_acc: - gr.Markdown("## πŸ“₯Choose data input method") - with gr.Row(): - default_checkbox = gr.Radio(label = "Explore the comparison between proprietary and non-proprietary databases or edit a toy table with the values you prefer", choices=['Proprietary vs Non-proprietary', 'Custom'], value='Proprietary vs Non-proprietary') - #default_checkbox = gr.Checkbox(label="Use default DataFrame" - - table_default = gr.Markdown(description, visible=True) - preview_output = gr.DataFrame(interactive=False, visible=False, value=df_default) - - gr.Markdown("## πŸ“‚ Or upload your data") - file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"]) - submit_button = gr.Button("Load Data") # Disabled by default - output = gr.JSON(visible=False) # Dictionary output - - # Function to enable the button if there is data to load - def enable_submit(file, use_default): - return gr.update(interactive=bool(file or use_default)) - - # Function to uncheck the checkbox if a file is uploaded - def deselect_default(file): - if file: - return gr.update(value='Proprietary vs Non-proprietary') - return gr.update() - - def enable_disable_first(enable): - return ( - gr.update(interactive=enable), - gr.update(interactive=enable), - gr.update(interactive=enable), - gr.update(interactive=enable) - ) - - # Enable the button when inputs are provided - #file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button]) - #default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button]) - - # Show preview of the default DataFrame when checkbox is selected - default_checkbox.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default]) - file_input.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default]) - preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output]) - - # Uncheck the checkbox when a file is uploaded - file_input.change(fn=deselect_default, inputs=[file_input], outputs=[default_checkbox]) - - def handle_output(file, use_default): - """Handles the output when the 'Load Data' button is pressed.""" - result = load_data(file, None, use_default) - - if isinstance(result, dict): # If result is a dictionary of DataFrames - if len(result) == 1: # If there's only one table - input_data['data']['selected_tables'] = list(input_data['data']['data_frames'].keys()) - return ( - gr.update(visible=False), # Hide JSON output - result, # Save the data state - gr.update(visible=False), # Hide table selection - result, # Maintain the data state - gr.update(interactive=False), # Disable the submit button - gr.update(visible=True, open=True), # Proceed to select_model_acc - gr.update(visible=True, open=False) - ) - else: - return ( - gr.update(visible=False), - result, - gr.update(open=True, visible=True), - result, - gr.update(interactive=False), - gr.update(visible=False), # Keep current behavior - gr.update(visible=True, open=False) - ) - else: - return ( - gr.update(visible=False), - None, - gr.update(open=False, visible=True), - None, - gr.update(interactive=True), - gr.update(visible=False), - gr.update(visible=True, open=False) - ) - - submit_button.click( - fn=handle_output, - inputs=[file_input, default_checkbox], - outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc] - ) - - submit_button.click( - fn=enable_disable_first, - inputs=[gr.State(False)], - outputs=[ - preview_output, - submit_button, - file_input, - default_checkbox - ] - ) - - ###################################### - # TABLE SELECTION PART # - ###################################### - with select_table_acc: - previous_selection = gr.State([]) - table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[]) - excluded_tables_info = gr.HTML(label="Non-selectable tables (too many columns)", visible=False) - table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)] - selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False) - - # Model selection button (initially disabled) - open_model_selection = gr.Button("Choose your models", interactive=False) - def update_table_list(data): - """Dynamically updates the list of available tables and excluded ones.""" - if isinstance(data, dict) and data: - table_names = [] - excluded_tables = [] - - data_frames = input_data['data'].get('data_frames', {}) - - available_tables = [] - for name, df in data.items(): - df_real = data_frames.get(name, None) - if input_data['input_method'] != "default": - if df_real is not None and df_real.shape[1] > 15: - excluded_tables.append(name) - else: - available_tables.append(name) - else: - available_tables.append(name) - - - if input_data['input_method'] == "default": - table_names.append("All") - excluded_tables = [] - elif len(available_tables) < 6: - table_names.append("All") - - table_names.extend(available_tables) - if excluded_tables and input_data['input_method'] != "default" : - excluded_text = "⚠️ The following tables have more than 15 columns and cannot be selected:
" + "
".join(f"- {t}" for t in excluded_tables) - excluded_visible = True - else: - excluded_text = "" - excluded_visible = False - - return [ - gr.update(choices=table_names, value=[]), # CheckboxGroup update - gr.update(value=excluded_text, visible=excluded_visible) # HTML display update - ] - - return [ - gr.update(choices=[], value=[]), - gr.update(value="", visible=False) - ] - - def show_selected_tables(data, selected_tables): - updates = [] - data_frames = input_data['data'].get('data_frames', {}) - - available_tables = [] - for name, df in data.items(): - df_real = data_frames.get(name) - if input_data['input_method'] != "default" : - if df_real is not None and df_real.shape[1] <= 15: - available_tables.append(name) - else: - available_tables.append(name) - - input_method = input_data['input_method'] - allow_all = input_method == "default" or len(available_tables) < 6 - - selected_set = set(selected_tables) - tables_set = set(available_tables) - - if allow_all: - if "All" in selected_set: - selected_tables = ["All"] + available_tables - elif selected_set == tables_set: - selected_tables = [] - else: - selected_tables = [t for t in selected_tables if t in available_tables] - else: - selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5] - - tables = {name: data[name] for name in selected_tables if name in data} - - for i, (name, df) in enumerate(tables.items()): - updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False)) - - for _ in range(len(tables), 50): - updates.append(gr.update(visible=False)) - - updates.append(gr.update(interactive=bool(tables))) - - if allow_all: - updates.insert(0, gr.update( - choices=["All"] + available_tables, - value=selected_tables - )) - else: - if len(selected_tables) >= 5: - updates.insert(0, gr.update( - choices=selected_tables, - value=selected_tables - )) - else: - updates.insert(0, gr.update( - choices=available_tables, - value=selected_tables - )) - - return updates - - def show_selected_table_names(data, selected_tables): - """Displays the names of the selected tables when the button is pressed.""" - if selected_tables: - available_tables = list(data.keys()) # Actually available names - if "All" in selected_tables: - selected_tables = available_tables - if (input_data['input_method'] != "default") : selected_tables = [t for t in selected_tables if len(data[t].columns) <= 15] - - input_data['data']['selected_tables'] = selected_tables - return gr.update(value=", ".join(selected_tables), visible=False) - return gr.update(value="", visible=False) - - # Automatically updates the checkbox list when `data_state` changes - data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector, excluded_tables_info]) - - # Updates the visible tables and the button state based on user selections - #table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection]) - table_selector.change( - fn=show_selected_tables, - inputs=[data_state, table_selector], - outputs=[table_selector] + table_outputs + [open_model_selection] - ) - # Shows the list of selected tables when "Choose your models" is clicked - open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names]) - open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc]) - - reset_data = gr.Button("Back to upload data section") - - reset_data.click( - fn=enable_disable_first, - inputs=[gr.State(True)], - outputs=[ - preview_output, - submit_button, - file_input, - default_checkbox - ] - ) - reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) - - #################################### - # MODEL SELECTION PART # - #################################### - with select_model_acc: - gr.Markdown("# Model Selection") - - # Assume that `us.read_models_csv` also returns the image path - model_list_dict = us.read_models_csv(models_path) - model_list = [model["code"] for model in model_list_dict] - model_images = [model["image_path"] for model in model_list_dict] - model_names = [model["name"] for model in model_list_dict] - # Create a mapping between model_list and model_images_names - model_mapping = dict(zip(model_list, model_names)) - model_mapping_reverse = dict(zip(model_names, model_list)) - - model_checkboxes = [] - rows = [] - - # Dynamically create checkboxes with images (3 per row) - for i in range(0, len(model_list), 3): - with gr.Row(): - cols = [] - for j in range(3): - if i + j < len(model_list): - model = model_list[i + j] - image_path = model_images[i + j] - with gr.Column(): - gr.Image(image_path, - show_label=False, - container=False, - interactive=False, - show_fullscreen_button=False, - show_download_button=False, - show_share_button=False) - checkbox = gr.Checkbox(label=model_mapping[model], value=False) - model_checkboxes.append(checkbox) - cols.append(checkbox) - rows.append(cols) - - selected_models_output = gr.JSON(visible=False) - - # Function to get selected models - def get_selected_models(*model_selections): - selected_models = [model for model, selected in zip(model_list, model_selections) if selected] - input_data['models'] = selected_models - button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) - return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state) - - # Add the Textbox to the interface - with gr.Row(): - button_prompt_nlsql = gr.Button("Choose NL2SQL task") - button_prompt_tqa = gr.Button("Choose TQA task") - - prompt = gr.TextArea( - label="Customise the prompt for selected models here or leave the default one.", - placeholder=prompt_default, - elem_id="custom-textarea" - ) - - warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False) - - # Submit button (initially disabled) - with gr.Row(): - submit_models_button = gr.Button("Submit Models", interactive=False) - - def check_prompt(prompt): - #TODO - missing_elements = [] - if(prompt==""): - global flag_TQA - if not flag_TQA: - input_data["prompt"] = prompt_default - else: - input_data["prompt"] = prompt_default_tqa - button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) - else: - input_data["prompt"] = prompt - if "{db_schema}" not in prompt: - missing_elements.append("{db_schema}") - if "{question}" not in prompt: - missing_elements.append("{question}") - button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) - if missing_elements: - return gr.update( - value=f"
" - f"❌ Missing {', '.join(missing_elements)} in the prompt ❌
", - visible=True - ), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"]) - return gr.update(visible=False), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"]) - - prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button]) - # Link checkboxes to selection events - for checkbox in model_checkboxes: - checkbox.change( - fn=get_selected_models, - inputs=model_checkboxes, - outputs=[selected_models_output, select_model_acc, submit_models_button] - ) - prompt.change( - fn=get_selected_models, - inputs=model_checkboxes, - outputs=[selected_models_output, select_model_acc, submit_models_button] - ) - - submit_models_button.click( - fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)), - inputs=model_checkboxes, - outputs=[selected_models_output, select_model_acc, qatch_acc] - ) - - def change_flag(): - global flag_TQA - flag_TQA = True - - def dis_flag(): - global flag_TQA - flag_TQA = False - - button_prompt_tqa.click(fn = change_flag, inputs=[], outputs=[]) - - button_prompt_nlsql.click(fn = dis_flag, inputs=[], outputs=[]) - - button_prompt_tqa.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt]) - - button_prompt_nlsql.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt]) - - - def enable_disable(enable): - return ( - *[gr.update(interactive=enable) for _ in model_checkboxes], - gr.update(interactive=enable), - gr.update(interactive=enable), - gr.update(interactive=enable), - gr.update(interactive=enable), - gr.update(interactive=enable), - gr.update(interactive=enable), - *[gr.update(interactive=enable) for _ in table_outputs], - gr.update(interactive=enable) - ) - - reset_data = gr.Button("Back to upload data section") - - submit_models_button.click( - fn=enable_disable, - inputs=[gr.State(False)], - outputs=[ - *model_checkboxes, - submit_models_button, - preview_output, - submit_button, - file_input, - default_checkbox, - table_selector, - *table_outputs, - open_model_selection - ] - ) - - reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) - - reset_data.click( - fn=enable_disable, - inputs=[gr.State(True)], - outputs=[ - *model_checkboxes, - submit_models_button, - preview_output, - submit_button, - file_input, - default_checkbox, - table_selector, - *table_outputs, - open_model_selection - ] - ) - - ############################# - # QATCH EXECUTION # - ############################# - with qatch_acc: - def change_text(text): - return text - - loading_symbols= {1:"π“†Ÿ", - 2: "π“†ž π“†Ÿ", - 3: "𓆛 π“†ž π“†Ÿ", - 4: "π“†ž 𓆛 π“†ž π“†Ÿ", - 5: "π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", - 6: "π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", - 7: "π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", - 8: "π“†ž π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", - 9: "π“†Ÿ π“†ž π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", - 10:"π“†ž π“†Ÿ π“†ž π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", - } - - def generate_loading_text(percent): - num_symbols = (round(percent) % 11) + 1 - symbols = loading_symbols.get(num_symbols, "π“†Ÿ") - mirrored_symbols = f'{symbols.strip()}' - css_symbols = f'{symbols.strip()}' - return f""" -
- {css_symbols} - - Generation {percent}% - - {mirrored_symbols} -
- """ - - def generate_eval_text(text): - symbols = "𓆑 " - mirrored_symbols = f'{symbols.strip()}' - css_symbols = f'{symbols.strip()}' - return f""" -
- {css_symbols} - - {text} - - {mirrored_symbols} -
- """ - - def qatch_flow_nl_sql(): - global reset_flag - global flag_TQA - predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list} - metrics_conc = pd.DataFrame() - columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"] - if (input_data['input_method']=="default"): - #target_df = us.load_csv(pnp_path) #target_df = us.load_csv("priority_non_priority_metrics.csv") - target_df = us.load_csv(pnp_path) if not flag_TQA else us.load_csv(PNP_TQA_PATH) - #predictions_dict = {model: pd.DataFrame(columns=target_df.columns) for model in model_list} - target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])] - target_df = target_df[target_df["model"].isin(input_data['models'])] - predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list} - reset_flag = False - for model in input_data['models']: - model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None) - yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] - count=1 - for _, row in predictions_dict[model].iterrows(): - #for index, row in target_df.iterrows(): - if (reset_flag == False): - percent_complete = round(count / len(predictions_dict[model]) * 100, 2) - count=count+1 - load_text = f"{generate_loading_text(percent_complete)}" - question = row['question'] - - display_question = f"""
Natural Language:
-
-
{question}
-
➑️
-
- """ - yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] - - prediction = row['predicted_sql'] - - display_prediction = f"""
Predicted SQL:
-
-
➑️
-
{prediction}
-
- """ - - yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] - yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] - metrics_conc = target_df - if 'valid_efficency_score' not in metrics_conc.columns: - metrics_conc['valid_efficency_score'] = metrics_conc['VES'] - if 'VES' not in metrics_conc.columns: - metrics_conc['VES'] = metrics_conc['valid_efficency_score'] - eval_text = generate_eval_text("End evaluation") - yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] - - else: - global flag_TQA - orchestrator_generator = OrchestratorGenerator() - target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables']) - - #create target_df[target_answer] - if flag_TQA : - # if (input_data["prompt"] == prompt_default): - # input_data["prompt"] = prompt_default_tqa - - target_df['db_schema'] = target_df.apply( - lambda row: utils_get_db_tables_info.utils_extract_db_schema_as_string( - db_id=input_data["db_name"], - base_path=input_data["data_path"], - normalize=False, - sql=row["query"], - get_insert_into=True, - model=None, - prompt=input_data["prompt"].format(question=row["question"], db_schema="") - ), - axis=1 - ) - - target_df = us.extract_answer(target_df) - - predictor = ModelPrediction() - reset_flag = False - for model in input_data["models"]: - model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None) - yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] - count=0 - for index, row in target_df.iterrows(): - if (reset_flag == False): - percent_complete = round(((index+1) / len(target_df)) * 100, 2) - load_text = f"{generate_loading_text(percent_complete)}" - - question = row['question'] - display_question = f"""
Natural Language:
-
-
{question}
-
➑️
-
- """ - yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] - #samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"]) - model_to_send = None if not flag_TQA else model - - db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string( - db_id = input_data["db_name"], - base_path = input_data["data_path"], - normalize=False, - sql=row["query"], - get_insert_into=True, - model = model_to_send, - prompt = input_data["prompt"].format(question=question, db_schema=""), - ) - - #prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples) - prompt_to_send = input_data["prompt"] - #PREDICTION SQL - - # TODO add button for QA or SP and pass to .make_prediction parameter TASK - if flag_TQA: task="QA" - else: task="SP" - start_time = time.time() - response = predictor.make_prediction( - question=question, - db_schema=db_schema_text, - model_name=model, - prompt=f"{prompt_to_send}", - task=task - ) - #if flag_TQA: response = {'response_parsed': "[['Alice'],['Bob'],['Charlie']]", 'cost': 0, 'response': "[['Alice'],['Bob'],['Charlie']]"} # TODO remove this line - #else : response = {'response_parsed': "SELECT * FROM 'MyTable'", 'cost': 0, 'response': "SQL_QUERY"} - end_time = time.time() - prediction = response['response_parsed'] - price = response['cost'] - answer = response['response'] - - if flag_TQA: - task_string = "Answer" - else: - task_string = "SQL" - - display_prediction = f"""
Predicted {task_string}:
-
-
➑️
-
{prediction}
-
- """ - # Create a new row as dataframe - new_row = pd.DataFrame([{ - 'id': index, - 'question': question, - 'predicted_sql': prediction, - 'time': end_time - start_time, - 'query': row["query"], - 'db_path': input_data["data_path"], - 'price':price, - 'answer': answer, - 'number_question':count, - 'target_answer' : row["target_answer"] if flag_TQA else None, - - }]).dropna(how="all") # Remove only completely empty rows - count=count+1 - # TODO: use a for loop - if (flag_TQA) : - new_row['predicted_answer'] = prediction - for col in target_df.columns: - if col not in new_row.columns: - new_row[col] = row[col] - # Update model's prediction dataframe incrementally - if not new_row.empty: - predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True) - - # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None - yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list] - yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list] - # END - eval_text = generate_eval_text("Evaluation") - yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] - - evaluator = OrchestratorEvaluator() - - for model in input_data["models"]: - if not flag_TQA: - metrics_df_model = evaluator.evaluate_df( - df=predictions_dict[model], - target_col_name="query", - prediction_col_name="predicted_sql", - db_path_name="db_path" - ) - else: - metrics_df_model = us.evaluate_answer(predictions_dict[model]) - metrics_df_model['model'] = model - metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True) - - if 'VES' not in metrics_conc.columns and 'valid_efficency_score' not in metrics_conc.columns: - metrics_conc['VES'] = 0 - metrics_conc['valid_efficency_score'] = 0 - - if 'valid_efficency_score' not in metrics_conc.columns: - metrics_conc['valid_efficency_score'] = metrics_conc['VES'] - - if 'VES' not in metrics_conc.columns: - metrics_conc['VES'] = metrics_conc['valid_efficency_score'] - - eval_text = generate_eval_text("End evaluation") - yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] - - # Loading Bar - with gr.Row(): - # progress = gr.Progress() - variable = gr.Markdown() - - # NL -> MODEL -> Generated Query - with gr.Row(): - with gr.Column(): - with gr.Column(): - question_display = gr.Markdown() - with gr.Column(): - model_logo = gr.Image(visible=True, - show_label=False, - container=False, - interactive=False, - show_fullscreen_button=False, - show_download_button=False, - show_share_button=False) - with gr.Column(): - with gr.Column(): - prediction_display = gr.Markdown() - - dataframe_per_model = {} - - with gr.Tabs() as model_tabs: - tab_dict = {} - - for model, model_name in zip(model_list, model_names): - with gr.TabItem(model_name, visible=(model in input_data["models"])) as tab: - gr.Markdown(f"**Results for {model}**") - tab_dict[model] = tab - dataframe_per_model[model] = gr.DataFrame() - #TODO download metrics per model - # download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False) - - evaluation_loading = gr.Markdown() - - def change_tab(): - return [gr.update(visible=(model in input_data["models"])) for model in model_list] - - submit_models_button.click( - change_tab, - inputs=[], - outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility - ) - - selected_models_display = gr.JSON(label="Final input data", visible=False) - metrics_df = gr.DataFrame(visible=False) - metrics_df_out = gr.DataFrame(visible=False) - - submit_models_button.click( - fn=qatch_flow_nl_sql, - inputs=[], - outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values()) - ) - - submit_models_button.click( - fn=lambda: gr.update(value=input_data), - outputs=[selected_models_display] - ) - - # Works for METRICS - metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out]) - - proceed_to_metrics_button = gr.Button("Proceed to Metrics", visible=False) - proceed_to_metrics_button.click( - fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)), - outputs=[qatch_acc, metrics_acc] - ) - - def allow_download(metrics_df_out): - #path = os.path.join(".", "data", "data_results", "results.csv") - path = os.path.join(".", "results.csv") - metrics_df_out.to_csv(path, index=False) - return gr.update(value=path, visible=True), gr.update(visible=True), gr.update(interactive=True) - - download_metrics = gr.DownloadButton(label="Download Metrics Evaluation", visible=False) - - submit_models_button.click( - fn=lambda: gr.update(visible=False), - outputs=[download_metrics] - ) - - def refresh(): - global reset_flag - global flag_TQA - reset_flag = True - flag_TQA = False - - reset_data = gr.Button("Back to upload data section", interactive=True) - - metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data]) - - reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) - #WHY NOT WORKING? - reset_data.click( - fn=lambda: gr.update(visible=False), - outputs=[download_metrics] - ) - reset_data.click(refresh) - - reset_data.click( - fn=enable_disable, - inputs=[gr.State(True)], - outputs=[ - *model_checkboxes, - submit_models_button, - preview_output, - submit_button, - file_input, - default_checkbox, - table_selector, - *table_outputs, - open_model_selection - ] - ) - - ########################################## - # METRICS VISUALIZATION SECTION # - ########################################## - with metrics_acc: - #data_path = 'test_results_metrics1.csv' - - @gr.render(inputs=metrics_df_out) - def function_metrics(metrics_df_out): - - #################################### - # UTILS FUNCTIONS SECTION # - #################################### - - def load_data_csv_es(): - - if input_data["input_method"]=="default": - global flag_TQA - #df = pd.read_csv(pnp_path) - df = us.load_csv(pnp_path) if not flag_TQA else us.load_csv(PNP_TQA_PATH) - df = df[df['model'].isin(input_data["models"])] - df = df[df['tbl_name'].isin(input_data["data"]["selected_tables"])] - - df['model'] = df['model'].replace('DeepSeek-R1-Distill-Llama-70B', 'DS-Llama3 70B') - df['model'] = df['model'].replace('gpt-3.5', 'GPT-3.5') - df['model'] = df['model'].replace('gpt-4o-mini', 'GPT-4o-mini') - df['model'] = df['model'].replace('llama-70', 'Llama-70B') - df['model'] = df['model'].replace('llama-8', 'Llama-8B') - df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY') - #if (flag_TQA) : flag_TQA = False #TODO delete after make pred - return df - return metrics_df_out - - def calculate_average_metrics(df, selected_metrics): - # Exclude the 'tuple_order' column from the selected metrics - - #TODO tuple_order has NULL VALUE - selected_metrics = [metric for metric in selected_metrics if metric != 'tuple_order'] - #print(df[selected_metrics]) - df['avg_metric'] = df[selected_metrics].mean(axis=1) - return df - - def generate_model_colors(): - """Generates a unique color map for models in the dataset.""" - df = load_data_csv_es() - unique_models = df['model'].unique() # Extract unique models - num_models = len(unique_models) - - # Use the Plotly color scale (you can change it if needed) - color_palette = ['#00B4D8', '#BCE784', '#C84630', '#F79256', '#D269FC'] - #color_palette = pc.qualitative.Plotly # ['#636EFA', '#EF553B', '#00CC96', ...] - - # If there are more models than colors, cycle through them - colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)} - - return colors - - MODEL_COLORS = generate_model_colors() - - def generate_db_category_colors(): - """Assigns 3 distinct colors to db_category groups.""" - return { - "Spider": "#1f77b4", # blu - "Beaver": "#ff7f0e", # arancione - "Economic": "#2ca02c", # tutti gli altri verdi - "Financial": "#2ca02c", - "Medical": "#2ca02c", - "Miscellaneous": "#2ca02c" - } - - DB_CATEGORY_COLORS = generate_db_category_colors() - - def normalize_valid_efficency_score(df): - df['valid_efficency_score'] = df['valid_efficency_score'].replace([np.nan, ''], 0) - df['valid_efficency_score'] = df['valid_efficency_score'].astype(int) - min_val = df['valid_efficency_score'].min() - max_val = df['valid_efficency_score'].max() - - if min_val == max_val : - # All values are equal, so for avoid division by zero, we set the score to 1/0 - if min_val == None: - df['valid_efficency_score'] = 0 - else: - df['valid_efficency_score'] = 1.0 - else: - df['valid_efficency_score'] = ( - df['valid_efficency_score'] - min_val - ) / (max_val - min_val) - - return df - - - #################################### - # GRAPH FUNCTIONS SECTION # - #################################### - - # BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION - def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models): - df = df[df['model'].isin(selected_models)] - df = normalize_valid_efficency_score(df) - - # Mappatura nomi leggibili -> tecnici - qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics] - external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric] - - selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal - - df = calculate_average_metrics(df, selected_metrics) - - if group_by == ["model"]: - # Bar plot per "model" - avg_metrics = df.groupby("model")['avg_metric'].mean().reset_index() - avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') - - fig = px.bar( - avg_metrics, - x="model", - y="avg_metric", - color="model", - color_discrete_map=MODEL_COLORS, - title='Average metrics per Model 🧠', - labels={"model": "Model", "avg_metric": "Average Metrics"}, - template='simple_white', - #template='plotly_dark', - text='text_label' - ) - else: - if group_by != ["tbl_name", "model"]: - group_by = ["tbl_name", "model"] - - avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index() - avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') - - fig = px.bar( - avg_metrics, - x=group_by[0], - y='avg_metric', - color='model', - color_discrete_map=MODEL_COLORS, - barmode='group', - title=f'Average metrics per {group_by[0]} πŸ“Š', - labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metrics'}, - template='simple_white', - #template='plotly_dark', - text='text_label' - ) - - fig.update_traces(textposition='outside', textfont_size=10) - - # Applica font Inter a tutto il layout - fig.update_layout( - margin=dict(t=80), - title=dict( - font=dict( - family="Inter, sans-serif", - size=22, - #color="white" - ), - x=0.5 - ), - xaxis=dict( - title=dict( - font=dict( - family="Inter, sans-serif", - size=18, - #color="white" - ) - ), - tickfont=dict( - family="Inter, sans-serif", - #color="white" - size=16 - ) - ), - yaxis=dict( - title=dict( - font=dict( - family="Inter, sans-serif", - size=18, - #color="white" - ) - ), - tickfont=dict( - family="Inter, sans-serif", - #color="white" - ) - ), - legend=dict( - title=dict( - font=dict( - family="Inter, sans-serif", - size=16, - #color="white" - ) - ), - font=dict( - family="Inter, sans-serif", - #color="white" - ) - ) - ) - - return gr.Plot(fig, visible=True) - - def update_plot(radio_metric, qatch_selected_metrics, external_selected_metric,group_by, selected_models): - df = load_data_csv_es() - return plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models) - - # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION - def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models): - if selected_models == "All": - selected_models = models - else: - selected_models = [selected_models] - - df = df[df['model'].isin(selected_models)] - df = normalize_valid_efficency_score(df) - - # Converti nomi leggibili -> tecnici - qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics] - external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric] - - selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal - - df = calculate_average_metrics(df, selected_metrics) - - avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index() - avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') - fig = px.bar( - avg_metrics, - x='db_category', - y='avg_metric', - color='model', - color_discrete_map=MODEL_COLORS, - barmode='group', - title='Average metrics per database types πŸ“Š', - labels={'db_path': 'DB Path', 'avg_metric': 'Average Metrics'}, - template='simple_white', - text='text_label' - ) - - fig.update_traces(textposition='outside', textfont_size=14) - - # Aggiorna layout con font Inter - fig.update_layout( - margin=dict(t=80), - title=dict( - font=dict( - family="Inter, sans-serif", - size=24, - color="black" - ), - x=0.5 - ), - xaxis=dict( - title=dict( - text='Database Category', - font=dict( - family='Inter, sans-serif', - size=22, - color='black' - ) - ), - tickfont=dict( - family='Inter, sans-serif', - color='black', - size=20 - ) - ), - yaxis=dict( - title=dict( - text='Average Metrics', - font=dict( - family='Inter, sans-serif', - size=22, - color='black' - ) - ), - tickfont=dict( - family='Inter, sans-serif', - color='black' - ) - ), - legend=dict( - title=dict( - text='Models', - font=dict( - family='Inter, sans-serif', - size=20, - color='black' - ) - ), - font=dict( - family='Inter, sans-serif', - color='black', - size=18 - ) - ) - ) - - return gr.Plot(fig, visible=True) - - def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models): - df = load_data_csv_es() - return plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models) - - # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION - - def lollipop_propietary(selected_models): - df = load_data_csv_es() - - # Filtra solo le categorie rilevanti - target_cats = ["Spider", "Economic", "Financial", "Medical", "Miscellaneous", "Beaver"] - df = df[df['db_category'].isin(target_cats)] - df = df[df['model'].isin(selected_models)] - - df = normalize_valid_efficency_score(df) - df = calculate_average_metrics(df, qatch_metrics) - - # Calcola la media per db_category e modello - avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index() - - # Separa Spider e le altre 4 categorie - spider_df = avg_metrics[avg_metrics["db_category"] == "Spider"] - other_df = avg_metrics[avg_metrics["db_category"] != "Spider"] - - # Calcola media delle altre categorie per ciascun modello - other_mean_df = other_df.groupby("model")["avg_metric"].mean().reset_index() - other_mean_df["db_category"] = "Others" - - # Rinominare per chiarezza e uniformitΓ  - spider_df = spider_df.rename(columns={"avg_metric": "Spider"}) - other_mean_df = other_mean_df.rename(columns={"avg_metric": "Others"}) - - # Unione dei due dataset - merged_df = pd.merge(spider_df[["model", "Spider"]], other_mean_df[["model", "Others"]], on="model") - - # Ordina per modello o per valore se vuoi - merged_df = merged_df.sort_values(by="model") - - fig = go.Figure() - - # Aggiungi linee orizzontali tra Spider e Others - for _, row in merged_df.iterrows(): - fig.add_trace(go.Scatter( - x=[row["Spider"], row["Others"]], - y=[row["model"]] * 2, - mode='lines', - line=dict(color='gray', width=2), - showlegend=False - )) - - # Punto per Spider - fig.add_trace(go.Scatter( - x=merged_df["Spider"], - y=merged_df["model"], - mode='markers', - name='Non-Proprietary (Spider)', - marker=dict(size=10, color='#C84630') - )) - - # Punto per Others (media delle altre 4 categorie) - fig.add_trace(go.Scatter( - x=merged_df["Others"], - y=merged_df["model"], - mode='markers', - name='Proprietary Databases', - marker=dict(size=10, color='#0077B6') - )) - - fig.update_layout( - xaxis_title='Average Metrics', - yaxis_title='Models', - template='simple_white', - #template='plotly_dark', - margin=dict(t=80), - title=dict( - font=dict( - family="Inter, sans-serif", - size=22, - color="black" - ), - x=0.5, - text='Dumbbell graph: Non-Proprietary (Spider πŸ•·οΈ) vs Proprietary Databases πŸ“Š' - ), - legend_title='Type of Databases:', - height=600, - xaxis=dict( - title=dict( - text='DB Category', - font=dict( - family='Inter, sans-serif', - size=18, - color='black' - ) - ), - tickfont=dict( - family='Inter, sans-serif', - color='black' - ) - ), - yaxis=dict( - title=dict( - text='Average Metrics', - font=dict( - family='Inter, sans-serif', - size=18, - color='black' - ) - ), - tickfont=dict( - family='Inter, sans-serif', - color='black' - ) - ), - legend=dict( - title=dict( - text='Models', - font=dict( - family='Inter, sans-serif', - size=18, - color='black' - ) - ), - font=dict( - family='Inter, sans-serif', - color='black', - size=14 - ) - ) - ) - - return gr.Plot(fig, visible=True) - - - # RADAR OR BAR CHART BASED ON CATEGORY COUNT - def plot_radar(df, selected_models, selected_metrics, selected_categories): - if "External" in selected_metrics: - selected_metrics = ["execution_accuracy", "valid_efficency_score"] - else: - selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] - - # Filtro modelli e normalizzazione - df = df[df['model'].isin(selected_models)] - df = normalize_valid_efficency_score(df) - df = calculate_average_metrics(df, selected_metrics) - - avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index() - - if avg_metrics.empty: - print("Error: No data available to compute averages.") - return go.Figure() - - categories = selected_categories - - if len(categories) < 3: - # πŸ”„ BAR PLOT - fig = go.Figure() - for model in selected_models: - model_data = avg_metrics[avg_metrics['model'] == model] - values = [ - model_data[model_data['test_category'] == cat]['avg_metric'].values[0] - if cat in model_data['test_category'].values else 0 - for cat in categories - ] - fig.add_trace(go.Bar( - x=categories, - y=values, - name=model, - marker=dict(color=MODEL_COLORS.get(model, "gray")) - )) - - fig.update_layout( - barmode='group', - title=dict( - text='πŸ“Š Bar Plot of Metrics per Model (Few Categories)', - font=dict( - family='Inter, sans-serif', - size=22, - #color='white' - ), - x=0.5 - ), - template='simple_white', - #template='plotly_dark', - xaxis=dict( - title=dict( - text='Test Category', - font=dict( - family='Inter, sans-serif', - size=18, - #color='white' - ) - ), - tickfont=dict( - family='Inter, sans-serif', - size=16 - #color='white' - ) - ), - yaxis=dict( - title=dict( - text='Average Metrics', - font=dict( - family='Inter, sans-serif', - size=18, - #color='white' - ) - ), - tickfont=dict( - family='Inter, sans-serif', - #color='white' - ) - ), - legend=dict( - title=dict( - text='Models', - font=dict( - family='Inter, sans-serif', - size=16, - #color='white' - ) - ), - font=dict( - family='Inter, sans-serif', - #color='white' - ) - ) - ) - else: - # 🧭 RADAR PLOT - fig = go.Figure() - for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True): - model_data = avg_metrics[avg_metrics['model'] == model] - values = [ - model_data[model_data['test_category'] == cat]['avg_metric'].values[0] - if cat in model_data['test_category'].values else 0 - for cat in categories - ] - fig.add_trace(go.Scatterpolar( - r=values, - theta=categories, - fill='toself', - name=model, - line=dict(color=MODEL_COLORS.get(model, "gray")) - )) - - fig.update_layout( - polar=dict( - radialaxis=dict( - visible=True, - range=[0, max(avg_metrics['avg_metric'].max(), 0.5)], - tickfont=dict( - family='Inter, sans-serif', - #color='white' - ) - ), - angularaxis=dict( - tickfont=dict( - family='Inter, sans-serif', - size=16 - #color='white' - ) - ) - ), - title=dict( - text='❇️ Radar Plot of Metrics per Model (Average per SQL Category)', - font=dict( - family='Inter, sans-serif', - size=22, - #color='white' - ), - x=0.5 - ), - legend=dict( - title=dict( - text='Models', - font=dict( - family='Inter, sans-serif', - size=18, - #color='white' - ) - ), - font=dict( - family='Inter, sans-serif', - size=16 - #color='white' - ) - ), - template='simple_white' - #template='plotly_dark' - ) - - return fig - - def update_radar(selected_models, selected_metrics, selected_categories): - df = load_data_csv_es() - return plot_radar(df, selected_models, selected_metrics, selected_categories) - - # RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT - def plot_radar_sub(df, selected_models, selected_metrics, selected_category): - if "External" in selected_metrics: - selected_metrics = ["execution_accuracy", "valid_efficency_score"] - else: - selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] - - df = df[df['model'].isin(selected_models)] - df = normalize_valid_efficency_score(df) - df = calculate_average_metrics(df, selected_metrics) - - if isinstance(selected_category, str): - selected_category = [selected_category] - - df = df[df['test_category'].isin(selected_category)] - avg_metrics = df.groupby(['model', 'sql_tag'])['avg_metric'].mean().reset_index() - - if avg_metrics.empty: - print("Error: No data available to compute averages.") - return go.Figure() - - categories = df['sql_tag'].unique().tolist() - - if len(categories) < 3: - # πŸ”„ BAR PLOT - fig = go.Figure() - for model in selected_models: - model_data = avg_metrics[avg_metrics['model'] == model] - values = [ - model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0] - if cat in model_data['sql_tag'].values else 0 - for cat in categories - ] - fig.add_trace(go.Bar( - x=categories, - y=values, - name=model, - marker=dict(color=MODEL_COLORS.get(model, "gray")) - )) - - fig.update_layout( - barmode='group', - title=dict( - text='πŸ“Š Bar Plot of Metrics per Model (Few Sub-Categories)', - font=dict( - family='Inter, sans-serif', - size=22, - #color='white' - ), - x=0.5 - ), - template='simple_white', - #template='plotly_dark', - xaxis=dict( - title=dict( - text='SQL Tag (Sub Category)', - font=dict( - family='Inter, sans-serif', - size=18, - #color='white' - ) - ), - tickfont=dict( - family='Inter, sans-serif', - #color='white' - ) - ), - yaxis=dict( - title=dict( - text='Average Metrics', - font=dict( - family='Inter, sans-serif', - size=18, - #color='white' - ) - ), - tickfont=dict( - family='Inter, sans-serif', - #color='white' - ) - ), - legend=dict( - title=dict( - text='Models', - font=dict( - family='Inter, sans-serif', - size=16, - #color='white' - ) - ), - font=dict( - family='Inter, sans-serif', - size=14 - #color='white' - ) - ) - ) - else: - # 🧭 RADAR PLOT - fig = go.Figure() - - for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True): - model_data = avg_metrics[avg_metrics['model'] == model] - values = [ - model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0] - if cat in model_data['sql_tag'].values else 0 - for cat in categories - ] - - fig.add_trace(go.Scatterpolar( - r=values, - theta=categories, - fill='toself', - name=model, - line=dict(color=MODEL_COLORS.get(model, "gray")) - )) - - fig.update_layout( - polar=dict( - radialaxis=dict( - visible=True, - range=[0, max(avg_metrics['avg_metric'].max(), 0.5)], - tickfont=dict( - family='Inter, sans-serif', - #color='white' - ) - ), - angularaxis=dict( - tickfont=dict( - family='Inter, sans-serif', - size=16 - #color='white' - ) - ) - ), - title=dict( - text='❇️ Radar Plot of Metrics per Model (Average per SQL Sub-Category)', - font=dict( - family='Inter, sans-serif', - size=22, - #color='white' - ), - x=0.5 - ), - legend=dict( - title=dict( - text='Models', - font=dict( - family='Inter, sans-serif', - size=16, - #color='white' - ) - ), - font=dict( - family='Inter, sans-serif', - size=14, - #color='white' - ) - ), - template='simple_white' - #template='plotly_dark' - ) - - return fig - - def update_radar_sub(selected_models, selected_metrics, selected_category): - df = load_data_csv_es() - return plot_radar_sub(df, selected_models, selected_metrics, selected_category) - - # RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION - def worst_cases_text(df, selected_models, selected_metrics, selected_categories): - global flag_TQA - if selected_models == "All": - selected_models = models - else: - selected_models = [selected_models] - - if selected_categories == "All": - selected_categories = principal_categories - else: - selected_categories = [selected_categories] - - df = df[df['model'].isin(selected_models)] - df = df[df['test_category'].isin(selected_categories)] - - if "external" in selected_metrics: - selected_metrics = ["execution_accuracy", "valid_efficency_score"] - else: - selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] - - df = normalize_valid_efficency_score(df) - df = calculate_average_metrics(df, selected_metrics) - - if flag_TQA: - df["target_answer"] = df["target_answer"] = df["target_answer"].apply(lambda x: "[" + ", ".join(map(str, x)) + "]") - - worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'target_answer', 'predicted_answer', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index() - else: - worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index() - - worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True) - - worst_cases_top_3 = worst_cases_df.head(3) - - worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2) - - worst_str = [] - answer_str = [] - - medals = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"] - - for i, row in worst_cases_top_3.iterrows(): - if flag_TQA: - entry = ( - f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n" - f"- Question: {row['question']} \n" - f"- Original Answer: `{row['target_answer']}` \n" - f"- Predicted Answer: `{eval(row['predicted_answer'])}` \n\n" - ) - - worst_str.append(entry) - else: - entry = ( - f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n" - f"- Question: {row['question']} \n" - f"- Original Query: `{row['query']}` \n" - f"- Predicted SQL: `{row['predicted_sql']}` \n\n" - ) - - worst_str.append(entry) - - raw_answer = ( - f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n" - f"- Raw Answer:
`{row['answer']}`
\n" - ) - - answer_str.append(raw_answer) - - return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2] - - def update_worst_cases_text(selected_models, selected_metrics, selected_categories): - df = load_data_csv_es() - return worst_cases_text(df, selected_models, selected_metrics, selected_categories) - - # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION - def plot_cumulative_flow(df, selected_models, max_points): - df = df[df['model'].isin(selected_models)] - df = normalize_valid_efficency_score(df) - - fig = go.Figure() - - for model in selected_models: - model_df = df[df['model'] == model].copy() - - # Limita il numero di punti se richiesto - if max_points is not None: - model_df = model_df.head(max_points + 1) - - # Tooltip personalizzato - model_df['hover_info'] = model_df.apply( - lambda row: - f"Id question: {row['number_question']}
" - f"Question: {row['question']}
" - f"Target: {row['query']}
" - f"Prediction: {row['predicted_sql']}
" - f"Category: {row['test_category']}", - axis=1 - ) - - # Calcoli cumulativi - model_df['cumulative_time'] = model_df['time'].cumsum() - model_df['cumulative_price'] = model_df['price'].cumsum() - - # Colore del modello - color = MODEL_COLORS.get(model, "gray") - - fig.add_trace(go.Scatter( - x=model_df['cumulative_time'], - y=model_df['cumulative_price'], - mode='lines+markers', - name=model, - line=dict(width=2, color=color), - customdata=model_df['hover_info'], - hovertemplate= - "Model: " + model + "
" + - "Cumulative Time: %{x}s
" + - "Cumulative Price: $%{y:.2f}
" + - "
Details:
%{customdata}" - )) - - # Layout con font elegante - fig.update_layout( - title=dict( - text="Cumulative Price Flow Chart πŸ’°", - font=dict( - family="Inter, sans-serif", - size=24, - #color="white" - ), - x=0.5 - ), - xaxis=dict( - title=dict( - text="Cumulative Time (s)", - font=dict( - family="Inter, sans-serif", - size=20, - #color="white" - ) - ), - tickfont=dict( - family="Inter, sans-serif", - size=18 - #color="white" - ) - ), - yaxis=dict( - title=dict( - text="Cumulative Price ($)", - font=dict( - family="Inter, sans-serif", - size=20, - #color="white" - ) - ), - tickfont=dict( - family="Inter, sans-serif", - size=18 - #color="white" - ) - ), - legend=dict( - title=dict( - text="Models", - font=dict( - family="Inter, sans-serif", - size=18, - #color="white" - ) - ), - font=dict( - family="Inter, sans-serif", - size=16, - #color="white" - ) - ), - template='simple_white', - #template="plotly_dark" - ) - - return fig - - def update_query_rate(selected_models, max_points): - df = load_data_csv_es() - return plot_cumulative_flow(df, selected_models, max_points) - - - - - ####################### - # PARAMETER SECTION # - ####################### - qatch_metrics_dict = { - "Cell Precision": "cell_precision", - "Cell Recall": "cell_recall", - "Tuple Order": "tuple_order", - "Tuple Cardinality": "tuple_cardinality", - "Tuple Constraint": "tuple_constraint" - } - - qatch_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] - last_valid_qatch_metrics_selection = qatch_metrics.copy() # Per salvare l’ultima selezione valida - def enforce_qatch_metrics_selection(selected): - global last_valid_qatch_metrics_selection - if not selected: # Se nessuna metrica Γ¨ selezionata - return gr.update(value=last_valid_qatch_metrics_selection) - last_valid_qatch_metrics_selection = selected # Altrimenti aggiorna la selezione valida - return gr.update(value=selected) - - external_metrics_dict = { - "Execution Accuracy": "execution_accuracy", - "Valid Efficency Score": "valid_efficency_score" - } - - external_metric = ["execution_accuracy", "valid_efficency_score"] - last_valid_external_metric_selection = external_metric.copy() - def enforce_external_metric_selection(selected): - global last_valid_external_metric_selection - if not selected: # Se nessuna metrica Γ¨ selezionata - return gr.update(value=last_valid_external_metric_selection) - last_valid_external_metric_selection = selected # Altrimenti aggiorna la selezione valida - return gr.update(value=selected) - - all_metrics = { - "Qatch": ["qatch"], - "External": ["external"] - } - - group_options = { - "Table": ["tbl_name", "model"], - "Model": ["model"] - } - - df_initial = load_data_csv_es() - models = models = df_initial['model'].unique().tolist() - last_valid_model_selection = models.copy() # Per salvare l’ultima selezione valida - def enforce_model_selection(selected): - global last_valid_model_selection - if not selected: # Se nessuna metrica Γ¨ selezionata - return gr.update(value=last_valid_model_selection) - last_valid_model_selection = selected # Altrimenti aggiorna la selezione valida - return gr.update(value=selected) - - all_categories = df_initial['sql_tag'].unique().tolist() - - principal_categories = df_initial['test_category'].unique().tolist() - last_valid_category_selection = principal_categories.copy() # Per salvare l’ultima selezione valida - def enforce_category_selection(selected): - global last_valid_category_selection - if not selected: # Se nessuna metrica Γ¨ selezionata - return gr.update(value=last_valid_category_selection) - last_valid_category_selection = selected # Altrimenti aggiorna la selezione valida - return gr.update(value=selected) - - all_categories_as_dic = {cat: [f"{cat}"] for cat in principal_categories} - - all_categories_as_dic_ranking = {cat: [f"{cat}"] for cat in principal_categories} - all_categories_as_dic_ranking["All"] = principal_categories - - all_model_as_dic = {cat: [f"{cat}"] for cat in models} - all_model_as_dic["All"] = models - - ########################### - # VISUALIZATION SECTION # - ########################### - gr.Markdown("""# Model Performance Analysis""") - - #FOR BAR - gr.Markdown("""## Section 1: Model - Data""") - - with gr.Row(): - with gr.Column(scale=1): - with gr.Row(): - choose_metrics_bar = gr.Radio( - choices=list(all_metrics.keys()), - label="Select the metrics group that you want to use:", - value="Qatch" - ) - - with gr.Row(): - qatch_info = gr.HTML(""" -
- Qatch metric info ℹ️ -
- """, visible=True) - - external_info = gr.HTML(""" -
- External metric info ℹ️ -
- """, visible=False) - - qatch_metric_multiselect_bar = gr.CheckboxGroup( - choices=list(qatch_metrics_dict.keys()), - label="Select one or mode Qatch metrics:", - value=list(qatch_metrics_dict.keys()), - visible=True - ) - - external_metric_select_bar = gr.CheckboxGroup( - choices=list(external_metrics_dict.keys()), - label="Select one or more External metrics:", - visible=False - ) - - if(input_data['input_method'] == 'default'): - model_radio_bar = gr.Radio( - choices=list(all_model_as_dic.keys()), - label="Select the model that you want to use:", - value="All" - ) - else: - model_multiselect_bar = gr.CheckboxGroup( - choices=models, - label="Select one or more models:", - value=models, - interactive=len(models) > 1 - ) - - group_radio = gr.Radio( - choices=list(group_options.keys()), - label="Select the grouping view:", - value="Table" - ) - - def toggle_metric_selector(selected_type): - if selected_type == "Qatch": - return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True, value=list(qatch_metrics_dict.keys())), gr.update(visible=False, value=[]) - else: - return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False, value=[]), gr.update(visible=True, value=list(external_metrics_dict.keys())) - - output_plot = gr.Plot(visible=False) - - if(input_data['input_method'] == 'default'): - with gr.Row(): - lollipop_propietary(models) - - #FOR RADAR - gr.Markdown("""## Section 2: Model - Category""") - with gr.Row(): - all_metrics_radar = gr.Radio( - choices=list(all_metrics.keys()), - label="Select the metrics group that you want to use:", - value="Qatch" - ) - - model_multiselect_radar = gr.CheckboxGroup( - choices=models, - label="Select one or more models:", - value=models, - interactive=len(models) > 1 - ) - - with gr.Row(): - with gr.Column(scale=1): - category_multiselect_radar = gr.CheckboxGroup( - choices=principal_categories, - label="Select one or more categories:", - value=principal_categories - ) - with gr.Column(scale=1): - category_radio_radar = gr.Radio( - choices=list(all_categories_as_dic.keys()), - label="Select the metrics that you want to use:", - value=list(all_categories_as_dic.keys())[0] - ) - - with gr.Row(): - with gr.Column(scale=1): - radar_plot_multiselect = gr.Plot(value=update_radar(models, "Qatch", principal_categories)) - - with gr.Column(scale=1): - radar_plot_radio = gr.Plot(value=update_radar_sub(models, "Qatch", list(all_categories_as_dic.keys())[0])) - - #FOR RANKING - with gr.Row(): - all_metrics_ranking = gr.Radio( - choices=list(all_metrics.keys()), - label="Select the metrics group that you want to use:", - value="Qatch" - ) - model_choices = list(all_model_as_dic.keys()) - - if len(model_choices) == 2: - model_choices = [model_choices[0]] # supponiamo che il modello sia in prima posizione - selected_value = model_choices[0] - else: - selected_value = "All" - - model_radio_ranking = gr.Radio( - choices=model_choices, - label="Select the model that you want to use:", - value=selected_value - ) - - category_radio_ranking = gr.Radio( - choices=list(all_categories_as_dic_ranking.keys()), - label="Select the category that you want to use", - value="All" - ) - - with gr.Row(): - with gr.Column(scale=1): - gr.Markdown("## ❌ 3 Worst Cases\n") - - worst_first, worst_second, worst_third, raw_first, raw_second, raw_third = update_worst_cases_text("All", "Qatch", "All") - - with gr.Row(): - first = gr.Markdown(worst_first) - - with gr.Row(): - first_button = gr.Button("Show raw answer for πŸ₯‡") - - with gr.Row(): - second = gr.Markdown(worst_second) - - with gr.Row(): - second_button = gr.Button("Show raw answer for πŸ₯ˆ") - - with gr.Row(): - third = gr.Markdown(worst_third) - - with gr.Row(): - third_button = gr.Button("Show raw answer for πŸ₯‰") - - with gr.Column(scale=1): - gr.Markdown("""## Raw Answer""") - row_answer_first = gr.Markdown(value=raw_first, visible=True) - row_answer_second = gr.Markdown(value=raw_second, visible=False) - row_answer_third = gr.Markdown(value=raw_third, visible=False) - - #FOR RATE - gr.Markdown("""## Section 3: Time - Price""") - with gr.Row(): - model_multiselect_rate = gr.CheckboxGroup( - choices=models, - label="Select one or more models:", - value=models, - interactive=len(models) > 1 - ) - - - with gr.Row(): - slicer = gr.Slider(minimum=0, maximum=max(df_initial["number_question"]), step=1, value=max(df_initial["number_question"]), label="Number of instances to visualize", elem_id="custom-slider") - - query_rate_plot = gr.Plot(value=update_query_rate(models, len(df_initial["number_question"].unique()))) - - - #FOR RESET - reset_data = gr.Button("Back to upload data section") - - - - - ############################### - # CALLBACK FUNCTION SECTION # - ############################### - - #FOR BAR - def on_change(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_group, selected_models): - return update_plot(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, group_options[selected_group], selected_models) - - def on_change_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models): - return update_plot_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models) - - #FOR RADAR - def on_radar_multiselect_change(selected_models, selected_metrics, selected_categories): - return update_radar(selected_models, selected_metrics, selected_categories) - - def on_radar_radio_change(selected_models, selected_metrics, selected_category): - return update_radar_sub(selected_models, selected_metrics, selected_category) - - #FOR RANKING - def on_ranking_change(selected_models, selected_metrics, selected_categories): - return update_worst_cases_text(selected_models, selected_metrics, selected_categories) - - def show_first(): - return ( - gr.update(visible=True), - gr.update(visible=False), - gr.update(visible=False) - ) - - def show_second(): - return ( - gr.update(visible=False), - gr.update(visible=True), - gr.update(visible=False) - ) - - def show_third(): - return ( - gr.update(visible=False), - gr.update(visible=False), - gr.update(visible=True) - ) - - - - - ###################### - # ON CLICK SECTION # - ###################### - - #FOR BAR - if(input_data['input_method'] == 'default'): - proceed_to_metrics_button.click(on_change_propietary, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) - qatch_metric_multiselect_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) - external_metric_select_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) - model_radio_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) - qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar) - choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar]) - external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar) - - else: - proceed_to_metrics_button.click(on_change, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) - qatch_metric_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) - external_metric_select_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) - group_radio.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) - model_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) - qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar) - model_multiselect_bar.change(fn=enforce_model_selection, inputs=model_multiselect_bar, outputs=model_multiselect_bar) - choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar]) - external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar) - - - #FOR RADAR MULTISELECT - model_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) - all_metrics_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) - category_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) - model_multiselect_radar.change(fn=enforce_model_selection, inputs=model_multiselect_radar, outputs=model_multiselect_radar) - category_multiselect_radar.change(fn=enforce_category_selection, inputs=category_multiselect_radar, outputs=category_multiselect_radar) - - #FOR RADAR RADIO - model_multiselect_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) - all_metrics_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) - category_radio_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) - - #FOR RANKING - model_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) - model_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) - all_metrics_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) - all_metrics_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) - category_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) - category_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) - model_radio_ranking.change(fn=enforce_model_selection, inputs=model_radio_ranking, outputs=model_radio_ranking) - category_radio_ranking.change(fn=enforce_category_selection, inputs=category_radio_ranking, outputs=category_radio_ranking) - first_button.click(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) - second_button.click(fn=show_second, outputs=[row_answer_first, row_answer_second, row_answer_third]) - third_button.click(fn=show_third, outputs=[row_answer_first, row_answer_second, row_answer_third]) - - #FOR RATE - model_multiselect_rate.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) - proceed_to_metrics_button.click(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) - model_multiselect_rate.change(fn=enforce_model_selection, inputs=model_multiselect_rate, outputs=model_multiselect_rate) - slicer.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) - - #FOR RESET - reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) - reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics]) - reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection]) - - +import os +import sys +import time +import re +import csv +import gradio as gr +import pandas as pd +import numpy as np +import plotly.express as px +import plotly.graph_objects as go +import plotly.colors as pc +from qatch.connectors.sqlite_connector import SqliteConnector +from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator +from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator +import qatch.evaluate_dataset.orchestrator_evaluator as eva +from prediction import ModelPrediction +import utils_get_db_tables_info +import utilities as us +# @spaces.GPU +# def model_prediction(): +# pass +# # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10 +# if os.environ.get("SPACES_ZERO_GPU") is not None: +# import spaces +# else: +# class spaces: +# @staticmethod +# def GPU(func): +# def wrapper(*args, **kwargs): +# return func(*args, **kwargs) +# return wrapper +#pnp_path = os.path.join("data", "evaluation_p_np_metrics.csv") +pnp_path = "concatenated_output.csv" +PATH_PKL_TABLES = 'tables_dict_beaver.pkl' +PNP_TQA_PATH = 'concatenated_output_tqa.csv' +js_func = """ +function refresh() { + const url = new URL(window.location); + + if (url.searchParams.get('__theme') !== 'light') { + url.searchParams.set('__theme', 'light'); + window.location.href = url.href; + } +} +""" +reset_flag = False +flag_TQA = False + +with open('style.css', 'r') as file: + css = file.read() + +# DataFrame di default +df_default = pd.DataFrame({ + 'Name': ['Alice', 'Bob', 'Charlie'], + 'Age': [25, 30, 35], + 'City': ['New York', 'Los Angeles', 'Chicago'] +}) +models_path ="models.csv" + +# Variabile globale per tenere traccia dei dati correnti +df_current = df_default.copy() + +description = """## πŸ“Š Comparison of Proprietary and Non-Proprietary Databases + ### ➀ **Proprietary** : + ###     β‡’ Economic πŸ’°, Medical πŸ₯, Financial πŸ’³, Miscellaneous πŸ“‚ + ###     β‡’ BEAVER (FAC BUILDING ADDRESS 🏒 , TIME QUARTER ⏱️) + ### ➀ **Non-Proprietary** + ###     β‡’ Spider 1.0 πŸ•·οΈ""" +prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n" +prompt_default_tqa = "Return the answer of the following question based on the provided database. Return your answer as the result of a query executed over the database. Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n Return the answer in answer tag as .\n Question \n {question}\n Database Schema\n {db_schema}\n" + + +input_data = { + 'input_method': "", + 'data_path': "", + 'db_name': "", + 'data': { + 'data_frames': {}, # dictionary of dataframes + 'db': None, # SQLITE3 database object + 'selected_tables' :[] + }, + 'models': [], + 'prompt': prompt_default +} + +def load_data(file, path, use_default): + """Carica i dati da un file, un percorso o usa il DataFrame di default.""" + global df_current + if file is not None: + try: + input_data["input_method"] = 'uploaded_file' + input_data["db_name"] = os.path.splitext(os.path.basename(file))[0] + if file.endswith('.sqlite'): + #return 'Error: The uploaded file is not a valid SQLite database.' + input_data["data_path"] = file #os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite") + else: + #change path + input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite") + input_data["data"] = us.load_data(file, input_data["db_name"]) + + df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame + if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files + table2primary_key = {} + for table_name, df in input_data["data"]['data_frames'].items(): + # Assign primary keys for each table + table2primary_key[table_name] = 'id' + input_data["data"]["db"] = SqliteConnector( + relative_db_path=input_data["data_path"], + db_name=input_data["db_name"], + tables= input_data["data"]['data_frames'], + table2primary_key=table2primary_key + ) + return input_data["data"]['data_frames'] + except Exception as e: + return f'Errore nel caricamento del file: {e}' + if use_default: + if(use_default == 'Custom'): + input_data["input_method"] = 'custom' + #input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable_0.sqlite") + input_data["data_path"] = os.path.join(".","mytable_0.sqlite") + #if file already exist + while os.path.exists(input_data["data_path"]): + input_data["data_path"] = us.increment_filename(input_data["data_path"]) + input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0] + input_data["data"]['data_frames'] = {'MyTable': df_current} + + if(input_data["data"]['data_frames']): + table2primary_key = {} + for table_name, df in input_data["data"]['data_frames'].items(): + # Assign primary keys for each table + table2primary_key[table_name] = 'id' + input_data["data"]["db"] = SqliteConnector( + relative_db_path=input_data["data_path"], + db_name=input_data["db_name"], + tables= input_data["data"]['data_frames'], + table2primary_key=table2primary_key + ) + df_current = df_default.copy() # Ripristina i dati di default + return input_data["data"]['data_frames'] + + if(use_default == 'Proprietary vs Non-proprietary'): + input_data["input_method"] = 'default' + #input_data["data_path"] = os.path.join(".", "data", "data_interface", "default.sqlite") + #input_data["data_path"] = os.path.join(".", "data", "spider_databases", "defeault.sqlite") + #input_data["db_name"] = "default" + #input_data["data"]['db'] = SqliteConnector(relative_db_path=input_data["data_path"], db_name=input_data["db_name"]) + input_data["data"]['data_frames'] = us.load_tables_dict_from_pkl(PATH_PKL_TABLES) + return input_data["data"]['data_frames'] + + selected_inputs = sum([file is not None, bool(path), use_default]) + if selected_inputs > 1: + return 'Error: Select only one input method at a time.' + + return input_data["data"]['data_frames'] + +def preview_default(use_default, file): + if file: + return gr.DataFrame(interactive=True, visible = False, value = df_default), gr.update(value="## βœ… File successfully uploaded!", visible=True) + else : + if use_default == 'Custom': + return gr.DataFrame(interactive=True, visible = True, value = df_default), gr.update(value="## πŸ“ Toy Table", visible=True) + else: + return gr.DataFrame(interactive=False, visible = False, value = df_default), gr.update(value = description, visible=True) + #return gr.DataFrame(interactive=True, value = df_current) # Mostra il DataFrame corrente, che potrebbe essere stato modificato + +def update_df(new_df): + """Aggiorna il DataFrame corrente.""" + global df_current # Usa la variabile globale per aggiornarla + df_current = new_df + return df_current + +def open_accordion(target): + # Apre uno e chiude l'altro + if target == "reset": + df_current = df_default.copy() + input_data['input_method'] = "" + input_data['data_path'] = "" + input_data['db_name'] = "" + input_data['data']['data_frames'] = {} + input_data['data']['selected_tables'] = [] + input_data['data']['db'] = None + input_data['models'] = [] + return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value='Proprietary vs Non-proprietary'), gr.update(value=None) + elif target == "model_selection": + return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False) + +# Interfaccia Gradio +#with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: +with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as interface: + with gr.Row(): + with gr.Column(scale=1): + gr.Image( + value=os.path.join(".", "qatch_logo.png"), + show_label=False, + container=False, + interactive=False, + show_fullscreen_button=False, + show_download_button=False, + show_share_button=False, + height=150, # in pixel + width=300 + ) + with gr.Column(scale=1): + pass + data_state = gr.State(None) # Memorizza i dati caricati + upload_acc = gr.Accordion("Upload data section", open=True, visible=True) + select_table_acc = gr.Accordion("Select tables section", open=False, visible=False) + select_model_acc = gr.Accordion("Select models section", open=False, visible=False) + qatch_acc = gr.Accordion("QATCH execution section", open=False, visible=False) + metrics_acc = gr.Accordion("Metrics section", open=False, visible=False) + + ################################# + # DATABASE INSERTION # + ################################# + with upload_acc: + gr.Markdown("## πŸ“₯Choose data input method") + with gr.Row(): + default_checkbox = gr.Radio(label = "Explore the comparison between proprietary and non-proprietary databases or edit a toy table with the values you prefer", choices=['Proprietary vs Non-proprietary', 'Custom'], value='Proprietary vs Non-proprietary') + #default_checkbox = gr.Checkbox(label="Use default DataFrame" + + table_default = gr.Markdown(description, visible=True) + preview_output = gr.DataFrame(interactive=False, visible=False, value=df_default) + + gr.Markdown("## πŸ“‚ Or upload your data") + file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"]) + submit_button = gr.Button("Load Data") # Disabled by default + output = gr.JSON(visible=False) # Dictionary output + + # Function to enable the button if there is data to load + def enable_submit(file, use_default): + return gr.update(interactive=bool(file or use_default)) + + # Function to uncheck the checkbox if a file is uploaded + def deselect_default(file): + if file: + return gr.update(value='Proprietary vs Non-proprietary') + return gr.update() + + def enable_disable_first(enable): + return ( + gr.update(interactive=enable), + gr.update(interactive=enable), + gr.update(interactive=enable), + gr.update(interactive=enable) + ) + + # Enable the button when inputs are provided + #file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button]) + #default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button]) + + # Show preview of the default DataFrame when checkbox is selected + default_checkbox.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default]) + file_input.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default]) + preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output]) + + # Uncheck the checkbox when a file is uploaded + file_input.change(fn=deselect_default, inputs=[file_input], outputs=[default_checkbox]) + + def handle_output(file, use_default): + """Handles the output when the 'Load Data' button is pressed.""" + result = load_data(file, None, use_default) + + if isinstance(result, dict): # If result is a dictionary of DataFrames + if len(result) == 1: # If there's only one table + input_data['data']['selected_tables'] = list(input_data['data']['data_frames'].keys()) + return ( + gr.update(visible=False), # Hide JSON output + result, # Save the data state + gr.update(visible=False), # Hide table selection + result, # Maintain the data state + gr.update(interactive=False), # Disable the submit button + gr.update(visible=True, open=True), # Proceed to select_model_acc + gr.update(visible=True, open=False) + ) + else: + return ( + gr.update(visible=False), + result, + gr.update(open=True, visible=True), + result, + gr.update(interactive=False), + gr.update(visible=False), # Keep current behavior + gr.update(visible=True, open=False) + ) + else: + return ( + gr.update(visible=False), + None, + gr.update(open=False, visible=True), + None, + gr.update(interactive=True), + gr.update(visible=False), + gr.update(visible=True, open=False) + ) + + submit_button.click( + fn=handle_output, + inputs=[file_input, default_checkbox], + outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc] + ) + + submit_button.click( + fn=enable_disable_first, + inputs=[gr.State(False)], + outputs=[ + preview_output, + submit_button, + file_input, + default_checkbox + ] + ) + + ###################################### + # TABLE SELECTION PART # + ###################################### + with select_table_acc: + previous_selection = gr.State([]) + table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[]) + excluded_tables_info = gr.HTML(label="Non-selectable tables (too many columns)", visible=False) + table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)] + selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False) + + # Model selection button (initially disabled) + open_model_selection = gr.Button("Choose your models", interactive=False) + def update_table_list(data): + """Dynamically updates the list of available tables and excluded ones.""" + if isinstance(data, dict) and data: + table_names = [] + excluded_tables = [] + + data_frames = input_data['data'].get('data_frames', {}) + + available_tables = [] + for name, df in data.items(): + df_real = data_frames.get(name, None) + if input_data['input_method'] != "default": + if df_real is not None and df_real.shape[1] > 15: + excluded_tables.append(name) + else: + available_tables.append(name) + else: + available_tables.append(name) + + + if input_data['input_method'] == "default": + table_names.append("All") + excluded_tables = [] + elif len(available_tables) < 6: + table_names.append("All") + + table_names.extend(available_tables) + if excluded_tables and input_data['input_method'] != "default" : + excluded_text = "⚠️ The following tables have more than 15 columns and cannot be selected:
" + "
".join(f"- {t}" for t in excluded_tables) + excluded_visible = True + else: + excluded_text = "" + excluded_visible = False + + return [ + gr.update(choices=table_names, value=[]), # CheckboxGroup update + gr.update(value=excluded_text, visible=excluded_visible) # HTML display update + ] + + return [ + gr.update(choices=[], value=[]), + gr.update(value="", visible=False) + ] + + def show_selected_tables(data, selected_tables): + updates = [] + data_frames = input_data['data'].get('data_frames', {}) + + available_tables = [] + for name, df in data.items(): + df_real = data_frames.get(name) + if input_data['input_method'] != "default" : + if df_real is not None and df_real.shape[1] <= 15: + available_tables.append(name) + else: + available_tables.append(name) + + input_method = input_data['input_method'] + allow_all = input_method == "default" or len(available_tables) < 6 + + selected_set = set(selected_tables) + tables_set = set(available_tables) + + if allow_all: + if "All" in selected_set: + selected_tables = ["All"] + available_tables + elif selected_set == tables_set: + selected_tables = [] + else: + selected_tables = [t for t in selected_tables if t in available_tables] + else: + selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5] + + tables = {name: data[name] for name in selected_tables if name in data} + + for i, (name, df) in enumerate(tables.items()): + updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False)) + + for _ in range(len(tables), 50): + updates.append(gr.update(visible=False)) + + updates.append(gr.update(interactive=bool(tables))) + + if allow_all: + updates.insert(0, gr.update( + choices=["All"] + available_tables, + value=selected_tables + )) + else: + if len(selected_tables) >= 5: + updates.insert(0, gr.update( + choices=selected_tables, + value=selected_tables + )) + else: + updates.insert(0, gr.update( + choices=available_tables, + value=selected_tables + )) + + return updates + + def show_selected_table_names(data, selected_tables): + """Displays the names of the selected tables when the button is pressed.""" + if selected_tables: + available_tables = list(data.keys()) # Actually available names + if "All" in selected_tables: + selected_tables = available_tables + if (input_data['input_method'] != "default") : selected_tables = [t for t in selected_tables if len(data[t].columns) <= 15] + + input_data['data']['selected_tables'] = selected_tables + return gr.update(value=", ".join(selected_tables), visible=False) + return gr.update(value="", visible=False) + + # Automatically updates the checkbox list when `data_state` changes + data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector, excluded_tables_info]) + + # Updates the visible tables and the button state based on user selections + #table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection]) + table_selector.change( + fn=show_selected_tables, + inputs=[data_state, table_selector], + outputs=[table_selector] + table_outputs + [open_model_selection] + ) + # Shows the list of selected tables when "Choose your models" is clicked + open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names]) + open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc]) + + reset_data = gr.Button("Back to upload data section") + + reset_data.click( + fn=enable_disable_first, + inputs=[gr.State(True)], + outputs=[ + preview_output, + submit_button, + file_input, + default_checkbox + ] + ) + reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) + + #################################### + # MODEL SELECTION PART # + #################################### + with select_model_acc: + gr.Markdown("# Model Selection") + + # Assume that `us.read_models_csv` also returns the image path + model_list_dict = us.read_models_csv(models_path) + model_list = [model["code"] for model in model_list_dict] + model_images = [model["image_path"] for model in model_list_dict] + model_names = [model["name"] for model in model_list_dict] + # Create a mapping between model_list and model_images_names + model_mapping = dict(zip(model_list, model_names)) + model_mapping_reverse = dict(zip(model_names, model_list)) + + model_checkboxes = [] + rows = [] + + # Dynamically create checkboxes with images (3 per row) + for i in range(0, len(model_list), 3): + with gr.Row(): + cols = [] + for j in range(3): + if i + j < len(model_list): + model = model_list[i + j] + image_path = model_images[i + j] + with gr.Column(): + gr.Image(image_path, + show_label=False, + container=False, + interactive=False, + show_fullscreen_button=False, + show_download_button=False, + show_share_button=False) + checkbox = gr.Checkbox(label=model_mapping[model], value=False) + model_checkboxes.append(checkbox) + cols.append(checkbox) + rows.append(cols) + + selected_models_output = gr.JSON(visible=False) + + # Function to get selected models + def get_selected_models(*model_selections): + selected_models = [model for model, selected in zip(model_list, model_selections) if selected] + input_data['models'] = selected_models + button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) + return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state) + + # Add the Textbox to the interface + with gr.Row(): + button_prompt_nlsql = gr.Button("Choose NL2SQL task") + button_prompt_tqa = gr.Button("Choose TQA task") + + prompt = gr.TextArea( + label="Customise the prompt for selected models here or leave the default one.", + placeholder=prompt_default, + elem_id="custom-textarea" + ) + + warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False) + + # Submit button (initially disabled) + with gr.Row(): + submit_models_button = gr.Button("Submit Models", interactive=False) + + def check_prompt(prompt): + #TODO + missing_elements = [] + if(prompt==""): + global flag_TQA + if not flag_TQA: + input_data["prompt"] = prompt_default + else: + input_data["prompt"] = prompt_default_tqa + button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) + else: + input_data["prompt"] = prompt + if "{db_schema}" not in prompt: + missing_elements.append("{db_schema}") + if "{question}" not in prompt: + missing_elements.append("{question}") + button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) + if missing_elements: + return gr.update( + value=f"
" + f"❌ Missing {', '.join(missing_elements)} in the prompt ❌
", + visible=True + ), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"]) + return gr.update(visible=False), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"]) + + prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button]) + # Link checkboxes to selection events + for checkbox in model_checkboxes: + checkbox.change( + fn=get_selected_models, + inputs=model_checkboxes, + outputs=[selected_models_output, select_model_acc, submit_models_button] + ) + prompt.change( + fn=get_selected_models, + inputs=model_checkboxes, + outputs=[selected_models_output, select_model_acc, submit_models_button] + ) + + submit_models_button.click( + fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)), + inputs=model_checkboxes, + outputs=[selected_models_output, select_model_acc, qatch_acc] + ) + + def change_flag(): + global flag_TQA + flag_TQA = True + + def dis_flag(): + global flag_TQA + flag_TQA = False + + button_prompt_tqa.click(fn = change_flag, inputs=[], outputs=[]) + + button_prompt_nlsql.click(fn = dis_flag, inputs=[], outputs=[]) + + button_prompt_tqa.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt]) + + button_prompt_nlsql.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt]) + + + def enable_disable(enable): + return ( + *[gr.update(interactive=enable) for _ in model_checkboxes], + gr.update(interactive=enable), + gr.update(interactive=enable), + gr.update(interactive=enable), + gr.update(interactive=enable), + gr.update(interactive=enable), + gr.update(interactive=enable), + *[gr.update(interactive=enable) for _ in table_outputs], + gr.update(interactive=enable) + ) + + reset_data = gr.Button("Back to upload data section") + + submit_models_button.click( + fn=enable_disable, + inputs=[gr.State(False)], + outputs=[ + *model_checkboxes, + submit_models_button, + preview_output, + submit_button, + file_input, + default_checkbox, + table_selector, + *table_outputs, + open_model_selection + ] + ) + + reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) + + reset_data.click( + fn=enable_disable, + inputs=[gr.State(True)], + outputs=[ + *model_checkboxes, + submit_models_button, + preview_output, + submit_button, + file_input, + default_checkbox, + table_selector, + *table_outputs, + open_model_selection + ] + ) + + ############################# + # QATCH EXECUTION # + ############################# + with qatch_acc: + def change_text(text): + return text + + loading_symbols= {1:"π“†Ÿ", + 2: "π“†ž π“†Ÿ", + 3: "𓆛 π“†ž π“†Ÿ", + 4: "π“†ž 𓆛 π“†ž π“†Ÿ", + 5: "π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", + 6: "π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", + 7: "π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", + 8: "π“†ž π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", + 9: "π“†Ÿ π“†ž π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", + 10:"π“†ž π“†Ÿ π“†ž π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", + } + + def generate_loading_text(percent): + num_symbols = (round(percent) % 11) + 1 + symbols = loading_symbols.get(num_symbols, "π“†Ÿ") + mirrored_symbols = f'{symbols.strip()}' + css_symbols = f'{symbols.strip()}' + return f""" +
+ {css_symbols} + + Generation {percent}% + + {mirrored_symbols} +
+ """ + + def generate_eval_text(text): + symbols = "𓆑 " + mirrored_symbols = f'{symbols.strip()}' + css_symbols = f'{symbols.strip()}' + return f""" +
+ {css_symbols} + + {text} + + {mirrored_symbols} +
+ """ + + def qatch_flow_nl_sql(): + global reset_flag + global flag_TQA + predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list} + metrics_conc = pd.DataFrame() + columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"] + if (input_data['input_method']=="default"): + #target_df = us.load_csv(pnp_path) #target_df = us.load_csv("priority_non_priority_metrics.csv") + target_df = us.load_csv(pnp_path) if not flag_TQA else us.load_csv(PNP_TQA_PATH) + #predictions_dict = {model: pd.DataFrame(columns=target_df.columns) for model in model_list} + target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])] + target_df = target_df[target_df["model"].isin(input_data['models'])] + predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list} + reset_flag = False + for model in input_data['models']: + model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None) + yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] + count=1 + for _, row in predictions_dict[model].iterrows(): + #for index, row in target_df.iterrows(): + if (reset_flag == False): + percent_complete = round(count / len(predictions_dict[model]) * 100, 2) + count=count+1 + load_text = f"{generate_loading_text(percent_complete)}" + question = row['question'] + + display_question = f"""
Natural Language:
+
+
{question}
+
➑️
+
+ """ + yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] + + prediction = row['predicted_sql'] + + display_prediction = f"""
Predicted SQL:
+
+
➑️
+
{prediction}
+
+ """ + + yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] + yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] + metrics_conc = target_df + if 'valid_efficency_score' not in metrics_conc.columns: + metrics_conc['valid_efficency_score'] = metrics_conc['VES'] + if 'VES' not in metrics_conc.columns: + metrics_conc['VES'] = metrics_conc['valid_efficency_score'] + eval_text = generate_eval_text("End evaluation") + yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] + + else: + # global flag_TQA + orchestrator_generator = OrchestratorGenerator() + target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables']) + + #create target_df[target_answer] + if flag_TQA : + # if (input_data["prompt"] == prompt_default): + # input_data["prompt"] = prompt_default_tqa + + target_df['db_schema'] = target_df.apply( + lambda row: utils_get_db_tables_info.utils_extract_db_schema_as_string( + db_id=input_data["db_name"], + base_path=input_data["data_path"], + normalize=False, + sql=row["query"], + get_insert_into=True, + model=None, + prompt=input_data["prompt"].format(question=row["question"], db_schema="") + ), + axis=1 + ) + + target_df = us.extract_answer(target_df) + + predictor = ModelPrediction() + reset_flag = False + for model in input_data["models"]: + model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None) + yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] + count=0 + for index, row in target_df.iterrows(): + if (reset_flag == False): + percent_complete = round(((index+1) / len(target_df)) * 100, 2) + load_text = f"{generate_loading_text(percent_complete)}" + + question = row['question'] + display_question = f"""
Natural Language:
+
+
{question}
+
➑️
+
+ """ + yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] + #samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"]) + model_to_send = None if not flag_TQA else model + + db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string( + db_id = input_data["db_name"], + base_path = input_data["data_path"], + normalize=False, + sql=row["query"], + get_insert_into=True, + model = model_to_send, + prompt = input_data["prompt"].format(question=question, db_schema=""), + ) + + #prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples) + prompt_to_send = input_data["prompt"] + #PREDICTION SQL + + # TODO add button for QA or SP and pass to .make_prediction parameter TASK + if flag_TQA: task="QA" + else: task="SP" + start_time = time.time() + response = predictor.make_prediction( + question=question, + db_schema=db_schema_text, + model_name=model, + prompt=f"{prompt_to_send}", + task=task + ) + #if flag_TQA: response = {'response_parsed': "[['Alice'],['Bob'],['Charlie']]", 'cost': 0, 'response': "[['Alice'],['Bob'],['Charlie']]"} # TODO remove this line + #else : response = {'response_parsed': "SELECT * FROM 'MyTable'", 'cost': 0, 'response': "SQL_QUERY"} + end_time = time.time() + prediction = response['response_parsed'] + price = response['cost'] + answer = response['response'] + + if flag_TQA: + task_string = "Answer" + else: + task_string = "SQL" + + display_prediction = f"""
Predicted {task_string}:
+
+
➑️
+
{prediction}
+
+ """ + # Create a new row as dataframe + new_row = pd.DataFrame([{ + 'id': index, + 'question': question, + 'predicted_sql': prediction, + 'time': end_time - start_time, + 'query': row["query"], + 'db_path': input_data["data_path"], + 'price':price, + 'answer': answer, + 'number_question':count, + 'target_answer' : row["target_answer"] if flag_TQA else None, + + }]).dropna(how="all") # Remove only completely empty rows + count=count+1 + # TODO: use a for loop + if (flag_TQA) : + new_row['predicted_answer'] = prediction + for col in target_df.columns: + if col not in new_row.columns: + new_row[col] = row[col] + # Update model's prediction dataframe incrementally + if not new_row.empty: + predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True) + + # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None + yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list] + yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list] + # END + eval_text = generate_eval_text("Evaluation") + yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] + + evaluator = OrchestratorEvaluator() + + for model in input_data["models"]: + if not flag_TQA: + metrics_df_model = evaluator.evaluate_df( + df=predictions_dict[model], + target_col_name="query", + prediction_col_name="predicted_sql", + db_path_name="db_path" + ) + else: + metrics_df_model = us.evaluate_answer(predictions_dict[model]) + metrics_df_model['model'] = model + metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True) + + if 'VES' not in metrics_conc.columns and 'valid_efficency_score' not in metrics_conc.columns: + metrics_conc['VES'] = 0 + metrics_conc['valid_efficency_score'] = 0 + + if 'valid_efficency_score' not in metrics_conc.columns: + metrics_conc['valid_efficency_score'] = metrics_conc['VES'] + + if 'VES' not in metrics_conc.columns: + metrics_conc['VES'] = metrics_conc['valid_efficency_score'] + + eval_text = generate_eval_text("End evaluation") + yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] + + # Loading Bar + with gr.Row(): + # progress = gr.Progress() + variable = gr.Markdown() + + # NL -> MODEL -> Generated Query + with gr.Row(): + with gr.Column(): + with gr.Column(): + question_display = gr.Markdown() + with gr.Column(): + model_logo = gr.Image(visible=True, + show_label=False, + container=False, + interactive=False, + show_fullscreen_button=False, + show_download_button=False, + show_share_button=False) + with gr.Column(): + with gr.Column(): + prediction_display = gr.Markdown() + + dataframe_per_model = {} + + with gr.Tabs() as model_tabs: + tab_dict = {} + + for model, model_name in zip(model_list, model_names): + with gr.TabItem(model_name, visible=(model in input_data["models"])) as tab: + gr.Markdown(f"**Results for {model}**") + tab_dict[model] = tab + dataframe_per_model[model] = gr.DataFrame() + #TODO download metrics per model + # download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False) + + evaluation_loading = gr.Markdown() + + def change_tab(): + return [gr.update(visible=(model in input_data["models"])) for model in model_list] + + submit_models_button.click( + change_tab, + inputs=[], + outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility + ) + + selected_models_display = gr.JSON(label="Final input data", visible=False) + metrics_df = gr.DataFrame(visible=False) + metrics_df_out = gr.DataFrame(visible=False) + + submit_models_button.click( + fn=qatch_flow_nl_sql, + inputs=[], + outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values()) + ) + + submit_models_button.click( + fn=lambda: gr.update(value=input_data), + outputs=[selected_models_display] + ) + + # Works for METRICS + metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out]) + + proceed_to_metrics_button = gr.Button("Proceed to Metrics", visible=False) + proceed_to_metrics_button.click( + fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)), + outputs=[qatch_acc, metrics_acc] + ) + + def allow_download(metrics_df_out): + #path = os.path.join(".", "data", "data_results", "results.csv") + path = os.path.join(".", "results.csv") + metrics_df_out.to_csv(path, index=False) + return gr.update(value=path, visible=True), gr.update(visible=True), gr.update(interactive=True) + + download_metrics = gr.DownloadButton(label="Download Metrics Evaluation", visible=False) + + submit_models_button.click( + fn=lambda: gr.update(visible=False), + outputs=[download_metrics] + ) + + def refresh(): + global reset_flag + global flag_TQA + reset_flag = True + flag_TQA = False + + reset_data = gr.Button("Back to upload data section", interactive=True) + + metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data]) + + reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) + #WHY NOT WORKING? + reset_data.click( + fn=lambda: gr.update(visible=False), + outputs=[download_metrics] + ) + reset_data.click(refresh) + + reset_data.click( + fn=enable_disable, + inputs=[gr.State(True)], + outputs=[ + *model_checkboxes, + submit_models_button, + preview_output, + submit_button, + file_input, + default_checkbox, + table_selector, + *table_outputs, + open_model_selection + ] + ) + + ########################################## + # METRICS VISUALIZATION SECTION # + ########################################## + with metrics_acc: + #data_path = 'test_results_metrics1.csv' + + @gr.render(inputs=metrics_df_out) + def function_metrics(metrics_df_out): + + #################################### + # UTILS FUNCTIONS SECTION # + #################################### + + def load_data_csv_es(): + + if input_data["input_method"]=="default": + global flag_TQA + #df = pd.read_csv(pnp_path) + df = us.load_csv(pnp_path) if not flag_TQA else us.load_csv(PNP_TQA_PATH) + df = df[df['model'].isin(input_data["models"])] + df = df[df['tbl_name'].isin(input_data["data"]["selected_tables"])] + + df['model'] = df['model'].replace('DeepSeek-R1-Distill-Llama-70B', 'DS-Llama3 70B') + df['model'] = df['model'].replace('gpt-3.5', 'GPT-3.5') + df['model'] = df['model'].replace('gpt-4o-mini', 'GPT-4o-mini') + df['model'] = df['model'].replace('llama-70', 'Llama-70B') + df['model'] = df['model'].replace('llama-8', 'Llama-8B') + df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY') + #if (flag_TQA) : flag_TQA = False #TODO delete after make pred + return df + return metrics_df_out + + def calculate_average_metrics(df, selected_metrics): + # Exclude the 'tuple_order' column from the selected metrics + + #TODO tuple_order has NULL VALUE + selected_metrics = [metric for metric in selected_metrics if metric != 'tuple_order'] + #print(df[selected_metrics]) + df['avg_metric'] = df[selected_metrics].mean(axis=1) + return df + + def generate_model_colors(): + """Generates a unique color map for models in the dataset.""" + df = load_data_csv_es() + unique_models = df['model'].unique() # Extract unique models + num_models = len(unique_models) + + # Use the Plotly color scale (you can change it if needed) + color_palette = ['#00B4D8', '#BCE784', '#C84630', '#F79256', '#D269FC'] + #color_palette = pc.qualitative.Plotly # ['#636EFA', '#EF553B', '#00CC96', ...] + + # If there are more models than colors, cycle through them + colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)} + + return colors + + MODEL_COLORS = generate_model_colors() + + def generate_db_category_colors(): + """Assigns 3 distinct colors to db_category groups.""" + return { + "Spider": "#1f77b4", # blu + "Beaver": "#ff7f0e", # arancione + "Economic": "#2ca02c", # tutti gli altri verdi + "Financial": "#2ca02c", + "Medical": "#2ca02c", + "Miscellaneous": "#2ca02c" + } + + DB_CATEGORY_COLORS = generate_db_category_colors() + + def normalize_valid_efficency_score(df): + df['valid_efficency_score'] = df['valid_efficency_score'].replace([np.nan, ''], 0) + df['valid_efficency_score'] = df['valid_efficency_score'].astype(int) + min_val = df['valid_efficency_score'].min() + max_val = df['valid_efficency_score'].max() + + if min_val == max_val : + # All values are equal, so for avoid division by zero, we set the score to 1/0 + if min_val == None: + df['valid_efficency_score'] = 0 + else: + df['valid_efficency_score'] = 1.0 + else: + df['valid_efficency_score'] = ( + df['valid_efficency_score'] - min_val + ) / (max_val - min_val) + + return df + + + #################################### + # GRAPH FUNCTIONS SECTION # + #################################### + + # BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION + def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models): + df = df[df['model'].isin(selected_models)] + df = normalize_valid_efficency_score(df) + + # Mappatura nomi leggibili -> tecnici + qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics] + external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric] + + selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal + + df = calculate_average_metrics(df, selected_metrics) + + if group_by == ["model"]: + # Bar plot per "model" + avg_metrics = df.groupby("model")['avg_metric'].mean().reset_index() + avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') + + fig = px.bar( + avg_metrics, + x="model", + y="avg_metric", + color="model", + color_discrete_map=MODEL_COLORS, + title='Average metrics per Model 🧠', + labels={"model": "Model", "avg_metric": "Average Metrics"}, + template='simple_white', + #template='plotly_dark', + text='text_label' + ) + else: + if group_by != ["tbl_name", "model"]: + group_by = ["tbl_name", "model"] + + avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index() + avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') + + fig = px.bar( + avg_metrics, + x=group_by[0], + y='avg_metric', + color='model', + color_discrete_map=MODEL_COLORS, + barmode='group', + title=f'Average metrics per {group_by[0]} πŸ“Š', + labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metrics'}, + template='simple_white', + #template='plotly_dark', + text='text_label' + ) + + fig.update_traces(textposition='outside', textfont_size=10) + + # Applica font Inter a tutto il layout + fig.update_layout( + margin=dict(t=80), + title=dict( + font=dict( + family="Inter, sans-serif", + size=22, + #color="white" + ), + x=0.5 + ), + xaxis=dict( + title=dict( + font=dict( + family="Inter, sans-serif", + size=18, + #color="white" + ) + ), + tickfont=dict( + family="Inter, sans-serif", + #color="white" + size=16 + ) + ), + yaxis=dict( + title=dict( + font=dict( + family="Inter, sans-serif", + size=18, + #color="white" + ) + ), + tickfont=dict( + family="Inter, sans-serif", + #color="white" + ) + ), + legend=dict( + title=dict( + font=dict( + family="Inter, sans-serif", + size=16, + #color="white" + ) + ), + font=dict( + family="Inter, sans-serif", + #color="white" + ) + ) + ) + + return gr.Plot(fig, visible=True) + + def update_plot(radio_metric, qatch_selected_metrics, external_selected_metric,group_by, selected_models): + df = load_data_csv_es() + return plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models) + + # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION + def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models): + if selected_models == "All": + selected_models = models + else: + selected_models = [selected_models] + + df = df[df['model'].isin(selected_models)] + df = normalize_valid_efficency_score(df) + + # Converti nomi leggibili -> tecnici + qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics] + external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric] + + selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal + + df = calculate_average_metrics(df, selected_metrics) + + avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index() + avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') + fig = px.bar( + avg_metrics, + x='db_category', + y='avg_metric', + color='model', + color_discrete_map=MODEL_COLORS, + barmode='group', + title='Average metrics per database types πŸ“Š', + labels={'db_path': 'DB Path', 'avg_metric': 'Average Metrics'}, + template='simple_white', + text='text_label' + ) + + fig.update_traces(textposition='outside', textfont_size=14) + + # Aggiorna layout con font Inter + fig.update_layout( + margin=dict(t=80), + title=dict( + font=dict( + family="Inter, sans-serif", + size=24, + color="black" + ), + x=0.5 + ), + xaxis=dict( + title=dict( + text='Database Category', + font=dict( + family='Inter, sans-serif', + size=22, + color='black' + ) + ), + tickfont=dict( + family='Inter, sans-serif', + color='black', + size=20 + ) + ), + yaxis=dict( + title=dict( + text='Average Metrics', + font=dict( + family='Inter, sans-serif', + size=22, + color='black' + ) + ), + tickfont=dict( + family='Inter, sans-serif', + color='black' + ) + ), + legend=dict( + title=dict( + text='Models', + font=dict( + family='Inter, sans-serif', + size=20, + color='black' + ) + ), + font=dict( + family='Inter, sans-serif', + color='black', + size=18 + ) + ) + ) + + return gr.Plot(fig, visible=True) + + def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models): + df = load_data_csv_es() + return plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models) + + # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION + + def lollipop_propietary(selected_models): + df = load_data_csv_es() + + # Filtra solo le categorie rilevanti + target_cats = ["Spider", "Economic", "Financial", "Medical", "Miscellaneous", "Beaver"] + df = df[df['db_category'].isin(target_cats)] + df = df[df['model'].isin(selected_models)] + + df = normalize_valid_efficency_score(df) + df = calculate_average_metrics(df, qatch_metrics) + + # Calcola la media per db_category e modello + avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index() + + # Separa Spider e le altre 4 categorie + spider_df = avg_metrics[avg_metrics["db_category"] == "Spider"] + other_df = avg_metrics[avg_metrics["db_category"] != "Spider"] + + # Calcola media delle altre categorie per ciascun modello + other_mean_df = other_df.groupby("model")["avg_metric"].mean().reset_index() + other_mean_df["db_category"] = "Others" + + # Rinominare per chiarezza e uniformitΓ  + spider_df = spider_df.rename(columns={"avg_metric": "Spider"}) + other_mean_df = other_mean_df.rename(columns={"avg_metric": "Others"}) + + # Unione dei due dataset + merged_df = pd.merge(spider_df[["model", "Spider"]], other_mean_df[["model", "Others"]], on="model") + + # Ordina per modello o per valore se vuoi + merged_df = merged_df.sort_values(by="model") + + fig = go.Figure() + + # Aggiungi linee orizzontali tra Spider e Others + for _, row in merged_df.iterrows(): + fig.add_trace(go.Scatter( + x=[row["Spider"], row["Others"]], + y=[row["model"]] * 2, + mode='lines', + line=dict(color='gray', width=2), + showlegend=False + )) + + # Punto per Spider + fig.add_trace(go.Scatter( + x=merged_df["Spider"], + y=merged_df["model"], + mode='markers', + name='Non-Proprietary (Spider)', + marker=dict(size=10, color='#C84630') + )) + + # Punto per Others (media delle altre 4 categorie) + fig.add_trace(go.Scatter( + x=merged_df["Others"], + y=merged_df["model"], + mode='markers', + name='Proprietary Databases', + marker=dict(size=10, color='#0077B6') + )) + + fig.update_layout( + xaxis_title='Average Metrics', + yaxis_title='Models', + template='simple_white', + #template='plotly_dark', + margin=dict(t=80), + title=dict( + font=dict( + family="Inter, sans-serif", + size=22, + color="black" + ), + x=0.5, + text='Dumbbell graph: Non-Proprietary (Spider πŸ•·οΈ) vs Proprietary Databases πŸ“Š' + ), + legend_title='Type of Databases:', + height=600, + xaxis=dict( + title=dict( + text='DB Category', + font=dict( + family='Inter, sans-serif', + size=18, + color='black' + ) + ), + tickfont=dict( + family='Inter, sans-serif', + color='black' + ) + ), + yaxis=dict( + title=dict( + text='Average Metrics', + font=dict( + family='Inter, sans-serif', + size=18, + color='black' + ) + ), + tickfont=dict( + family='Inter, sans-serif', + color='black' + ) + ), + legend=dict( + title=dict( + text='Models', + font=dict( + family='Inter, sans-serif', + size=18, + color='black' + ) + ), + font=dict( + family='Inter, sans-serif', + color='black', + size=14 + ) + ) + ) + + return gr.Plot(fig, visible=True) + + + # RADAR OR BAR CHART BASED ON CATEGORY COUNT + def plot_radar(df, selected_models, selected_metrics, selected_categories): + if "External" in selected_metrics: + selected_metrics = ["execution_accuracy", "valid_efficency_score"] + else: + selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] + + # Filtro modelli e normalizzazione + df = df[df['model'].isin(selected_models)] + df = normalize_valid_efficency_score(df) + df = calculate_average_metrics(df, selected_metrics) + + avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index() + + if avg_metrics.empty: + print("Error: No data available to compute averages.") + return go.Figure() + + categories = selected_categories + + if len(categories) < 3: + # πŸ”„ BAR PLOT + fig = go.Figure() + for model in selected_models: + model_data = avg_metrics[avg_metrics['model'] == model] + values = [ + model_data[model_data['test_category'] == cat]['avg_metric'].values[0] + if cat in model_data['test_category'].values else 0 + for cat in categories + ] + fig.add_trace(go.Bar( + x=categories, + y=values, + name=model, + marker=dict(color=MODEL_COLORS.get(model, "gray")) + )) + + fig.update_layout( + barmode='group', + title=dict( + text='πŸ“Š Bar Plot of Metrics per Model (Few Categories)', + font=dict( + family='Inter, sans-serif', + size=22, + #color='white' + ), + x=0.5 + ), + template='simple_white', + #template='plotly_dark', + xaxis=dict( + title=dict( + text='Test Category', + font=dict( + family='Inter, sans-serif', + size=18, + #color='white' + ) + ), + tickfont=dict( + family='Inter, sans-serif', + size=16 + #color='white' + ) + ), + yaxis=dict( + title=dict( + text='Average Metrics', + font=dict( + family='Inter, sans-serif', + size=18, + #color='white' + ) + ), + tickfont=dict( + family='Inter, sans-serif', + #color='white' + ) + ), + legend=dict( + title=dict( + text='Models', + font=dict( + family='Inter, sans-serif', + size=16, + #color='white' + ) + ), + font=dict( + family='Inter, sans-serif', + #color='white' + ) + ) + ) + else: + # 🧭 RADAR PLOT + fig = go.Figure() + for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True): + model_data = avg_metrics[avg_metrics['model'] == model] + values = [ + model_data[model_data['test_category'] == cat]['avg_metric'].values[0] + if cat in model_data['test_category'].values else 0 + for cat in categories + ] + fig.add_trace(go.Scatterpolar( + r=values, + theta=categories, + fill='toself', + name=model, + line=dict(color=MODEL_COLORS.get(model, "gray")) + )) + + fig.update_layout( + polar=dict( + radialaxis=dict( + visible=True, + range=[0, max(avg_metrics['avg_metric'].max(), 0.5)], + tickfont=dict( + family='Inter, sans-serif', + #color='white' + ) + ), + angularaxis=dict( + tickfont=dict( + family='Inter, sans-serif', + size=16 + #color='white' + ) + ) + ), + title=dict( + text='❇️ Radar Plot of Metrics per Model (Average per SQL Category)', + font=dict( + family='Inter, sans-serif', + size=22, + #color='white' + ), + x=0.5 + ), + legend=dict( + title=dict( + text='Models', + font=dict( + family='Inter, sans-serif', + size=18, + #color='white' + ) + ), + font=dict( + family='Inter, sans-serif', + size=16 + #color='white' + ) + ), + template='simple_white' + #template='plotly_dark' + ) + + return fig + + def update_radar(selected_models, selected_metrics, selected_categories): + df = load_data_csv_es() + return plot_radar(df, selected_models, selected_metrics, selected_categories) + + # RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT + def plot_radar_sub(df, selected_models, selected_metrics, selected_category): + if "External" in selected_metrics: + selected_metrics = ["execution_accuracy", "valid_efficency_score"] + else: + selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] + + df = df[df['model'].isin(selected_models)] + df = normalize_valid_efficency_score(df) + df = calculate_average_metrics(df, selected_metrics) + + if isinstance(selected_category, str): + selected_category = [selected_category] + + df = df[df['test_category'].isin(selected_category)] + avg_metrics = df.groupby(['model', 'sql_tag'])['avg_metric'].mean().reset_index() + + if avg_metrics.empty: + print("Error: No data available to compute averages.") + return go.Figure() + + categories = df['sql_tag'].unique().tolist() + + if len(categories) < 3: + # πŸ”„ BAR PLOT + fig = go.Figure() + for model in selected_models: + model_data = avg_metrics[avg_metrics['model'] == model] + values = [ + model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0] + if cat in model_data['sql_tag'].values else 0 + for cat in categories + ] + fig.add_trace(go.Bar( + x=categories, + y=values, + name=model, + marker=dict(color=MODEL_COLORS.get(model, "gray")) + )) + + fig.update_layout( + barmode='group', + title=dict( + text='πŸ“Š Bar Plot of Metrics per Model (Few Sub-Categories)', + font=dict( + family='Inter, sans-serif', + size=22, + #color='white' + ), + x=0.5 + ), + template='simple_white', + #template='plotly_dark', + xaxis=dict( + title=dict( + text='SQL Tag (Sub Category)', + font=dict( + family='Inter, sans-serif', + size=18, + #color='white' + ) + ), + tickfont=dict( + family='Inter, sans-serif', + #color='white' + ) + ), + yaxis=dict( + title=dict( + text='Average Metrics', + font=dict( + family='Inter, sans-serif', + size=18, + #color='white' + ) + ), + tickfont=dict( + family='Inter, sans-serif', + #color='white' + ) + ), + legend=dict( + title=dict( + text='Models', + font=dict( + family='Inter, sans-serif', + size=16, + #color='white' + ) + ), + font=dict( + family='Inter, sans-serif', + size=14 + #color='white' + ) + ) + ) + else: + # 🧭 RADAR PLOT + fig = go.Figure() + + for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True): + model_data = avg_metrics[avg_metrics['model'] == model] + values = [ + model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0] + if cat in model_data['sql_tag'].values else 0 + for cat in categories + ] + + fig.add_trace(go.Scatterpolar( + r=values, + theta=categories, + fill='toself', + name=model, + line=dict(color=MODEL_COLORS.get(model, "gray")) + )) + + fig.update_layout( + polar=dict( + radialaxis=dict( + visible=True, + range=[0, max(avg_metrics['avg_metric'].max(), 0.5)], + tickfont=dict( + family='Inter, sans-serif', + #color='white' + ) + ), + angularaxis=dict( + tickfont=dict( + family='Inter, sans-serif', + size=16 + #color='white' + ) + ) + ), + title=dict( + text='❇️ Radar Plot of Metrics per Model (Average per SQL Sub-Category)', + font=dict( + family='Inter, sans-serif', + size=22, + #color='white' + ), + x=0.5 + ), + legend=dict( + title=dict( + text='Models', + font=dict( + family='Inter, sans-serif', + size=16, + #color='white' + ) + ), + font=dict( + family='Inter, sans-serif', + size=14, + #color='white' + ) + ), + template='simple_white' + #template='plotly_dark' + ) + + return fig + + def update_radar_sub(selected_models, selected_metrics, selected_category): + df = load_data_csv_es() + return plot_radar_sub(df, selected_models, selected_metrics, selected_category) + + # RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION + def worst_cases_text(df, selected_models, selected_metrics, selected_categories): + global flag_TQA + if selected_models == "All": + selected_models = models + else: + selected_models = [selected_models] + + if selected_categories == "All": + selected_categories = principal_categories + else: + selected_categories = [selected_categories] + + df = df[df['model'].isin(selected_models)] + df = df[df['test_category'].isin(selected_categories)] + + if "external" in selected_metrics: + selected_metrics = ["execution_accuracy", "valid_efficency_score"] + else: + selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] + + df = normalize_valid_efficency_score(df) + df = calculate_average_metrics(df, selected_metrics) + + if flag_TQA: + df["target_answer"] = df["target_answer"] = df["target_answer"].apply(lambda x: "[" + ", ".join(map(str, x)) + "]") + + worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'target_answer', 'predicted_answer', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index() + else: + worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index() + + worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True) + + worst_cases_top_3 = worst_cases_df.head(3) + + worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2) + + worst_str = [] + answer_str = [] + + medals = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"] + + for i, row in worst_cases_top_3.iterrows(): + if flag_TQA: + entry = ( + f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n" + f"- Question: {row['question']} \n" + f"- Original Answer: `{row['target_answer']}` \n" + f"- Predicted Answer: `{eval(row['predicted_answer'])}` \n\n" + ) + + worst_str.append(entry) + else: + entry = ( + f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n" + f"- Question: {row['question']} \n" + f"- Original Query: `{row['query']}` \n" + f"- Predicted SQL: `{row['predicted_sql']}` \n\n" + ) + + worst_str.append(entry) + + raw_answer = ( + f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n" + f"- Raw Answer:
`{row['answer']}`
\n" + ) + + answer_str.append(raw_answer) + + return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2] + + def update_worst_cases_text(selected_models, selected_metrics, selected_categories): + df = load_data_csv_es() + return worst_cases_text(df, selected_models, selected_metrics, selected_categories) + + # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION + def plot_cumulative_flow(df, selected_models, max_points): + df = df[df['model'].isin(selected_models)] + df = normalize_valid_efficency_score(df) + + fig = go.Figure() + + for model in selected_models: + model_df = df[df['model'] == model].copy() + + # Limita il numero di punti se richiesto + if max_points is not None: + model_df = model_df.head(max_points + 1) + + # Tooltip personalizzato + model_df['hover_info'] = model_df.apply( + lambda row: + f"Id question: {row['number_question']}
" + f"Question: {row['question']}
" + f"Target: {row['query']}
" + f"Prediction: {row['predicted_sql']}
" + f"Category: {row['test_category']}", + axis=1 + ) + + # Calcoli cumulativi + model_df['cumulative_time'] = model_df['time'].cumsum() + model_df['cumulative_price'] = model_df['price'].cumsum() + + # Colore del modello + color = MODEL_COLORS.get(model, "gray") + + fig.add_trace(go.Scatter( + x=model_df['cumulative_time'], + y=model_df['cumulative_price'], + mode='lines+markers', + name=model, + line=dict(width=2, color=color), + customdata=model_df['hover_info'], + hovertemplate= + "Model: " + model + "
" + + "Cumulative Time: %{x}s
" + + "Cumulative Price: $%{y:.2f}
" + + "
Details:
%{customdata}" + )) + + # Layout con font elegante + fig.update_layout( + title=dict( + text="Cumulative Price Flow Chart πŸ’°", + font=dict( + family="Inter, sans-serif", + size=24, + #color="white" + ), + x=0.5 + ), + xaxis=dict( + title=dict( + text="Cumulative Time (s)", + font=dict( + family="Inter, sans-serif", + size=20, + #color="white" + ) + ), + tickfont=dict( + family="Inter, sans-serif", + size=18 + #color="white" + ) + ), + yaxis=dict( + title=dict( + text="Cumulative Price ($)", + font=dict( + family="Inter, sans-serif", + size=20, + #color="white" + ) + ), + tickfont=dict( + family="Inter, sans-serif", + size=18 + #color="white" + ) + ), + legend=dict( + title=dict( + text="Models", + font=dict( + family="Inter, sans-serif", + size=18, + #color="white" + ) + ), + font=dict( + family="Inter, sans-serif", + size=16, + #color="white" + ) + ), + template='simple_white', + #template="plotly_dark" + ) + + return fig + + def update_query_rate(selected_models, max_points): + df = load_data_csv_es() + return plot_cumulative_flow(df, selected_models, max_points) + + + + + ####################### + # PARAMETER SECTION # + ####################### + qatch_metrics_dict = { + "Cell Precision": "cell_precision", + "Cell Recall": "cell_recall", + "Tuple Order": "tuple_order", + "Tuple Cardinality": "tuple_cardinality", + "Tuple Constraint": "tuple_constraint" + } + + qatch_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] + last_valid_qatch_metrics_selection = qatch_metrics.copy() # Per salvare l’ultima selezione valida + def enforce_qatch_metrics_selection(selected): + global last_valid_qatch_metrics_selection + if not selected: # Se nessuna metrica Γ¨ selezionata + return gr.update(value=last_valid_qatch_metrics_selection) + last_valid_qatch_metrics_selection = selected # Altrimenti aggiorna la selezione valida + return gr.update(value=selected) + + external_metrics_dict = { + "Execution Accuracy": "execution_accuracy", + "Valid Efficency Score": "valid_efficency_score" + } + + external_metric = ["execution_accuracy", "valid_efficency_score"] + last_valid_external_metric_selection = external_metric.copy() + def enforce_external_metric_selection(selected): + global last_valid_external_metric_selection + if not selected: # Se nessuna metrica Γ¨ selezionata + return gr.update(value=last_valid_external_metric_selection) + last_valid_external_metric_selection = selected # Altrimenti aggiorna la selezione valida + return gr.update(value=selected) + + all_metrics = { + "Qatch": ["qatch"], + "External": ["external"] + } + + group_options = { + "Table": ["tbl_name", "model"], + "Model": ["model"] + } + + df_initial = load_data_csv_es() + models = models = df_initial['model'].unique().tolist() + last_valid_model_selection = models.copy() # Per salvare l’ultima selezione valida + def enforce_model_selection(selected): + global last_valid_model_selection + if not selected: # Se nessuna metrica Γ¨ selezionata + return gr.update(value=last_valid_model_selection) + last_valid_model_selection = selected # Altrimenti aggiorna la selezione valida + return gr.update(value=selected) + + all_categories = df_initial['sql_tag'].unique().tolist() + + principal_categories = df_initial['test_category'].unique().tolist() + last_valid_category_selection = principal_categories.copy() # Per salvare l’ultima selezione valida + def enforce_category_selection(selected): + global last_valid_category_selection + if not selected: # Se nessuna metrica Γ¨ selezionata + return gr.update(value=last_valid_category_selection) + last_valid_category_selection = selected # Altrimenti aggiorna la selezione valida + return gr.update(value=selected) + + all_categories_as_dic = {cat: [f"{cat}"] for cat in principal_categories} + + all_categories_as_dic_ranking = {cat: [f"{cat}"] for cat in principal_categories} + all_categories_as_dic_ranking["All"] = principal_categories + + all_model_as_dic = {cat: [f"{cat}"] for cat in models} + all_model_as_dic["All"] = models + + ########################### + # VISUALIZATION SECTION # + ########################### + gr.Markdown("""# Model Performance Analysis""") + + #FOR BAR + gr.Markdown("""## Section 1: Model - Data""") + + with gr.Row(): + with gr.Column(scale=1): + with gr.Row(): + choose_metrics_bar = gr.Radio( + choices=list(all_metrics.keys()), + label="Select the metrics group that you want to use:", + value="Qatch" + ) + + with gr.Row(): + qatch_info = gr.HTML(""" +
+ Qatch metric info ℹ️ +
+ """, visible=True) + + external_info = gr.HTML(""" +
+ External metric info ℹ️ +
+ """, visible=False) + + qatch_metric_multiselect_bar = gr.CheckboxGroup( + choices=list(qatch_metrics_dict.keys()), + label="Select one or mode Qatch metrics:", + value=list(qatch_metrics_dict.keys()), + visible=True + ) + + external_metric_select_bar = gr.CheckboxGroup( + choices=list(external_metrics_dict.keys()), + label="Select one or more External metrics:", + visible=False + ) + + if(input_data['input_method'] == 'default'): + model_radio_bar = gr.Radio( + choices=list(all_model_as_dic.keys()), + label="Select the model that you want to use:", + value="All" + ) + else: + model_multiselect_bar = gr.CheckboxGroup( + choices=models, + label="Select one or more models:", + value=models, + interactive=len(models) > 1 + ) + + group_radio = gr.Radio( + choices=list(group_options.keys()), + label="Select the grouping view:", + value="Table" + ) + + def toggle_metric_selector(selected_type): + if selected_type == "Qatch": + return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True, value=list(qatch_metrics_dict.keys())), gr.update(visible=False, value=[]) + else: + return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False, value=[]), gr.update(visible=True, value=list(external_metrics_dict.keys())) + + output_plot = gr.Plot(visible=False) + + if(input_data['input_method'] == 'default'): + with gr.Row(): + lollipop_propietary(models) + + #FOR RADAR + gr.Markdown("""## Section 2: Model - Category""") + with gr.Row(): + all_metrics_radar = gr.Radio( + choices=list(all_metrics.keys()), + label="Select the metrics group that you want to use:", + value="Qatch" + ) + + model_multiselect_radar = gr.CheckboxGroup( + choices=models, + label="Select one or more models:", + value=models, + interactive=len(models) > 1 + ) + + with gr.Row(): + with gr.Column(scale=1): + category_multiselect_radar = gr.CheckboxGroup( + choices=principal_categories, + label="Select one or more categories:", + value=principal_categories + ) + with gr.Column(scale=1): + category_radio_radar = gr.Radio( + choices=list(all_categories_as_dic.keys()), + label="Select the metrics that you want to use:", + value=list(all_categories_as_dic.keys())[0] + ) + + with gr.Row(): + with gr.Column(scale=1): + radar_plot_multiselect = gr.Plot(value=update_radar(models, "Qatch", principal_categories)) + + with gr.Column(scale=1): + radar_plot_radio = gr.Plot(value=update_radar_sub(models, "Qatch", list(all_categories_as_dic.keys())[0])) + + #FOR RANKING + with gr.Row(): + all_metrics_ranking = gr.Radio( + choices=list(all_metrics.keys()), + label="Select the metrics group that you want to use:", + value="Qatch" + ) + model_choices = list(all_model_as_dic.keys()) + + if len(model_choices) == 2: + model_choices = [model_choices[0]] # supponiamo che il modello sia in prima posizione + selected_value = model_choices[0] + else: + selected_value = "All" + + model_radio_ranking = gr.Radio( + choices=model_choices, + label="Select the model that you want to use:", + value=selected_value + ) + + category_radio_ranking = gr.Radio( + choices=list(all_categories_as_dic_ranking.keys()), + label="Select the category that you want to use", + value="All" + ) + + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("## ❌ 3 Worst Cases\n") + + worst_first, worst_second, worst_third, raw_first, raw_second, raw_third = update_worst_cases_text("All", "Qatch", "All") + + with gr.Row(): + first = gr.Markdown(worst_first) + + with gr.Row(): + first_button = gr.Button("Show raw answer for πŸ₯‡") + + with gr.Row(): + second = gr.Markdown(worst_second) + + with gr.Row(): + second_button = gr.Button("Show raw answer for πŸ₯ˆ") + + with gr.Row(): + third = gr.Markdown(worst_third) + + with gr.Row(): + third_button = gr.Button("Show raw answer for πŸ₯‰") + + with gr.Column(scale=1): + gr.Markdown("""## Raw Answer""") + row_answer_first = gr.Markdown(value=raw_first, visible=True) + row_answer_second = gr.Markdown(value=raw_second, visible=False) + row_answer_third = gr.Markdown(value=raw_third, visible=False) + + #FOR RATE + gr.Markdown("""## Section 3: Time - Price""") + with gr.Row(): + model_multiselect_rate = gr.CheckboxGroup( + choices=models, + label="Select one or more models:", + value=models, + interactive=len(models) > 1 + ) + + + with gr.Row(): + slicer = gr.Slider(minimum=0, maximum=max(df_initial["number_question"]), step=1, value=max(df_initial["number_question"]), label="Number of instances to visualize", elem_id="custom-slider") + + query_rate_plot = gr.Plot(value=update_query_rate(models, len(df_initial["number_question"].unique()))) + + + #FOR RESET + reset_data = gr.Button("Back to upload data section") + + + + + ############################### + # CALLBACK FUNCTION SECTION # + ############################### + + #FOR BAR + def on_change(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_group, selected_models): + return update_plot(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, group_options[selected_group], selected_models) + + def on_change_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models): + return update_plot_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models) + + #FOR RADAR + def on_radar_multiselect_change(selected_models, selected_metrics, selected_categories): + return update_radar(selected_models, selected_metrics, selected_categories) + + def on_radar_radio_change(selected_models, selected_metrics, selected_category): + return update_radar_sub(selected_models, selected_metrics, selected_category) + + #FOR RANKING + def on_ranking_change(selected_models, selected_metrics, selected_categories): + return update_worst_cases_text(selected_models, selected_metrics, selected_categories) + + def show_first(): + return ( + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=False) + ) + + def show_second(): + return ( + gr.update(visible=False), + gr.update(visible=True), + gr.update(visible=False) + ) + + def show_third(): + return ( + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=True) + ) + + + + + ###################### + # ON CLICK SECTION # + ###################### + + #FOR BAR + if(input_data['input_method'] == 'default'): + proceed_to_metrics_button.click(on_change_propietary, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) + qatch_metric_multiselect_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) + external_metric_select_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) + model_radio_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) + qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar) + choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar]) + external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar) + + else: + proceed_to_metrics_button.click(on_change, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) + qatch_metric_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) + external_metric_select_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) + group_radio.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) + model_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) + qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar) + model_multiselect_bar.change(fn=enforce_model_selection, inputs=model_multiselect_bar, outputs=model_multiselect_bar) + choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar]) + external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar) + + + #FOR RADAR MULTISELECT + model_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) + all_metrics_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) + category_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) + model_multiselect_radar.change(fn=enforce_model_selection, inputs=model_multiselect_radar, outputs=model_multiselect_radar) + category_multiselect_radar.change(fn=enforce_category_selection, inputs=category_multiselect_radar, outputs=category_multiselect_radar) + + #FOR RADAR RADIO + model_multiselect_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) + all_metrics_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) + category_radio_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) + + #FOR RANKING + model_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) + model_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) + all_metrics_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) + all_metrics_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) + category_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) + category_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) + model_radio_ranking.change(fn=enforce_model_selection, inputs=model_radio_ranking, outputs=model_radio_ranking) + category_radio_ranking.change(fn=enforce_category_selection, inputs=category_radio_ranking, outputs=category_radio_ranking) + first_button.click(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) + second_button.click(fn=show_second, outputs=[row_answer_first, row_answer_second, row_answer_third]) + third_button.click(fn=show_third, outputs=[row_answer_first, row_answer_second, row_answer_third]) + + #FOR RATE + model_multiselect_rate.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) + proceed_to_metrics_button.click(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) + model_multiselect_rate.change(fn=enforce_model_selection, inputs=model_multiselect_rate, outputs=model_multiselect_rate) + slicer.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) + + #FOR RESET + reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) + reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics]) + reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection]) + + interface.launch(share = True) \ No newline at end of file