diff --git "a/app.py" "b/app.py"
--- "a/app.py"
+++ "b/app.py"
@@ -1,2355 +1,2355 @@
-import os
-import sys
-import time
-import re
-import csv
-import gradio as gr
-import pandas as pd
-import numpy as np
-import plotly.express as px
-import plotly.graph_objects as go
-import plotly.colors as pc
-from qatch.connectors.sqlite_connector import SqliteConnector
-from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
-from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
-import qatch.evaluate_dataset.orchestrator_evaluator as eva
-from prediction import ModelPrediction
-import utils_get_db_tables_info
-import utilities as us
-# @spaces.GPU
-# def model_prediction():
-# pass
-# # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
-# if os.environ.get("SPACES_ZERO_GPU") is not None:
-# import spaces
-# else:
-# class spaces:
-# @staticmethod
-# def GPU(func):
-# def wrapper(*args, **kwargs):
-# return func(*args, **kwargs)
-# return wrapper
-#pnp_path = os.path.join("data", "evaluation_p_np_metrics.csv")
-pnp_path = "concatenated_output.csv"
-PATH_PKL_TABLES = 'tables_dict_beaver.pkl'
-PNP_TQA_PATH = 'concatenated_output_tqa.csv'
-js_func = """
-function refresh() {
- const url = new URL(window.location);
-
- if (url.searchParams.get('__theme') !== 'light') {
- url.searchParams.set('__theme', 'light');
- window.location.href = url.href;
- }
-}
-"""
-reset_flag = False
-flag_TQA = False
-
-with open('style.css', 'r') as file:
- css = file.read()
-
-# DataFrame di default
-df_default = pd.DataFrame({
- 'Name': ['Alice', 'Bob', 'Charlie'],
- 'Age': [25, 30, 35],
- 'City': ['New York', 'Los Angeles', 'Chicago']
-})
-models_path ="models.csv"
-
-# Variabile globale per tenere traccia dei dati correnti
-df_current = df_default.copy()
-
-description = """## π Comparison of Proprietary and Non-Proprietary Databases
- ### β€ **Proprietary** :
- ### β Economic π°, Medical π₯, Financial π³, Miscellaneous π
- ### β BEAVER (FAC BUILDING ADDRESS π’ , TIME QUARTER β±οΈ)
- ### β€ **Non-Proprietary**
- ### β Spider 1.0 π·οΈ"""
-prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n"
-prompt_default_tqa = "Return the answer of the following question based on the provided database. Return your answer as the result of a query executed over the database. Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n Return the answer in answer tag as .\n Question \n {question}\n Database Schema\n {db_schema}\n"
-
-
-input_data = {
- 'input_method': "",
- 'data_path': "",
- 'db_name': "",
- 'data': {
- 'data_frames': {}, # dictionary of dataframes
- 'db': None, # SQLITE3 database object
- 'selected_tables' :[]
- },
- 'models': [],
- 'prompt': prompt_default
-}
-
-def load_data(file, path, use_default):
- """Carica i dati da un file, un percorso o usa il DataFrame di default."""
- global df_current
- if file is not None:
- try:
- input_data["input_method"] = 'uploaded_file'
- input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
- if file.endswith('.sqlite'):
- #return 'Error: The uploaded file is not a valid SQLite database.'
- input_data["data_path"] = file #os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
- else:
- #change path
- input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
- input_data["data"] = us.load_data(file, input_data["db_name"])
-
- df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
- if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
- table2primary_key = {}
- for table_name, df in input_data["data"]['data_frames'].items():
- # Assign primary keys for each table
- table2primary_key[table_name] = 'id'
- input_data["data"]["db"] = SqliteConnector(
- relative_db_path=input_data["data_path"],
- db_name=input_data["db_name"],
- tables= input_data["data"]['data_frames'],
- table2primary_key=table2primary_key
- )
- return input_data["data"]['data_frames']
- except Exception as e:
- return f'Errore nel caricamento del file: {e}'
- if use_default:
- if(use_default == 'Custom'):
- input_data["input_method"] = 'custom'
- #input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable_0.sqlite")
- input_data["data_path"] = os.path.join(".","mytable_0.sqlite")
- #if file already exist
- while os.path.exists(input_data["data_path"]):
- input_data["data_path"] = us.increment_filename(input_data["data_path"])
- input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0]
- input_data["data"]['data_frames'] = {'MyTable': df_current}
-
- if(input_data["data"]['data_frames']):
- table2primary_key = {}
- for table_name, df in input_data["data"]['data_frames'].items():
- # Assign primary keys for each table
- table2primary_key[table_name] = 'id'
- input_data["data"]["db"] = SqliteConnector(
- relative_db_path=input_data["data_path"],
- db_name=input_data["db_name"],
- tables= input_data["data"]['data_frames'],
- table2primary_key=table2primary_key
- )
- df_current = df_default.copy() # Ripristina i dati di default
- return input_data["data"]['data_frames']
-
- if(use_default == 'Proprietary vs Non-proprietary'):
- input_data["input_method"] = 'default'
- #input_data["data_path"] = os.path.join(".", "data", "data_interface", "default.sqlite")
- #input_data["data_path"] = os.path.join(".", "data", "spider_databases", "defeault.sqlite")
- #input_data["db_name"] = "default"
- #input_data["data"]['db'] = SqliteConnector(relative_db_path=input_data["data_path"], db_name=input_data["db_name"])
- input_data["data"]['data_frames'] = us.load_tables_dict_from_pkl(PATH_PKL_TABLES)
- return input_data["data"]['data_frames']
-
- selected_inputs = sum([file is not None, bool(path), use_default])
- if selected_inputs > 1:
- return 'Error: Select only one input method at a time.'
-
- return input_data["data"]['data_frames']
-
-def preview_default(use_default, file):
- if file:
- return gr.DataFrame(interactive=True, visible = False, value = df_default), gr.update(value="## β
File successfully uploaded!", visible=True)
- else :
- if use_default == 'Custom':
- return gr.DataFrame(interactive=True, visible = True, value = df_default), gr.update(value="## π Toy Table", visible=True)
- else:
- return gr.DataFrame(interactive=False, visible = False, value = df_default), gr.update(value = description, visible=True)
- #return gr.DataFrame(interactive=True, value = df_current) # Mostra il DataFrame corrente, che potrebbe essere stato modificato
-
-def update_df(new_df):
- """Aggiorna il DataFrame corrente."""
- global df_current # Usa la variabile globale per aggiornarla
- df_current = new_df
- return df_current
-
-def open_accordion(target):
- # Apre uno e chiude l'altro
- if target == "reset":
- df_current = df_default.copy()
- input_data['input_method'] = ""
- input_data['data_path'] = ""
- input_data['db_name'] = ""
- input_data['data']['data_frames'] = {}
- input_data['data']['selected_tables'] = []
- input_data['data']['db'] = None
- input_data['models'] = []
- return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value='Proprietary vs Non-proprietary'), gr.update(value=None)
- elif target == "model_selection":
- return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False)
-
-# Interfaccia Gradio
-#with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
-with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as interface:
- with gr.Row():
- with gr.Column(scale=1):
- gr.Image(
- value=os.path.join(".", "qatch_logo.png"),
- show_label=False,
- container=False,
- interactive=False,
- show_fullscreen_button=False,
- show_download_button=False,
- show_share_button=False,
- height=150, # in pixel
- width=300
- )
- with gr.Column(scale=1):
- pass
- data_state = gr.State(None) # Memorizza i dati caricati
- upload_acc = gr.Accordion("Upload data section", open=True, visible=True)
- select_table_acc = gr.Accordion("Select tables section", open=False, visible=False)
- select_model_acc = gr.Accordion("Select models section", open=False, visible=False)
- qatch_acc = gr.Accordion("QATCH execution section", open=False, visible=False)
- metrics_acc = gr.Accordion("Metrics section", open=False, visible=False)
-
- #################################
- # DATABASE INSERTION #
- #################################
- with upload_acc:
- gr.Markdown("## π₯Choose data input method")
- with gr.Row():
- default_checkbox = gr.Radio(label = "Explore the comparison between proprietary and non-proprietary databases or edit a toy table with the values you prefer", choices=['Proprietary vs Non-proprietary', 'Custom'], value='Proprietary vs Non-proprietary')
- #default_checkbox = gr.Checkbox(label="Use default DataFrame"
-
- table_default = gr.Markdown(description, visible=True)
- preview_output = gr.DataFrame(interactive=False, visible=False, value=df_default)
-
- gr.Markdown("## π Or upload your data")
- file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"])
- submit_button = gr.Button("Load Data") # Disabled by default
- output = gr.JSON(visible=False) # Dictionary output
-
- # Function to enable the button if there is data to load
- def enable_submit(file, use_default):
- return gr.update(interactive=bool(file or use_default))
-
- # Function to uncheck the checkbox if a file is uploaded
- def deselect_default(file):
- if file:
- return gr.update(value='Proprietary vs Non-proprietary')
- return gr.update()
-
- def enable_disable_first(enable):
- return (
- gr.update(interactive=enable),
- gr.update(interactive=enable),
- gr.update(interactive=enable),
- gr.update(interactive=enable)
- )
-
- # Enable the button when inputs are provided
- #file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
- #default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
-
- # Show preview of the default DataFrame when checkbox is selected
- default_checkbox.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default])
- file_input.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default])
- preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output])
-
- # Uncheck the checkbox when a file is uploaded
- file_input.change(fn=deselect_default, inputs=[file_input], outputs=[default_checkbox])
-
- def handle_output(file, use_default):
- """Handles the output when the 'Load Data' button is pressed."""
- result = load_data(file, None, use_default)
-
- if isinstance(result, dict): # If result is a dictionary of DataFrames
- if len(result) == 1: # If there's only one table
- input_data['data']['selected_tables'] = list(input_data['data']['data_frames'].keys())
- return (
- gr.update(visible=False), # Hide JSON output
- result, # Save the data state
- gr.update(visible=False), # Hide table selection
- result, # Maintain the data state
- gr.update(interactive=False), # Disable the submit button
- gr.update(visible=True, open=True), # Proceed to select_model_acc
- gr.update(visible=True, open=False)
- )
- else:
- return (
- gr.update(visible=False),
- result,
- gr.update(open=True, visible=True),
- result,
- gr.update(interactive=False),
- gr.update(visible=False), # Keep current behavior
- gr.update(visible=True, open=False)
- )
- else:
- return (
- gr.update(visible=False),
- None,
- gr.update(open=False, visible=True),
- None,
- gr.update(interactive=True),
- gr.update(visible=False),
- gr.update(visible=True, open=False)
- )
-
- submit_button.click(
- fn=handle_output,
- inputs=[file_input, default_checkbox],
- outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc]
- )
-
- submit_button.click(
- fn=enable_disable_first,
- inputs=[gr.State(False)],
- outputs=[
- preview_output,
- submit_button,
- file_input,
- default_checkbox
- ]
- )
-
- ######################################
- # TABLE SELECTION PART #
- ######################################
- with select_table_acc:
- previous_selection = gr.State([])
- table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
- excluded_tables_info = gr.HTML(label="Non-selectable tables (too many columns)", visible=False)
- table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
- selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
-
- # Model selection button (initially disabled)
- open_model_selection = gr.Button("Choose your models", interactive=False)
- def update_table_list(data):
- """Dynamically updates the list of available tables and excluded ones."""
- if isinstance(data, dict) and data:
- table_names = []
- excluded_tables = []
-
- data_frames = input_data['data'].get('data_frames', {})
-
- available_tables = []
- for name, df in data.items():
- df_real = data_frames.get(name, None)
- if input_data['input_method'] != "default":
- if df_real is not None and df_real.shape[1] > 15:
- excluded_tables.append(name)
- else:
- available_tables.append(name)
- else:
- available_tables.append(name)
-
-
- if input_data['input_method'] == "default":
- table_names.append("All")
- excluded_tables = []
- elif len(available_tables) < 6:
- table_names.append("All")
-
- table_names.extend(available_tables)
- if excluded_tables and input_data['input_method'] != "default" :
- excluded_text = "β οΈ The following tables have more than 15 columns and cannot be selected:
" + "
".join(f"- {t}" for t in excluded_tables)
- excluded_visible = True
- else:
- excluded_text = ""
- excluded_visible = False
-
- return [
- gr.update(choices=table_names, value=[]), # CheckboxGroup update
- gr.update(value=excluded_text, visible=excluded_visible) # HTML display update
- ]
-
- return [
- gr.update(choices=[], value=[]),
- gr.update(value="", visible=False)
- ]
-
- def show_selected_tables(data, selected_tables):
- updates = []
- data_frames = input_data['data'].get('data_frames', {})
-
- available_tables = []
- for name, df in data.items():
- df_real = data_frames.get(name)
- if input_data['input_method'] != "default" :
- if df_real is not None and df_real.shape[1] <= 15:
- available_tables.append(name)
- else:
- available_tables.append(name)
-
- input_method = input_data['input_method']
- allow_all = input_method == "default" or len(available_tables) < 6
-
- selected_set = set(selected_tables)
- tables_set = set(available_tables)
-
- if allow_all:
- if "All" in selected_set:
- selected_tables = ["All"] + available_tables
- elif selected_set == tables_set:
- selected_tables = []
- else:
- selected_tables = [t for t in selected_tables if t in available_tables]
- else:
- selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5]
-
- tables = {name: data[name] for name in selected_tables if name in data}
-
- for i, (name, df) in enumerate(tables.items()):
- updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
-
- for _ in range(len(tables), 50):
- updates.append(gr.update(visible=False))
-
- updates.append(gr.update(interactive=bool(tables)))
-
- if allow_all:
- updates.insert(0, gr.update(
- choices=["All"] + available_tables,
- value=selected_tables
- ))
- else:
- if len(selected_tables) >= 5:
- updates.insert(0, gr.update(
- choices=selected_tables,
- value=selected_tables
- ))
- else:
- updates.insert(0, gr.update(
- choices=available_tables,
- value=selected_tables
- ))
-
- return updates
-
- def show_selected_table_names(data, selected_tables):
- """Displays the names of the selected tables when the button is pressed."""
- if selected_tables:
- available_tables = list(data.keys()) # Actually available names
- if "All" in selected_tables:
- selected_tables = available_tables
- if (input_data['input_method'] != "default") : selected_tables = [t for t in selected_tables if len(data[t].columns) <= 15]
-
- input_data['data']['selected_tables'] = selected_tables
- return gr.update(value=", ".join(selected_tables), visible=False)
- return gr.update(value="", visible=False)
-
- # Automatically updates the checkbox list when `data_state` changes
- data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector, excluded_tables_info])
-
- # Updates the visible tables and the button state based on user selections
- #table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
- table_selector.change(
- fn=show_selected_tables,
- inputs=[data_state, table_selector],
- outputs=[table_selector] + table_outputs + [open_model_selection]
- )
- # Shows the list of selected tables when "Choose your models" is clicked
- open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names])
- open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
-
- reset_data = gr.Button("Back to upload data section")
-
- reset_data.click(
- fn=enable_disable_first,
- inputs=[gr.State(True)],
- outputs=[
- preview_output,
- submit_button,
- file_input,
- default_checkbox
- ]
- )
- reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
-
- ####################################
- # MODEL SELECTION PART #
- ####################################
- with select_model_acc:
- gr.Markdown("# Model Selection")
-
- # Assume that `us.read_models_csv` also returns the image path
- model_list_dict = us.read_models_csv(models_path)
- model_list = [model["code"] for model in model_list_dict]
- model_images = [model["image_path"] for model in model_list_dict]
- model_names = [model["name"] for model in model_list_dict]
- # Create a mapping between model_list and model_images_names
- model_mapping = dict(zip(model_list, model_names))
- model_mapping_reverse = dict(zip(model_names, model_list))
-
- model_checkboxes = []
- rows = []
-
- # Dynamically create checkboxes with images (3 per row)
- for i in range(0, len(model_list), 3):
- with gr.Row():
- cols = []
- for j in range(3):
- if i + j < len(model_list):
- model = model_list[i + j]
- image_path = model_images[i + j]
- with gr.Column():
- gr.Image(image_path,
- show_label=False,
- container=False,
- interactive=False,
- show_fullscreen_button=False,
- show_download_button=False,
- show_share_button=False)
- checkbox = gr.Checkbox(label=model_mapping[model], value=False)
- model_checkboxes.append(checkbox)
- cols.append(checkbox)
- rows.append(cols)
-
- selected_models_output = gr.JSON(visible=False)
-
- # Function to get selected models
- def get_selected_models(*model_selections):
- selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
- input_data['models'] = selected_models
- button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
- return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
-
- # Add the Textbox to the interface
- with gr.Row():
- button_prompt_nlsql = gr.Button("Choose NL2SQL task")
- button_prompt_tqa = gr.Button("Choose TQA task")
-
- prompt = gr.TextArea(
- label="Customise the prompt for selected models here or leave the default one.",
- placeholder=prompt_default,
- elem_id="custom-textarea"
- )
-
- warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False)
-
- # Submit button (initially disabled)
- with gr.Row():
- submit_models_button = gr.Button("Submit Models", interactive=False)
-
- def check_prompt(prompt):
- #TODO
- missing_elements = []
- if(prompt==""):
- global flag_TQA
- if not flag_TQA:
- input_data["prompt"] = prompt_default
- else:
- input_data["prompt"] = prompt_default_tqa
- button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
- else:
- input_data["prompt"] = prompt
- if "{db_schema}" not in prompt:
- missing_elements.append("{db_schema}")
- if "{question}" not in prompt:
- missing_elements.append("{question}")
- button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
- if missing_elements:
- return gr.update(
- value=f"
"
- f"β Missing {', '.join(missing_elements)} in the prompt β
",
- visible=True
- ), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"])
- return gr.update(visible=False), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"])
-
- prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button])
- # Link checkboxes to selection events
- for checkbox in model_checkboxes:
- checkbox.change(
- fn=get_selected_models,
- inputs=model_checkboxes,
- outputs=[selected_models_output, select_model_acc, submit_models_button]
- )
- prompt.change(
- fn=get_selected_models,
- inputs=model_checkboxes,
- outputs=[selected_models_output, select_model_acc, submit_models_button]
- )
-
- submit_models_button.click(
- fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
- inputs=model_checkboxes,
- outputs=[selected_models_output, select_model_acc, qatch_acc]
- )
-
- def change_flag():
- global flag_TQA
- flag_TQA = True
-
- def dis_flag():
- global flag_TQA
- flag_TQA = False
-
- button_prompt_tqa.click(fn = change_flag, inputs=[], outputs=[])
-
- button_prompt_nlsql.click(fn = dis_flag, inputs=[], outputs=[])
-
- button_prompt_tqa.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt])
-
- button_prompt_nlsql.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt])
-
-
- def enable_disable(enable):
- return (
- *[gr.update(interactive=enable) for _ in model_checkboxes],
- gr.update(interactive=enable),
- gr.update(interactive=enable),
- gr.update(interactive=enable),
- gr.update(interactive=enable),
- gr.update(interactive=enable),
- gr.update(interactive=enable),
- *[gr.update(interactive=enable) for _ in table_outputs],
- gr.update(interactive=enable)
- )
-
- reset_data = gr.Button("Back to upload data section")
-
- submit_models_button.click(
- fn=enable_disable,
- inputs=[gr.State(False)],
- outputs=[
- *model_checkboxes,
- submit_models_button,
- preview_output,
- submit_button,
- file_input,
- default_checkbox,
- table_selector,
- *table_outputs,
- open_model_selection
- ]
- )
-
- reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
-
- reset_data.click(
- fn=enable_disable,
- inputs=[gr.State(True)],
- outputs=[
- *model_checkboxes,
- submit_models_button,
- preview_output,
- submit_button,
- file_input,
- default_checkbox,
- table_selector,
- *table_outputs,
- open_model_selection
- ]
- )
-
- #############################
- # QATCH EXECUTION #
- #############################
- with qatch_acc:
- def change_text(text):
- return text
-
- loading_symbols= {1:"π",
- 2: "π π",
- 3: "π π π",
- 4: "π π π π",
- 5: "π π π π π",
- 6: "π π π π π π",
- 7: "π π π π π π π",
- 8: "π π π π π π π π",
- 9: "π π π π π π π π π",
- 10:"π π π π π π π π π π",
- }
-
- def generate_loading_text(percent):
- num_symbols = (round(percent) % 11) + 1
- symbols = loading_symbols.get(num_symbols, "π")
- mirrored_symbols = f'{symbols.strip()}'
- css_symbols = f'{symbols.strip()}'
- return f"""
-
- {css_symbols}
-
- Generation {percent}%
-
- {mirrored_symbols}
-
- """
-
- def generate_eval_text(text):
- symbols = "π‘ "
- mirrored_symbols = f'{symbols.strip()}'
- css_symbols = f'{symbols.strip()}'
- return f"""
-
- {css_symbols}
-
- {text}
-
- {mirrored_symbols}
-
- """
-
- def qatch_flow_nl_sql():
- global reset_flag
- global flag_TQA
- predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
- metrics_conc = pd.DataFrame()
- columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
- if (input_data['input_method']=="default"):
- #target_df = us.load_csv(pnp_path) #target_df = us.load_csv("priority_non_priority_metrics.csv")
- target_df = us.load_csv(pnp_path) if not flag_TQA else us.load_csv(PNP_TQA_PATH)
- #predictions_dict = {model: pd.DataFrame(columns=target_df.columns) for model in model_list}
- target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])]
- target_df = target_df[target_df["model"].isin(input_data['models'])]
- predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list}
- reset_flag = False
- for model in input_data['models']:
- model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
- yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
- count=1
- for _, row in predictions_dict[model].iterrows():
- #for index, row in target_df.iterrows():
- if (reset_flag == False):
- percent_complete = round(count / len(predictions_dict[model]) * 100, 2)
- count=count+1
- load_text = f"{generate_loading_text(percent_complete)}"
- question = row['question']
-
- display_question = f"""Natural Language:
-
- """
- yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
-
- prediction = row['predicted_sql']
-
- display_prediction = f"""Predicted SQL:
-
-
β‘οΈ
-
{prediction}
-
- """
-
- yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
- yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
- metrics_conc = target_df
- if 'valid_efficency_score' not in metrics_conc.columns:
- metrics_conc['valid_efficency_score'] = metrics_conc['VES']
- if 'VES' not in metrics_conc.columns:
- metrics_conc['VES'] = metrics_conc['valid_efficency_score']
- eval_text = generate_eval_text("End evaluation")
- yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
-
- else:
- global flag_TQA
- orchestrator_generator = OrchestratorGenerator()
- target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
-
- #create target_df[target_answer]
- if flag_TQA :
- # if (input_data["prompt"] == prompt_default):
- # input_data["prompt"] = prompt_default_tqa
-
- target_df['db_schema'] = target_df.apply(
- lambda row: utils_get_db_tables_info.utils_extract_db_schema_as_string(
- db_id=input_data["db_name"],
- base_path=input_data["data_path"],
- normalize=False,
- sql=row["query"],
- get_insert_into=True,
- model=None,
- prompt=input_data["prompt"].format(question=row["question"], db_schema="")
- ),
- axis=1
- )
-
- target_df = us.extract_answer(target_df)
-
- predictor = ModelPrediction()
- reset_flag = False
- for model in input_data["models"]:
- model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
- yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
- count=0
- for index, row in target_df.iterrows():
- if (reset_flag == False):
- percent_complete = round(((index+1) / len(target_df)) * 100, 2)
- load_text = f"{generate_loading_text(percent_complete)}"
-
- question = row['question']
- display_question = f"""Natural Language:
-
- """
- yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
- #samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
- model_to_send = None if not flag_TQA else model
-
- db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
- db_id = input_data["db_name"],
- base_path = input_data["data_path"],
- normalize=False,
- sql=row["query"],
- get_insert_into=True,
- model = model_to_send,
- prompt = input_data["prompt"].format(question=question, db_schema=""),
- )
-
- #prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
- prompt_to_send = input_data["prompt"]
- #PREDICTION SQL
-
- # TODO add button for QA or SP and pass to .make_prediction parameter TASK
- if flag_TQA: task="QA"
- else: task="SP"
- start_time = time.time()
- response = predictor.make_prediction(
- question=question,
- db_schema=db_schema_text,
- model_name=model,
- prompt=f"{prompt_to_send}",
- task=task
- )
- #if flag_TQA: response = {'response_parsed': "[['Alice'],['Bob'],['Charlie']]", 'cost': 0, 'response': "[['Alice'],['Bob'],['Charlie']]"} # TODO remove this line
- #else : response = {'response_parsed': "SELECT * FROM 'MyTable'", 'cost': 0, 'response': "SQL_QUERY"}
- end_time = time.time()
- prediction = response['response_parsed']
- price = response['cost']
- answer = response['response']
-
- if flag_TQA:
- task_string = "Answer"
- else:
- task_string = "SQL"
-
- display_prediction = f"""Predicted {task_string}:
-
-
β‘οΈ
-
{prediction}
-
- """
- # Create a new row as dataframe
- new_row = pd.DataFrame([{
- 'id': index,
- 'question': question,
- 'predicted_sql': prediction,
- 'time': end_time - start_time,
- 'query': row["query"],
- 'db_path': input_data["data_path"],
- 'price':price,
- 'answer': answer,
- 'number_question':count,
- 'target_answer' : row["target_answer"] if flag_TQA else None,
-
- }]).dropna(how="all") # Remove only completely empty rows
- count=count+1
- # TODO: use a for loop
- if (flag_TQA) :
- new_row['predicted_answer'] = prediction
- for col in target_df.columns:
- if col not in new_row.columns:
- new_row[col] = row[col]
- # Update model's prediction dataframe incrementally
- if not new_row.empty:
- predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
-
- # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
- yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
- yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
- # END
- eval_text = generate_eval_text("Evaluation")
- yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
-
- evaluator = OrchestratorEvaluator()
-
- for model in input_data["models"]:
- if not flag_TQA:
- metrics_df_model = evaluator.evaluate_df(
- df=predictions_dict[model],
- target_col_name="query",
- prediction_col_name="predicted_sql",
- db_path_name="db_path"
- )
- else:
- metrics_df_model = us.evaluate_answer(predictions_dict[model])
- metrics_df_model['model'] = model
- metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
-
- if 'VES' not in metrics_conc.columns and 'valid_efficency_score' not in metrics_conc.columns:
- metrics_conc['VES'] = 0
- metrics_conc['valid_efficency_score'] = 0
-
- if 'valid_efficency_score' not in metrics_conc.columns:
- metrics_conc['valid_efficency_score'] = metrics_conc['VES']
-
- if 'VES' not in metrics_conc.columns:
- metrics_conc['VES'] = metrics_conc['valid_efficency_score']
-
- eval_text = generate_eval_text("End evaluation")
- yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
-
- # Loading Bar
- with gr.Row():
- # progress = gr.Progress()
- variable = gr.Markdown()
-
- # NL -> MODEL -> Generated Query
- with gr.Row():
- with gr.Column():
- with gr.Column():
- question_display = gr.Markdown()
- with gr.Column():
- model_logo = gr.Image(visible=True,
- show_label=False,
- container=False,
- interactive=False,
- show_fullscreen_button=False,
- show_download_button=False,
- show_share_button=False)
- with gr.Column():
- with gr.Column():
- prediction_display = gr.Markdown()
-
- dataframe_per_model = {}
-
- with gr.Tabs() as model_tabs:
- tab_dict = {}
-
- for model, model_name in zip(model_list, model_names):
- with gr.TabItem(model_name, visible=(model in input_data["models"])) as tab:
- gr.Markdown(f"**Results for {model}**")
- tab_dict[model] = tab
- dataframe_per_model[model] = gr.DataFrame()
- #TODO download metrics per model
- # download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
-
- evaluation_loading = gr.Markdown()
-
- def change_tab():
- return [gr.update(visible=(model in input_data["models"])) for model in model_list]
-
- submit_models_button.click(
- change_tab,
- inputs=[],
- outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
- )
-
- selected_models_display = gr.JSON(label="Final input data", visible=False)
- metrics_df = gr.DataFrame(visible=False)
- metrics_df_out = gr.DataFrame(visible=False)
-
- submit_models_button.click(
- fn=qatch_flow_nl_sql,
- inputs=[],
- outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
- )
-
- submit_models_button.click(
- fn=lambda: gr.update(value=input_data),
- outputs=[selected_models_display]
- )
-
- # Works for METRICS
- metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
-
- proceed_to_metrics_button = gr.Button("Proceed to Metrics", visible=False)
- proceed_to_metrics_button.click(
- fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
- outputs=[qatch_acc, metrics_acc]
- )
-
- def allow_download(metrics_df_out):
- #path = os.path.join(".", "data", "data_results", "results.csv")
- path = os.path.join(".", "results.csv")
- metrics_df_out.to_csv(path, index=False)
- return gr.update(value=path, visible=True), gr.update(visible=True), gr.update(interactive=True)
-
- download_metrics = gr.DownloadButton(label="Download Metrics Evaluation", visible=False)
-
- submit_models_button.click(
- fn=lambda: gr.update(visible=False),
- outputs=[download_metrics]
- )
-
- def refresh():
- global reset_flag
- global flag_TQA
- reset_flag = True
- flag_TQA = False
-
- reset_data = gr.Button("Back to upload data section", interactive=True)
-
- metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data])
-
- reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
- #WHY NOT WORKING?
- reset_data.click(
- fn=lambda: gr.update(visible=False),
- outputs=[download_metrics]
- )
- reset_data.click(refresh)
-
- reset_data.click(
- fn=enable_disable,
- inputs=[gr.State(True)],
- outputs=[
- *model_checkboxes,
- submit_models_button,
- preview_output,
- submit_button,
- file_input,
- default_checkbox,
- table_selector,
- *table_outputs,
- open_model_selection
- ]
- )
-
- ##########################################
- # METRICS VISUALIZATION SECTION #
- ##########################################
- with metrics_acc:
- #data_path = 'test_results_metrics1.csv'
-
- @gr.render(inputs=metrics_df_out)
- def function_metrics(metrics_df_out):
-
- ####################################
- # UTILS FUNCTIONS SECTION #
- ####################################
-
- def load_data_csv_es():
-
- if input_data["input_method"]=="default":
- global flag_TQA
- #df = pd.read_csv(pnp_path)
- df = us.load_csv(pnp_path) if not flag_TQA else us.load_csv(PNP_TQA_PATH)
- df = df[df['model'].isin(input_data["models"])]
- df = df[df['tbl_name'].isin(input_data["data"]["selected_tables"])]
-
- df['model'] = df['model'].replace('DeepSeek-R1-Distill-Llama-70B', 'DS-Llama3 70B')
- df['model'] = df['model'].replace('gpt-3.5', 'GPT-3.5')
- df['model'] = df['model'].replace('gpt-4o-mini', 'GPT-4o-mini')
- df['model'] = df['model'].replace('llama-70', 'Llama-70B')
- df['model'] = df['model'].replace('llama-8', 'Llama-8B')
- df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY')
- #if (flag_TQA) : flag_TQA = False #TODO delete after make pred
- return df
- return metrics_df_out
-
- def calculate_average_metrics(df, selected_metrics):
- # Exclude the 'tuple_order' column from the selected metrics
-
- #TODO tuple_order has NULL VALUE
- selected_metrics = [metric for metric in selected_metrics if metric != 'tuple_order']
- #print(df[selected_metrics])
- df['avg_metric'] = df[selected_metrics].mean(axis=1)
- return df
-
- def generate_model_colors():
- """Generates a unique color map for models in the dataset."""
- df = load_data_csv_es()
- unique_models = df['model'].unique() # Extract unique models
- num_models = len(unique_models)
-
- # Use the Plotly color scale (you can change it if needed)
- color_palette = ['#00B4D8', '#BCE784', '#C84630', '#F79256', '#D269FC']
- #color_palette = pc.qualitative.Plotly # ['#636EFA', '#EF553B', '#00CC96', ...]
-
- # If there are more models than colors, cycle through them
- colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)}
-
- return colors
-
- MODEL_COLORS = generate_model_colors()
-
- def generate_db_category_colors():
- """Assigns 3 distinct colors to db_category groups."""
- return {
- "Spider": "#1f77b4", # blu
- "Beaver": "#ff7f0e", # arancione
- "Economic": "#2ca02c", # tutti gli altri verdi
- "Financial": "#2ca02c",
- "Medical": "#2ca02c",
- "Miscellaneous": "#2ca02c"
- }
-
- DB_CATEGORY_COLORS = generate_db_category_colors()
-
- def normalize_valid_efficency_score(df):
- df['valid_efficency_score'] = df['valid_efficency_score'].replace([np.nan, ''], 0)
- df['valid_efficency_score'] = df['valid_efficency_score'].astype(int)
- min_val = df['valid_efficency_score'].min()
- max_val = df['valid_efficency_score'].max()
-
- if min_val == max_val :
- # All values are equal, so for avoid division by zero, we set the score to 1/0
- if min_val == None:
- df['valid_efficency_score'] = 0
- else:
- df['valid_efficency_score'] = 1.0
- else:
- df['valid_efficency_score'] = (
- df['valid_efficency_score'] - min_val
- ) / (max_val - min_val)
-
- return df
-
-
- ####################################
- # GRAPH FUNCTIONS SECTION #
- ####################################
-
- # BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
- def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models):
- df = df[df['model'].isin(selected_models)]
- df = normalize_valid_efficency_score(df)
-
- # Mappatura nomi leggibili -> tecnici
- qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
- external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric]
-
- selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal
-
- df = calculate_average_metrics(df, selected_metrics)
-
- if group_by == ["model"]:
- # Bar plot per "model"
- avg_metrics = df.groupby("model")['avg_metric'].mean().reset_index()
- avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
-
- fig = px.bar(
- avg_metrics,
- x="model",
- y="avg_metric",
- color="model",
- color_discrete_map=MODEL_COLORS,
- title='Average metrics per Model π§ ',
- labels={"model": "Model", "avg_metric": "Average Metrics"},
- template='simple_white',
- #template='plotly_dark',
- text='text_label'
- )
- else:
- if group_by != ["tbl_name", "model"]:
- group_by = ["tbl_name", "model"]
-
- avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index()
- avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
-
- fig = px.bar(
- avg_metrics,
- x=group_by[0],
- y='avg_metric',
- color='model',
- color_discrete_map=MODEL_COLORS,
- barmode='group',
- title=f'Average metrics per {group_by[0]} π',
- labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metrics'},
- template='simple_white',
- #template='plotly_dark',
- text='text_label'
- )
-
- fig.update_traces(textposition='outside', textfont_size=10)
-
- # Applica font Inter a tutto il layout
- fig.update_layout(
- margin=dict(t=80),
- title=dict(
- font=dict(
- family="Inter, sans-serif",
- size=22,
- #color="white"
- ),
- x=0.5
- ),
- xaxis=dict(
- title=dict(
- font=dict(
- family="Inter, sans-serif",
- size=18,
- #color="white"
- )
- ),
- tickfont=dict(
- family="Inter, sans-serif",
- #color="white"
- size=16
- )
- ),
- yaxis=dict(
- title=dict(
- font=dict(
- family="Inter, sans-serif",
- size=18,
- #color="white"
- )
- ),
- tickfont=dict(
- family="Inter, sans-serif",
- #color="white"
- )
- ),
- legend=dict(
- title=dict(
- font=dict(
- family="Inter, sans-serif",
- size=16,
- #color="white"
- )
- ),
- font=dict(
- family="Inter, sans-serif",
- #color="white"
- )
- )
- )
-
- return gr.Plot(fig, visible=True)
-
- def update_plot(radio_metric, qatch_selected_metrics, external_selected_metric,group_by, selected_models):
- df = load_data_csv_es()
- return plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models)
-
- # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION
- def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
- if selected_models == "All":
- selected_models = models
- else:
- selected_models = [selected_models]
-
- df = df[df['model'].isin(selected_models)]
- df = normalize_valid_efficency_score(df)
-
- # Converti nomi leggibili -> tecnici
- qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
- external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric]
-
- selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal
-
- df = calculate_average_metrics(df, selected_metrics)
-
- avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index()
- avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
- fig = px.bar(
- avg_metrics,
- x='db_category',
- y='avg_metric',
- color='model',
- color_discrete_map=MODEL_COLORS,
- barmode='group',
- title='Average metrics per database types π',
- labels={'db_path': 'DB Path', 'avg_metric': 'Average Metrics'},
- template='simple_white',
- text='text_label'
- )
-
- fig.update_traces(textposition='outside', textfont_size=14)
-
- # Aggiorna layout con font Inter
- fig.update_layout(
- margin=dict(t=80),
- title=dict(
- font=dict(
- family="Inter, sans-serif",
- size=24,
- color="black"
- ),
- x=0.5
- ),
- xaxis=dict(
- title=dict(
- text='Database Category',
- font=dict(
- family='Inter, sans-serif',
- size=22,
- color='black'
- )
- ),
- tickfont=dict(
- family='Inter, sans-serif',
- color='black',
- size=20
- )
- ),
- yaxis=dict(
- title=dict(
- text='Average Metrics',
- font=dict(
- family='Inter, sans-serif',
- size=22,
- color='black'
- )
- ),
- tickfont=dict(
- family='Inter, sans-serif',
- color='black'
- )
- ),
- legend=dict(
- title=dict(
- text='Models',
- font=dict(
- family='Inter, sans-serif',
- size=20,
- color='black'
- )
- ),
- font=dict(
- family='Inter, sans-serif',
- color='black',
- size=18
- )
- )
- )
-
- return gr.Plot(fig, visible=True)
-
- def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
- df = load_data_csv_es()
- return plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models)
-
- # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION
-
- def lollipop_propietary(selected_models):
- df = load_data_csv_es()
-
- # Filtra solo le categorie rilevanti
- target_cats = ["Spider", "Economic", "Financial", "Medical", "Miscellaneous", "Beaver"]
- df = df[df['db_category'].isin(target_cats)]
- df = df[df['model'].isin(selected_models)]
-
- df = normalize_valid_efficency_score(df)
- df = calculate_average_metrics(df, qatch_metrics)
-
- # Calcola la media per db_category e modello
- avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index()
-
- # Separa Spider e le altre 4 categorie
- spider_df = avg_metrics[avg_metrics["db_category"] == "Spider"]
- other_df = avg_metrics[avg_metrics["db_category"] != "Spider"]
-
- # Calcola media delle altre categorie per ciascun modello
- other_mean_df = other_df.groupby("model")["avg_metric"].mean().reset_index()
- other_mean_df["db_category"] = "Others"
-
- # Rinominare per chiarezza e uniformitΓ
- spider_df = spider_df.rename(columns={"avg_metric": "Spider"})
- other_mean_df = other_mean_df.rename(columns={"avg_metric": "Others"})
-
- # Unione dei due dataset
- merged_df = pd.merge(spider_df[["model", "Spider"]], other_mean_df[["model", "Others"]], on="model")
-
- # Ordina per modello o per valore se vuoi
- merged_df = merged_df.sort_values(by="model")
-
- fig = go.Figure()
-
- # Aggiungi linee orizzontali tra Spider e Others
- for _, row in merged_df.iterrows():
- fig.add_trace(go.Scatter(
- x=[row["Spider"], row["Others"]],
- y=[row["model"]] * 2,
- mode='lines',
- line=dict(color='gray', width=2),
- showlegend=False
- ))
-
- # Punto per Spider
- fig.add_trace(go.Scatter(
- x=merged_df["Spider"],
- y=merged_df["model"],
- mode='markers',
- name='Non-Proprietary (Spider)',
- marker=dict(size=10, color='#C84630')
- ))
-
- # Punto per Others (media delle altre 4 categorie)
- fig.add_trace(go.Scatter(
- x=merged_df["Others"],
- y=merged_df["model"],
- mode='markers',
- name='Proprietary Databases',
- marker=dict(size=10, color='#0077B6')
- ))
-
- fig.update_layout(
- xaxis_title='Average Metrics',
- yaxis_title='Models',
- template='simple_white',
- #template='plotly_dark',
- margin=dict(t=80),
- title=dict(
- font=dict(
- family="Inter, sans-serif",
- size=22,
- color="black"
- ),
- x=0.5,
- text='Dumbbell graph: Non-Proprietary (Spider π·οΈ) vs Proprietary Databases π'
- ),
- legend_title='Type of Databases:',
- height=600,
- xaxis=dict(
- title=dict(
- text='DB Category',
- font=dict(
- family='Inter, sans-serif',
- size=18,
- color='black'
- )
- ),
- tickfont=dict(
- family='Inter, sans-serif',
- color='black'
- )
- ),
- yaxis=dict(
- title=dict(
- text='Average Metrics',
- font=dict(
- family='Inter, sans-serif',
- size=18,
- color='black'
- )
- ),
- tickfont=dict(
- family='Inter, sans-serif',
- color='black'
- )
- ),
- legend=dict(
- title=dict(
- text='Models',
- font=dict(
- family='Inter, sans-serif',
- size=18,
- color='black'
- )
- ),
- font=dict(
- family='Inter, sans-serif',
- color='black',
- size=14
- )
- )
- )
-
- return gr.Plot(fig, visible=True)
-
-
- # RADAR OR BAR CHART BASED ON CATEGORY COUNT
- def plot_radar(df, selected_models, selected_metrics, selected_categories):
- if "External" in selected_metrics:
- selected_metrics = ["execution_accuracy", "valid_efficency_score"]
- else:
- selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
-
- # Filtro modelli e normalizzazione
- df = df[df['model'].isin(selected_models)]
- df = normalize_valid_efficency_score(df)
- df = calculate_average_metrics(df, selected_metrics)
-
- avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
-
- if avg_metrics.empty:
- print("Error: No data available to compute averages.")
- return go.Figure()
-
- categories = selected_categories
-
- if len(categories) < 3:
- # π BAR PLOT
- fig = go.Figure()
- for model in selected_models:
- model_data = avg_metrics[avg_metrics['model'] == model]
- values = [
- model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
- if cat in model_data['test_category'].values else 0
- for cat in categories
- ]
- fig.add_trace(go.Bar(
- x=categories,
- y=values,
- name=model,
- marker=dict(color=MODEL_COLORS.get(model, "gray"))
- ))
-
- fig.update_layout(
- barmode='group',
- title=dict(
- text='π Bar Plot of Metrics per Model (Few Categories)',
- font=dict(
- family='Inter, sans-serif',
- size=22,
- #color='white'
- ),
- x=0.5
- ),
- template='simple_white',
- #template='plotly_dark',
- xaxis=dict(
- title=dict(
- text='Test Category',
- font=dict(
- family='Inter, sans-serif',
- size=18,
- #color='white'
- )
- ),
- tickfont=dict(
- family='Inter, sans-serif',
- size=16
- #color='white'
- )
- ),
- yaxis=dict(
- title=dict(
- text='Average Metrics',
- font=dict(
- family='Inter, sans-serif',
- size=18,
- #color='white'
- )
- ),
- tickfont=dict(
- family='Inter, sans-serif',
- #color='white'
- )
- ),
- legend=dict(
- title=dict(
- text='Models',
- font=dict(
- family='Inter, sans-serif',
- size=16,
- #color='white'
- )
- ),
- font=dict(
- family='Inter, sans-serif',
- #color='white'
- )
- )
- )
- else:
- # π§ RADAR PLOT
- fig = go.Figure()
- for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True):
- model_data = avg_metrics[avg_metrics['model'] == model]
- values = [
- model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
- if cat in model_data['test_category'].values else 0
- for cat in categories
- ]
- fig.add_trace(go.Scatterpolar(
- r=values,
- theta=categories,
- fill='toself',
- name=model,
- line=dict(color=MODEL_COLORS.get(model, "gray"))
- ))
-
- fig.update_layout(
- polar=dict(
- radialaxis=dict(
- visible=True,
- range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
- tickfont=dict(
- family='Inter, sans-serif',
- #color='white'
- )
- ),
- angularaxis=dict(
- tickfont=dict(
- family='Inter, sans-serif',
- size=16
- #color='white'
- )
- )
- ),
- title=dict(
- text='βοΈ Radar Plot of Metrics per Model (Average per SQL Category)',
- font=dict(
- family='Inter, sans-serif',
- size=22,
- #color='white'
- ),
- x=0.5
- ),
- legend=dict(
- title=dict(
- text='Models',
- font=dict(
- family='Inter, sans-serif',
- size=18,
- #color='white'
- )
- ),
- font=dict(
- family='Inter, sans-serif',
- size=16
- #color='white'
- )
- ),
- template='simple_white'
- #template='plotly_dark'
- )
-
- return fig
-
- def update_radar(selected_models, selected_metrics, selected_categories):
- df = load_data_csv_es()
- return plot_radar(df, selected_models, selected_metrics, selected_categories)
-
- # RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT
- def plot_radar_sub(df, selected_models, selected_metrics, selected_category):
- if "External" in selected_metrics:
- selected_metrics = ["execution_accuracy", "valid_efficency_score"]
- else:
- selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
-
- df = df[df['model'].isin(selected_models)]
- df = normalize_valid_efficency_score(df)
- df = calculate_average_metrics(df, selected_metrics)
-
- if isinstance(selected_category, str):
- selected_category = [selected_category]
-
- df = df[df['test_category'].isin(selected_category)]
- avg_metrics = df.groupby(['model', 'sql_tag'])['avg_metric'].mean().reset_index()
-
- if avg_metrics.empty:
- print("Error: No data available to compute averages.")
- return go.Figure()
-
- categories = df['sql_tag'].unique().tolist()
-
- if len(categories) < 3:
- # π BAR PLOT
- fig = go.Figure()
- for model in selected_models:
- model_data = avg_metrics[avg_metrics['model'] == model]
- values = [
- model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0]
- if cat in model_data['sql_tag'].values else 0
- for cat in categories
- ]
- fig.add_trace(go.Bar(
- x=categories,
- y=values,
- name=model,
- marker=dict(color=MODEL_COLORS.get(model, "gray"))
- ))
-
- fig.update_layout(
- barmode='group',
- title=dict(
- text='π Bar Plot of Metrics per Model (Few Sub-Categories)',
- font=dict(
- family='Inter, sans-serif',
- size=22,
- #color='white'
- ),
- x=0.5
- ),
- template='simple_white',
- #template='plotly_dark',
- xaxis=dict(
- title=dict(
- text='SQL Tag (Sub Category)',
- font=dict(
- family='Inter, sans-serif',
- size=18,
- #color='white'
- )
- ),
- tickfont=dict(
- family='Inter, sans-serif',
- #color='white'
- )
- ),
- yaxis=dict(
- title=dict(
- text='Average Metrics',
- font=dict(
- family='Inter, sans-serif',
- size=18,
- #color='white'
- )
- ),
- tickfont=dict(
- family='Inter, sans-serif',
- #color='white'
- )
- ),
- legend=dict(
- title=dict(
- text='Models',
- font=dict(
- family='Inter, sans-serif',
- size=16,
- #color='white'
- )
- ),
- font=dict(
- family='Inter, sans-serif',
- size=14
- #color='white'
- )
- )
- )
- else:
- # π§ RADAR PLOT
- fig = go.Figure()
-
- for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True):
- model_data = avg_metrics[avg_metrics['model'] == model]
- values = [
- model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0]
- if cat in model_data['sql_tag'].values else 0
- for cat in categories
- ]
-
- fig.add_trace(go.Scatterpolar(
- r=values,
- theta=categories,
- fill='toself',
- name=model,
- line=dict(color=MODEL_COLORS.get(model, "gray"))
- ))
-
- fig.update_layout(
- polar=dict(
- radialaxis=dict(
- visible=True,
- range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
- tickfont=dict(
- family='Inter, sans-serif',
- #color='white'
- )
- ),
- angularaxis=dict(
- tickfont=dict(
- family='Inter, sans-serif',
- size=16
- #color='white'
- )
- )
- ),
- title=dict(
- text='βοΈ Radar Plot of Metrics per Model (Average per SQL Sub-Category)',
- font=dict(
- family='Inter, sans-serif',
- size=22,
- #color='white'
- ),
- x=0.5
- ),
- legend=dict(
- title=dict(
- text='Models',
- font=dict(
- family='Inter, sans-serif',
- size=16,
- #color='white'
- )
- ),
- font=dict(
- family='Inter, sans-serif',
- size=14,
- #color='white'
- )
- ),
- template='simple_white'
- #template='plotly_dark'
- )
-
- return fig
-
- def update_radar_sub(selected_models, selected_metrics, selected_category):
- df = load_data_csv_es()
- return plot_radar_sub(df, selected_models, selected_metrics, selected_category)
-
- # RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
- def worst_cases_text(df, selected_models, selected_metrics, selected_categories):
- global flag_TQA
- if selected_models == "All":
- selected_models = models
- else:
- selected_models = [selected_models]
-
- if selected_categories == "All":
- selected_categories = principal_categories
- else:
- selected_categories = [selected_categories]
-
- df = df[df['model'].isin(selected_models)]
- df = df[df['test_category'].isin(selected_categories)]
-
- if "external" in selected_metrics:
- selected_metrics = ["execution_accuracy", "valid_efficency_score"]
- else:
- selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
-
- df = normalize_valid_efficency_score(df)
- df = calculate_average_metrics(df, selected_metrics)
-
- if flag_TQA:
- df["target_answer"] = df["target_answer"] = df["target_answer"].apply(lambda x: "[" + ", ".join(map(str, x)) + "]")
-
- worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'target_answer', 'predicted_answer', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
- else:
- worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
-
- worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
-
- worst_cases_top_3 = worst_cases_df.head(3)
-
- worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2)
-
- worst_str = []
- answer_str = []
-
- medals = ["π₯", "π₯", "π₯"]
-
- for i, row in worst_cases_top_3.iterrows():
- if flag_TQA:
- entry = (
- f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n"
- f"- Question: {row['question']} \n"
- f"- Original Answer: `{row['target_answer']}` \n"
- f"- Predicted Answer: `{eval(row['predicted_answer'])}` \n\n"
- )
-
- worst_str.append(entry)
- else:
- entry = (
- f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n"
- f"- Question: {row['question']} \n"
- f"- Original Query: `{row['query']}` \n"
- f"- Predicted SQL: `{row['predicted_sql']}` \n\n"
- )
-
- worst_str.append(entry)
-
- raw_answer = (
- f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n"
- f"- Raw Answer:
`{row['answer']}` \n"
- )
-
- answer_str.append(raw_answer)
-
- return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2]
-
- def update_worst_cases_text(selected_models, selected_metrics, selected_categories):
- df = load_data_csv_es()
- return worst_cases_text(df, selected_models, selected_metrics, selected_categories)
-
- # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
- def plot_cumulative_flow(df, selected_models, max_points):
- df = df[df['model'].isin(selected_models)]
- df = normalize_valid_efficency_score(df)
-
- fig = go.Figure()
-
- for model in selected_models:
- model_df = df[df['model'] == model].copy()
-
- # Limita il numero di punti se richiesto
- if max_points is not None:
- model_df = model_df.head(max_points + 1)
-
- # Tooltip personalizzato
- model_df['hover_info'] = model_df.apply(
- lambda row:
- f"Id question: {row['number_question']}
"
- f"Question: {row['question']}
"
- f"Target: {row['query']}
"
- f"Prediction: {row['predicted_sql']}
"
- f"Category: {row['test_category']}",
- axis=1
- )
-
- # Calcoli cumulativi
- model_df['cumulative_time'] = model_df['time'].cumsum()
- model_df['cumulative_price'] = model_df['price'].cumsum()
-
- # Colore del modello
- color = MODEL_COLORS.get(model, "gray")
-
- fig.add_trace(go.Scatter(
- x=model_df['cumulative_time'],
- y=model_df['cumulative_price'],
- mode='lines+markers',
- name=model,
- line=dict(width=2, color=color),
- customdata=model_df['hover_info'],
- hovertemplate=
- "Model: " + model + "
" +
- "Cumulative Time: %{x}s
" +
- "Cumulative Price: $%{y:.2f}
" +
- "
Details:
%{customdata}"
- ))
-
- # Layout con font elegante
- fig.update_layout(
- title=dict(
- text="Cumulative Price Flow Chart π°",
- font=dict(
- family="Inter, sans-serif",
- size=24,
- #color="white"
- ),
- x=0.5
- ),
- xaxis=dict(
- title=dict(
- text="Cumulative Time (s)",
- font=dict(
- family="Inter, sans-serif",
- size=20,
- #color="white"
- )
- ),
- tickfont=dict(
- family="Inter, sans-serif",
- size=18
- #color="white"
- )
- ),
- yaxis=dict(
- title=dict(
- text="Cumulative Price ($)",
- font=dict(
- family="Inter, sans-serif",
- size=20,
- #color="white"
- )
- ),
- tickfont=dict(
- family="Inter, sans-serif",
- size=18
- #color="white"
- )
- ),
- legend=dict(
- title=dict(
- text="Models",
- font=dict(
- family="Inter, sans-serif",
- size=18,
- #color="white"
- )
- ),
- font=dict(
- family="Inter, sans-serif",
- size=16,
- #color="white"
- )
- ),
- template='simple_white',
- #template="plotly_dark"
- )
-
- return fig
-
- def update_query_rate(selected_models, max_points):
- df = load_data_csv_es()
- return plot_cumulative_flow(df, selected_models, max_points)
-
-
-
-
- #######################
- # PARAMETER SECTION #
- #######################
- qatch_metrics_dict = {
- "Cell Precision": "cell_precision",
- "Cell Recall": "cell_recall",
- "Tuple Order": "tuple_order",
- "Tuple Cardinality": "tuple_cardinality",
- "Tuple Constraint": "tuple_constraint"
- }
-
- qatch_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
- last_valid_qatch_metrics_selection = qatch_metrics.copy() # Per salvare lβultima selezione valida
- def enforce_qatch_metrics_selection(selected):
- global last_valid_qatch_metrics_selection
- if not selected: # Se nessuna metrica Γ¨ selezionata
- return gr.update(value=last_valid_qatch_metrics_selection)
- last_valid_qatch_metrics_selection = selected # Altrimenti aggiorna la selezione valida
- return gr.update(value=selected)
-
- external_metrics_dict = {
- "Execution Accuracy": "execution_accuracy",
- "Valid Efficency Score": "valid_efficency_score"
- }
-
- external_metric = ["execution_accuracy", "valid_efficency_score"]
- last_valid_external_metric_selection = external_metric.copy()
- def enforce_external_metric_selection(selected):
- global last_valid_external_metric_selection
- if not selected: # Se nessuna metrica Γ¨ selezionata
- return gr.update(value=last_valid_external_metric_selection)
- last_valid_external_metric_selection = selected # Altrimenti aggiorna la selezione valida
- return gr.update(value=selected)
-
- all_metrics = {
- "Qatch": ["qatch"],
- "External": ["external"]
- }
-
- group_options = {
- "Table": ["tbl_name", "model"],
- "Model": ["model"]
- }
-
- df_initial = load_data_csv_es()
- models = models = df_initial['model'].unique().tolist()
- last_valid_model_selection = models.copy() # Per salvare lβultima selezione valida
- def enforce_model_selection(selected):
- global last_valid_model_selection
- if not selected: # Se nessuna metrica Γ¨ selezionata
- return gr.update(value=last_valid_model_selection)
- last_valid_model_selection = selected # Altrimenti aggiorna la selezione valida
- return gr.update(value=selected)
-
- all_categories = df_initial['sql_tag'].unique().tolist()
-
- principal_categories = df_initial['test_category'].unique().tolist()
- last_valid_category_selection = principal_categories.copy() # Per salvare lβultima selezione valida
- def enforce_category_selection(selected):
- global last_valid_category_selection
- if not selected: # Se nessuna metrica Γ¨ selezionata
- return gr.update(value=last_valid_category_selection)
- last_valid_category_selection = selected # Altrimenti aggiorna la selezione valida
- return gr.update(value=selected)
-
- all_categories_as_dic = {cat: [f"{cat}"] for cat in principal_categories}
-
- all_categories_as_dic_ranking = {cat: [f"{cat}"] for cat in principal_categories}
- all_categories_as_dic_ranking["All"] = principal_categories
-
- all_model_as_dic = {cat: [f"{cat}"] for cat in models}
- all_model_as_dic["All"] = models
-
- ###########################
- # VISUALIZATION SECTION #
- ###########################
- gr.Markdown("""# Model Performance Analysis""")
-
- #FOR BAR
- gr.Markdown("""## Section 1: Model - Data""")
-
- with gr.Row():
- with gr.Column(scale=1):
- with gr.Row():
- choose_metrics_bar = gr.Radio(
- choices=list(all_metrics.keys()),
- label="Select the metrics group that you want to use:",
- value="Qatch"
- )
-
- with gr.Row():
- qatch_info = gr.HTML("""
-
- Qatch metric info βΉοΈ
-
- """, visible=True)
-
- external_info = gr.HTML("""
-
- External metric info βΉοΈ
-
- """, visible=False)
-
- qatch_metric_multiselect_bar = gr.CheckboxGroup(
- choices=list(qatch_metrics_dict.keys()),
- label="Select one or mode Qatch metrics:",
- value=list(qatch_metrics_dict.keys()),
- visible=True
- )
-
- external_metric_select_bar = gr.CheckboxGroup(
- choices=list(external_metrics_dict.keys()),
- label="Select one or more External metrics:",
- visible=False
- )
-
- if(input_data['input_method'] == 'default'):
- model_radio_bar = gr.Radio(
- choices=list(all_model_as_dic.keys()),
- label="Select the model that you want to use:",
- value="All"
- )
- else:
- model_multiselect_bar = gr.CheckboxGroup(
- choices=models,
- label="Select one or more models:",
- value=models,
- interactive=len(models) > 1
- )
-
- group_radio = gr.Radio(
- choices=list(group_options.keys()),
- label="Select the grouping view:",
- value="Table"
- )
-
- def toggle_metric_selector(selected_type):
- if selected_type == "Qatch":
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True, value=list(qatch_metrics_dict.keys())), gr.update(visible=False, value=[])
- else:
- return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False, value=[]), gr.update(visible=True, value=list(external_metrics_dict.keys()))
-
- output_plot = gr.Plot(visible=False)
-
- if(input_data['input_method'] == 'default'):
- with gr.Row():
- lollipop_propietary(models)
-
- #FOR RADAR
- gr.Markdown("""## Section 2: Model - Category""")
- with gr.Row():
- all_metrics_radar = gr.Radio(
- choices=list(all_metrics.keys()),
- label="Select the metrics group that you want to use:",
- value="Qatch"
- )
-
- model_multiselect_radar = gr.CheckboxGroup(
- choices=models,
- label="Select one or more models:",
- value=models,
- interactive=len(models) > 1
- )
-
- with gr.Row():
- with gr.Column(scale=1):
- category_multiselect_radar = gr.CheckboxGroup(
- choices=principal_categories,
- label="Select one or more categories:",
- value=principal_categories
- )
- with gr.Column(scale=1):
- category_radio_radar = gr.Radio(
- choices=list(all_categories_as_dic.keys()),
- label="Select the metrics that you want to use:",
- value=list(all_categories_as_dic.keys())[0]
- )
-
- with gr.Row():
- with gr.Column(scale=1):
- radar_plot_multiselect = gr.Plot(value=update_radar(models, "Qatch", principal_categories))
-
- with gr.Column(scale=1):
- radar_plot_radio = gr.Plot(value=update_radar_sub(models, "Qatch", list(all_categories_as_dic.keys())[0]))
-
- #FOR RANKING
- with gr.Row():
- all_metrics_ranking = gr.Radio(
- choices=list(all_metrics.keys()),
- label="Select the metrics group that you want to use:",
- value="Qatch"
- )
- model_choices = list(all_model_as_dic.keys())
-
- if len(model_choices) == 2:
- model_choices = [model_choices[0]] # supponiamo che il modello sia in prima posizione
- selected_value = model_choices[0]
- else:
- selected_value = "All"
-
- model_radio_ranking = gr.Radio(
- choices=model_choices,
- label="Select the model that you want to use:",
- value=selected_value
- )
-
- category_radio_ranking = gr.Radio(
- choices=list(all_categories_as_dic_ranking.keys()),
- label="Select the category that you want to use",
- value="All"
- )
-
- with gr.Row():
- with gr.Column(scale=1):
- gr.Markdown("## β 3 Worst Cases\n")
-
- worst_first, worst_second, worst_third, raw_first, raw_second, raw_third = update_worst_cases_text("All", "Qatch", "All")
-
- with gr.Row():
- first = gr.Markdown(worst_first)
-
- with gr.Row():
- first_button = gr.Button("Show raw answer for π₯")
-
- with gr.Row():
- second = gr.Markdown(worst_second)
-
- with gr.Row():
- second_button = gr.Button("Show raw answer for π₯")
-
- with gr.Row():
- third = gr.Markdown(worst_third)
-
- with gr.Row():
- third_button = gr.Button("Show raw answer for π₯")
-
- with gr.Column(scale=1):
- gr.Markdown("""## Raw Answer""")
- row_answer_first = gr.Markdown(value=raw_first, visible=True)
- row_answer_second = gr.Markdown(value=raw_second, visible=False)
- row_answer_third = gr.Markdown(value=raw_third, visible=False)
-
- #FOR RATE
- gr.Markdown("""## Section 3: Time - Price""")
- with gr.Row():
- model_multiselect_rate = gr.CheckboxGroup(
- choices=models,
- label="Select one or more models:",
- value=models,
- interactive=len(models) > 1
- )
-
-
- with gr.Row():
- slicer = gr.Slider(minimum=0, maximum=max(df_initial["number_question"]), step=1, value=max(df_initial["number_question"]), label="Number of instances to visualize", elem_id="custom-slider")
-
- query_rate_plot = gr.Plot(value=update_query_rate(models, len(df_initial["number_question"].unique())))
-
-
- #FOR RESET
- reset_data = gr.Button("Back to upload data section")
-
-
-
-
- ###############################
- # CALLBACK FUNCTION SECTION #
- ###############################
-
- #FOR BAR
- def on_change(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_group, selected_models):
- return update_plot(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, group_options[selected_group], selected_models)
-
- def on_change_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models):
- return update_plot_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models)
-
- #FOR RADAR
- def on_radar_multiselect_change(selected_models, selected_metrics, selected_categories):
- return update_radar(selected_models, selected_metrics, selected_categories)
-
- def on_radar_radio_change(selected_models, selected_metrics, selected_category):
- return update_radar_sub(selected_models, selected_metrics, selected_category)
-
- #FOR RANKING
- def on_ranking_change(selected_models, selected_metrics, selected_categories):
- return update_worst_cases_text(selected_models, selected_metrics, selected_categories)
-
- def show_first():
- return (
- gr.update(visible=True),
- gr.update(visible=False),
- gr.update(visible=False)
- )
-
- def show_second():
- return (
- gr.update(visible=False),
- gr.update(visible=True),
- gr.update(visible=False)
- )
-
- def show_third():
- return (
- gr.update(visible=False),
- gr.update(visible=False),
- gr.update(visible=True)
- )
-
-
-
-
- ######################
- # ON CLICK SECTION #
- ######################
-
- #FOR BAR
- if(input_data['input_method'] == 'default'):
- proceed_to_metrics_button.click(on_change_propietary, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
- qatch_metric_multiselect_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
- external_metric_select_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
- model_radio_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
- qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
- choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar])
- external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
-
- else:
- proceed_to_metrics_button.click(on_change, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
- qatch_metric_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
- external_metric_select_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
- group_radio.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
- model_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
- qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
- model_multiselect_bar.change(fn=enforce_model_selection, inputs=model_multiselect_bar, outputs=model_multiselect_bar)
- choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar])
- external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
-
-
- #FOR RADAR MULTISELECT
- model_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
- all_metrics_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
- category_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
- model_multiselect_radar.change(fn=enforce_model_selection, inputs=model_multiselect_radar, outputs=model_multiselect_radar)
- category_multiselect_radar.change(fn=enforce_category_selection, inputs=category_multiselect_radar, outputs=category_multiselect_radar)
-
- #FOR RADAR RADIO
- model_multiselect_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
- all_metrics_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
- category_radio_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
-
- #FOR RANKING
- model_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
- model_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
- all_metrics_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
- all_metrics_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
- category_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
- category_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
- model_radio_ranking.change(fn=enforce_model_selection, inputs=model_radio_ranking, outputs=model_radio_ranking)
- category_radio_ranking.change(fn=enforce_category_selection, inputs=category_radio_ranking, outputs=category_radio_ranking)
- first_button.click(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
- second_button.click(fn=show_second, outputs=[row_answer_first, row_answer_second, row_answer_third])
- third_button.click(fn=show_third, outputs=[row_answer_first, row_answer_second, row_answer_third])
-
- #FOR RATE
- model_multiselect_rate.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
- proceed_to_metrics_button.click(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
- model_multiselect_rate.change(fn=enforce_model_selection, inputs=model_multiselect_rate, outputs=model_multiselect_rate)
- slicer.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
-
- #FOR RESET
- reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
- reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics])
- reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
-
-
+import os
+import sys
+import time
+import re
+import csv
+import gradio as gr
+import pandas as pd
+import numpy as np
+import plotly.express as px
+import plotly.graph_objects as go
+import plotly.colors as pc
+from qatch.connectors.sqlite_connector import SqliteConnector
+from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
+from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
+import qatch.evaluate_dataset.orchestrator_evaluator as eva
+from prediction import ModelPrediction
+import utils_get_db_tables_info
+import utilities as us
+# @spaces.GPU
+# def model_prediction():
+# pass
+# # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
+# if os.environ.get("SPACES_ZERO_GPU") is not None:
+# import spaces
+# else:
+# class spaces:
+# @staticmethod
+# def GPU(func):
+# def wrapper(*args, **kwargs):
+# return func(*args, **kwargs)
+# return wrapper
+#pnp_path = os.path.join("data", "evaluation_p_np_metrics.csv")
+pnp_path = "concatenated_output.csv"
+PATH_PKL_TABLES = 'tables_dict_beaver.pkl'
+PNP_TQA_PATH = 'concatenated_output_tqa.csv'
+js_func = """
+function refresh() {
+ const url = new URL(window.location);
+
+ if (url.searchParams.get('__theme') !== 'light') {
+ url.searchParams.set('__theme', 'light');
+ window.location.href = url.href;
+ }
+}
+"""
+reset_flag = False
+flag_TQA = False
+
+with open('style.css', 'r') as file:
+ css = file.read()
+
+# DataFrame di default
+df_default = pd.DataFrame({
+ 'Name': ['Alice', 'Bob', 'Charlie'],
+ 'Age': [25, 30, 35],
+ 'City': ['New York', 'Los Angeles', 'Chicago']
+})
+models_path ="models.csv"
+
+# Variabile globale per tenere traccia dei dati correnti
+df_current = df_default.copy()
+
+description = """## π Comparison of Proprietary and Non-Proprietary Databases
+ ### β€ **Proprietary** :
+ ### β Economic π°, Medical π₯, Financial π³, Miscellaneous π
+ ### β BEAVER (FAC BUILDING ADDRESS π’ , TIME QUARTER β±οΈ)
+ ### β€ **Non-Proprietary**
+ ### β Spider 1.0 π·οΈ"""
+prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n"
+prompt_default_tqa = "Return the answer of the following question based on the provided database. Return your answer as the result of a query executed over the database. Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n Return the answer in answer tag as .\n Question \n {question}\n Database Schema\n {db_schema}\n"
+
+
+input_data = {
+ 'input_method': "",
+ 'data_path': "",
+ 'db_name': "",
+ 'data': {
+ 'data_frames': {}, # dictionary of dataframes
+ 'db': None, # SQLITE3 database object
+ 'selected_tables' :[]
+ },
+ 'models': [],
+ 'prompt': prompt_default
+}
+
+def load_data(file, path, use_default):
+ """Carica i dati da un file, un percorso o usa il DataFrame di default."""
+ global df_current
+ if file is not None:
+ try:
+ input_data["input_method"] = 'uploaded_file'
+ input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
+ if file.endswith('.sqlite'):
+ #return 'Error: The uploaded file is not a valid SQLite database.'
+ input_data["data_path"] = file #os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
+ else:
+ #change path
+ input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
+ input_data["data"] = us.load_data(file, input_data["db_name"])
+
+ df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
+ if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
+ table2primary_key = {}
+ for table_name, df in input_data["data"]['data_frames'].items():
+ # Assign primary keys for each table
+ table2primary_key[table_name] = 'id'
+ input_data["data"]["db"] = SqliteConnector(
+ relative_db_path=input_data["data_path"],
+ db_name=input_data["db_name"],
+ tables= input_data["data"]['data_frames'],
+ table2primary_key=table2primary_key
+ )
+ return input_data["data"]['data_frames']
+ except Exception as e:
+ return f'Errore nel caricamento del file: {e}'
+ if use_default:
+ if(use_default == 'Custom'):
+ input_data["input_method"] = 'custom'
+ #input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable_0.sqlite")
+ input_data["data_path"] = os.path.join(".","mytable_0.sqlite")
+ #if file already exist
+ while os.path.exists(input_data["data_path"]):
+ input_data["data_path"] = us.increment_filename(input_data["data_path"])
+ input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0]
+ input_data["data"]['data_frames'] = {'MyTable': df_current}
+
+ if(input_data["data"]['data_frames']):
+ table2primary_key = {}
+ for table_name, df in input_data["data"]['data_frames'].items():
+ # Assign primary keys for each table
+ table2primary_key[table_name] = 'id'
+ input_data["data"]["db"] = SqliteConnector(
+ relative_db_path=input_data["data_path"],
+ db_name=input_data["db_name"],
+ tables= input_data["data"]['data_frames'],
+ table2primary_key=table2primary_key
+ )
+ df_current = df_default.copy() # Ripristina i dati di default
+ return input_data["data"]['data_frames']
+
+ if(use_default == 'Proprietary vs Non-proprietary'):
+ input_data["input_method"] = 'default'
+ #input_data["data_path"] = os.path.join(".", "data", "data_interface", "default.sqlite")
+ #input_data["data_path"] = os.path.join(".", "data", "spider_databases", "defeault.sqlite")
+ #input_data["db_name"] = "default"
+ #input_data["data"]['db'] = SqliteConnector(relative_db_path=input_data["data_path"], db_name=input_data["db_name"])
+ input_data["data"]['data_frames'] = us.load_tables_dict_from_pkl(PATH_PKL_TABLES)
+ return input_data["data"]['data_frames']
+
+ selected_inputs = sum([file is not None, bool(path), use_default])
+ if selected_inputs > 1:
+ return 'Error: Select only one input method at a time.'
+
+ return input_data["data"]['data_frames']
+
+def preview_default(use_default, file):
+ if file:
+ return gr.DataFrame(interactive=True, visible = False, value = df_default), gr.update(value="## β
File successfully uploaded!", visible=True)
+ else :
+ if use_default == 'Custom':
+ return gr.DataFrame(interactive=True, visible = True, value = df_default), gr.update(value="## π Toy Table", visible=True)
+ else:
+ return gr.DataFrame(interactive=False, visible = False, value = df_default), gr.update(value = description, visible=True)
+ #return gr.DataFrame(interactive=True, value = df_current) # Mostra il DataFrame corrente, che potrebbe essere stato modificato
+
+def update_df(new_df):
+ """Aggiorna il DataFrame corrente."""
+ global df_current # Usa la variabile globale per aggiornarla
+ df_current = new_df
+ return df_current
+
+def open_accordion(target):
+ # Apre uno e chiude l'altro
+ if target == "reset":
+ df_current = df_default.copy()
+ input_data['input_method'] = ""
+ input_data['data_path'] = ""
+ input_data['db_name'] = ""
+ input_data['data']['data_frames'] = {}
+ input_data['data']['selected_tables'] = []
+ input_data['data']['db'] = None
+ input_data['models'] = []
+ return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value='Proprietary vs Non-proprietary'), gr.update(value=None)
+ elif target == "model_selection":
+ return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False)
+
+# Interfaccia Gradio
+#with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
+with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as interface:
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Image(
+ value=os.path.join(".", "qatch_logo.png"),
+ show_label=False,
+ container=False,
+ interactive=False,
+ show_fullscreen_button=False,
+ show_download_button=False,
+ show_share_button=False,
+ height=150, # in pixel
+ width=300
+ )
+ with gr.Column(scale=1):
+ pass
+ data_state = gr.State(None) # Memorizza i dati caricati
+ upload_acc = gr.Accordion("Upload data section", open=True, visible=True)
+ select_table_acc = gr.Accordion("Select tables section", open=False, visible=False)
+ select_model_acc = gr.Accordion("Select models section", open=False, visible=False)
+ qatch_acc = gr.Accordion("QATCH execution section", open=False, visible=False)
+ metrics_acc = gr.Accordion("Metrics section", open=False, visible=False)
+
+ #################################
+ # DATABASE INSERTION #
+ #################################
+ with upload_acc:
+ gr.Markdown("## π₯Choose data input method")
+ with gr.Row():
+ default_checkbox = gr.Radio(label = "Explore the comparison between proprietary and non-proprietary databases or edit a toy table with the values you prefer", choices=['Proprietary vs Non-proprietary', 'Custom'], value='Proprietary vs Non-proprietary')
+ #default_checkbox = gr.Checkbox(label="Use default DataFrame"
+
+ table_default = gr.Markdown(description, visible=True)
+ preview_output = gr.DataFrame(interactive=False, visible=False, value=df_default)
+
+ gr.Markdown("## π Or upload your data")
+ file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"])
+ submit_button = gr.Button("Load Data") # Disabled by default
+ output = gr.JSON(visible=False) # Dictionary output
+
+ # Function to enable the button if there is data to load
+ def enable_submit(file, use_default):
+ return gr.update(interactive=bool(file or use_default))
+
+ # Function to uncheck the checkbox if a file is uploaded
+ def deselect_default(file):
+ if file:
+ return gr.update(value='Proprietary vs Non-proprietary')
+ return gr.update()
+
+ def enable_disable_first(enable):
+ return (
+ gr.update(interactive=enable),
+ gr.update(interactive=enable),
+ gr.update(interactive=enable),
+ gr.update(interactive=enable)
+ )
+
+ # Enable the button when inputs are provided
+ #file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
+ #default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
+
+ # Show preview of the default DataFrame when checkbox is selected
+ default_checkbox.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default])
+ file_input.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default])
+ preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output])
+
+ # Uncheck the checkbox when a file is uploaded
+ file_input.change(fn=deselect_default, inputs=[file_input], outputs=[default_checkbox])
+
+ def handle_output(file, use_default):
+ """Handles the output when the 'Load Data' button is pressed."""
+ result = load_data(file, None, use_default)
+
+ if isinstance(result, dict): # If result is a dictionary of DataFrames
+ if len(result) == 1: # If there's only one table
+ input_data['data']['selected_tables'] = list(input_data['data']['data_frames'].keys())
+ return (
+ gr.update(visible=False), # Hide JSON output
+ result, # Save the data state
+ gr.update(visible=False), # Hide table selection
+ result, # Maintain the data state
+ gr.update(interactive=False), # Disable the submit button
+ gr.update(visible=True, open=True), # Proceed to select_model_acc
+ gr.update(visible=True, open=False)
+ )
+ else:
+ return (
+ gr.update(visible=False),
+ result,
+ gr.update(open=True, visible=True),
+ result,
+ gr.update(interactive=False),
+ gr.update(visible=False), # Keep current behavior
+ gr.update(visible=True, open=False)
+ )
+ else:
+ return (
+ gr.update(visible=False),
+ None,
+ gr.update(open=False, visible=True),
+ None,
+ gr.update(interactive=True),
+ gr.update(visible=False),
+ gr.update(visible=True, open=False)
+ )
+
+ submit_button.click(
+ fn=handle_output,
+ inputs=[file_input, default_checkbox],
+ outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc]
+ )
+
+ submit_button.click(
+ fn=enable_disable_first,
+ inputs=[gr.State(False)],
+ outputs=[
+ preview_output,
+ submit_button,
+ file_input,
+ default_checkbox
+ ]
+ )
+
+ ######################################
+ # TABLE SELECTION PART #
+ ######################################
+ with select_table_acc:
+ previous_selection = gr.State([])
+ table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
+ excluded_tables_info = gr.HTML(label="Non-selectable tables (too many columns)", visible=False)
+ table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
+ selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
+
+ # Model selection button (initially disabled)
+ open_model_selection = gr.Button("Choose your models", interactive=False)
+ def update_table_list(data):
+ """Dynamically updates the list of available tables and excluded ones."""
+ if isinstance(data, dict) and data:
+ table_names = []
+ excluded_tables = []
+
+ data_frames = input_data['data'].get('data_frames', {})
+
+ available_tables = []
+ for name, df in data.items():
+ df_real = data_frames.get(name, None)
+ if input_data['input_method'] != "default":
+ if df_real is not None and df_real.shape[1] > 15:
+ excluded_tables.append(name)
+ else:
+ available_tables.append(name)
+ else:
+ available_tables.append(name)
+
+
+ if input_data['input_method'] == "default":
+ table_names.append("All")
+ excluded_tables = []
+ elif len(available_tables) < 6:
+ table_names.append("All")
+
+ table_names.extend(available_tables)
+ if excluded_tables and input_data['input_method'] != "default" :
+ excluded_text = "β οΈ The following tables have more than 15 columns and cannot be selected:
" + "
".join(f"- {t}" for t in excluded_tables)
+ excluded_visible = True
+ else:
+ excluded_text = ""
+ excluded_visible = False
+
+ return [
+ gr.update(choices=table_names, value=[]), # CheckboxGroup update
+ gr.update(value=excluded_text, visible=excluded_visible) # HTML display update
+ ]
+
+ return [
+ gr.update(choices=[], value=[]),
+ gr.update(value="", visible=False)
+ ]
+
+ def show_selected_tables(data, selected_tables):
+ updates = []
+ data_frames = input_data['data'].get('data_frames', {})
+
+ available_tables = []
+ for name, df in data.items():
+ df_real = data_frames.get(name)
+ if input_data['input_method'] != "default" :
+ if df_real is not None and df_real.shape[1] <= 15:
+ available_tables.append(name)
+ else:
+ available_tables.append(name)
+
+ input_method = input_data['input_method']
+ allow_all = input_method == "default" or len(available_tables) < 6
+
+ selected_set = set(selected_tables)
+ tables_set = set(available_tables)
+
+ if allow_all:
+ if "All" in selected_set:
+ selected_tables = ["All"] + available_tables
+ elif selected_set == tables_set:
+ selected_tables = []
+ else:
+ selected_tables = [t for t in selected_tables if t in available_tables]
+ else:
+ selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5]
+
+ tables = {name: data[name] for name in selected_tables if name in data}
+
+ for i, (name, df) in enumerate(tables.items()):
+ updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
+
+ for _ in range(len(tables), 50):
+ updates.append(gr.update(visible=False))
+
+ updates.append(gr.update(interactive=bool(tables)))
+
+ if allow_all:
+ updates.insert(0, gr.update(
+ choices=["All"] + available_tables,
+ value=selected_tables
+ ))
+ else:
+ if len(selected_tables) >= 5:
+ updates.insert(0, gr.update(
+ choices=selected_tables,
+ value=selected_tables
+ ))
+ else:
+ updates.insert(0, gr.update(
+ choices=available_tables,
+ value=selected_tables
+ ))
+
+ return updates
+
+ def show_selected_table_names(data, selected_tables):
+ """Displays the names of the selected tables when the button is pressed."""
+ if selected_tables:
+ available_tables = list(data.keys()) # Actually available names
+ if "All" in selected_tables:
+ selected_tables = available_tables
+ if (input_data['input_method'] != "default") : selected_tables = [t for t in selected_tables if len(data[t].columns) <= 15]
+
+ input_data['data']['selected_tables'] = selected_tables
+ return gr.update(value=", ".join(selected_tables), visible=False)
+ return gr.update(value="", visible=False)
+
+ # Automatically updates the checkbox list when `data_state` changes
+ data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector, excluded_tables_info])
+
+ # Updates the visible tables and the button state based on user selections
+ #table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
+ table_selector.change(
+ fn=show_selected_tables,
+ inputs=[data_state, table_selector],
+ outputs=[table_selector] + table_outputs + [open_model_selection]
+ )
+ # Shows the list of selected tables when "Choose your models" is clicked
+ open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names])
+ open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
+
+ reset_data = gr.Button("Back to upload data section")
+
+ reset_data.click(
+ fn=enable_disable_first,
+ inputs=[gr.State(True)],
+ outputs=[
+ preview_output,
+ submit_button,
+ file_input,
+ default_checkbox
+ ]
+ )
+ reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
+
+ ####################################
+ # MODEL SELECTION PART #
+ ####################################
+ with select_model_acc:
+ gr.Markdown("# Model Selection")
+
+ # Assume that `us.read_models_csv` also returns the image path
+ model_list_dict = us.read_models_csv(models_path)
+ model_list = [model["code"] for model in model_list_dict]
+ model_images = [model["image_path"] for model in model_list_dict]
+ model_names = [model["name"] for model in model_list_dict]
+ # Create a mapping between model_list and model_images_names
+ model_mapping = dict(zip(model_list, model_names))
+ model_mapping_reverse = dict(zip(model_names, model_list))
+
+ model_checkboxes = []
+ rows = []
+
+ # Dynamically create checkboxes with images (3 per row)
+ for i in range(0, len(model_list), 3):
+ with gr.Row():
+ cols = []
+ for j in range(3):
+ if i + j < len(model_list):
+ model = model_list[i + j]
+ image_path = model_images[i + j]
+ with gr.Column():
+ gr.Image(image_path,
+ show_label=False,
+ container=False,
+ interactive=False,
+ show_fullscreen_button=False,
+ show_download_button=False,
+ show_share_button=False)
+ checkbox = gr.Checkbox(label=model_mapping[model], value=False)
+ model_checkboxes.append(checkbox)
+ cols.append(checkbox)
+ rows.append(cols)
+
+ selected_models_output = gr.JSON(visible=False)
+
+ # Function to get selected models
+ def get_selected_models(*model_selections):
+ selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
+ input_data['models'] = selected_models
+ button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
+ return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
+
+ # Add the Textbox to the interface
+ with gr.Row():
+ button_prompt_nlsql = gr.Button("Choose NL2SQL task")
+ button_prompt_tqa = gr.Button("Choose TQA task")
+
+ prompt = gr.TextArea(
+ label="Customise the prompt for selected models here or leave the default one.",
+ placeholder=prompt_default,
+ elem_id="custom-textarea"
+ )
+
+ warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False)
+
+ # Submit button (initially disabled)
+ with gr.Row():
+ submit_models_button = gr.Button("Submit Models", interactive=False)
+
+ def check_prompt(prompt):
+ #TODO
+ missing_elements = []
+ if(prompt==""):
+ global flag_TQA
+ if not flag_TQA:
+ input_data["prompt"] = prompt_default
+ else:
+ input_data["prompt"] = prompt_default_tqa
+ button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
+ else:
+ input_data["prompt"] = prompt
+ if "{db_schema}" not in prompt:
+ missing_elements.append("{db_schema}")
+ if "{question}" not in prompt:
+ missing_elements.append("{question}")
+ button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
+ if missing_elements:
+ return gr.update(
+ value=f""
+ f"β Missing {', '.join(missing_elements)} in the prompt β
",
+ visible=True
+ ), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"])
+ return gr.update(visible=False), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"])
+
+ prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button])
+ # Link checkboxes to selection events
+ for checkbox in model_checkboxes:
+ checkbox.change(
+ fn=get_selected_models,
+ inputs=model_checkboxes,
+ outputs=[selected_models_output, select_model_acc, submit_models_button]
+ )
+ prompt.change(
+ fn=get_selected_models,
+ inputs=model_checkboxes,
+ outputs=[selected_models_output, select_model_acc, submit_models_button]
+ )
+
+ submit_models_button.click(
+ fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
+ inputs=model_checkboxes,
+ outputs=[selected_models_output, select_model_acc, qatch_acc]
+ )
+
+ def change_flag():
+ global flag_TQA
+ flag_TQA = True
+
+ def dis_flag():
+ global flag_TQA
+ flag_TQA = False
+
+ button_prompt_tqa.click(fn = change_flag, inputs=[], outputs=[])
+
+ button_prompt_nlsql.click(fn = dis_flag, inputs=[], outputs=[])
+
+ button_prompt_tqa.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt])
+
+ button_prompt_nlsql.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt])
+
+
+ def enable_disable(enable):
+ return (
+ *[gr.update(interactive=enable) for _ in model_checkboxes],
+ gr.update(interactive=enable),
+ gr.update(interactive=enable),
+ gr.update(interactive=enable),
+ gr.update(interactive=enable),
+ gr.update(interactive=enable),
+ gr.update(interactive=enable),
+ *[gr.update(interactive=enable) for _ in table_outputs],
+ gr.update(interactive=enable)
+ )
+
+ reset_data = gr.Button("Back to upload data section")
+
+ submit_models_button.click(
+ fn=enable_disable,
+ inputs=[gr.State(False)],
+ outputs=[
+ *model_checkboxes,
+ submit_models_button,
+ preview_output,
+ submit_button,
+ file_input,
+ default_checkbox,
+ table_selector,
+ *table_outputs,
+ open_model_selection
+ ]
+ )
+
+ reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
+
+ reset_data.click(
+ fn=enable_disable,
+ inputs=[gr.State(True)],
+ outputs=[
+ *model_checkboxes,
+ submit_models_button,
+ preview_output,
+ submit_button,
+ file_input,
+ default_checkbox,
+ table_selector,
+ *table_outputs,
+ open_model_selection
+ ]
+ )
+
+ #############################
+ # QATCH EXECUTION #
+ #############################
+ with qatch_acc:
+ def change_text(text):
+ return text
+
+ loading_symbols= {1:"π",
+ 2: "π π",
+ 3: "π π π",
+ 4: "π π π π",
+ 5: "π π π π π",
+ 6: "π π π π π π",
+ 7: "π π π π π π π",
+ 8: "π π π π π π π π",
+ 9: "π π π π π π π π π",
+ 10:"π π π π π π π π π π",
+ }
+
+ def generate_loading_text(percent):
+ num_symbols = (round(percent) % 11) + 1
+ symbols = loading_symbols.get(num_symbols, "π")
+ mirrored_symbols = f'{symbols.strip()}'
+ css_symbols = f'{symbols.strip()}'
+ return f"""
+
+ {css_symbols}
+
+ Generation {percent}%
+
+ {mirrored_symbols}
+
+ """
+
+ def generate_eval_text(text):
+ symbols = "π‘ "
+ mirrored_symbols = f'{symbols.strip()}'
+ css_symbols = f'{symbols.strip()}'
+ return f"""
+
+ {css_symbols}
+
+ {text}
+
+ {mirrored_symbols}
+
+ """
+
+ def qatch_flow_nl_sql():
+ global reset_flag
+ global flag_TQA
+ predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
+ metrics_conc = pd.DataFrame()
+ columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
+ if (input_data['input_method']=="default"):
+ #target_df = us.load_csv(pnp_path) #target_df = us.load_csv("priority_non_priority_metrics.csv")
+ target_df = us.load_csv(pnp_path) if not flag_TQA else us.load_csv(PNP_TQA_PATH)
+ #predictions_dict = {model: pd.DataFrame(columns=target_df.columns) for model in model_list}
+ target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])]
+ target_df = target_df[target_df["model"].isin(input_data['models'])]
+ predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list}
+ reset_flag = False
+ for model in input_data['models']:
+ model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
+ yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
+ count=1
+ for _, row in predictions_dict[model].iterrows():
+ #for index, row in target_df.iterrows():
+ if (reset_flag == False):
+ percent_complete = round(count / len(predictions_dict[model]) * 100, 2)
+ count=count+1
+ load_text = f"{generate_loading_text(percent_complete)}"
+ question = row['question']
+
+ display_question = f"""Natural Language:
+
+ """
+ yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
+
+ prediction = row['predicted_sql']
+
+ display_prediction = f"""Predicted SQL:
+
+
β‘οΈ
+
{prediction}
+
+ """
+
+ yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
+ yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
+ metrics_conc = target_df
+ if 'valid_efficency_score' not in metrics_conc.columns:
+ metrics_conc['valid_efficency_score'] = metrics_conc['VES']
+ if 'VES' not in metrics_conc.columns:
+ metrics_conc['VES'] = metrics_conc['valid_efficency_score']
+ eval_text = generate_eval_text("End evaluation")
+ yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
+
+ else:
+ # global flag_TQA
+ orchestrator_generator = OrchestratorGenerator()
+ target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
+
+ #create target_df[target_answer]
+ if flag_TQA :
+ # if (input_data["prompt"] == prompt_default):
+ # input_data["prompt"] = prompt_default_tqa
+
+ target_df['db_schema'] = target_df.apply(
+ lambda row: utils_get_db_tables_info.utils_extract_db_schema_as_string(
+ db_id=input_data["db_name"],
+ base_path=input_data["data_path"],
+ normalize=False,
+ sql=row["query"],
+ get_insert_into=True,
+ model=None,
+ prompt=input_data["prompt"].format(question=row["question"], db_schema="")
+ ),
+ axis=1
+ )
+
+ target_df = us.extract_answer(target_df)
+
+ predictor = ModelPrediction()
+ reset_flag = False
+ for model in input_data["models"]:
+ model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
+ yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
+ count=0
+ for index, row in target_df.iterrows():
+ if (reset_flag == False):
+ percent_complete = round(((index+1) / len(target_df)) * 100, 2)
+ load_text = f"{generate_loading_text(percent_complete)}"
+
+ question = row['question']
+ display_question = f"""Natural Language:
+
+ """
+ yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
+ #samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
+ model_to_send = None if not flag_TQA else model
+
+ db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
+ db_id = input_data["db_name"],
+ base_path = input_data["data_path"],
+ normalize=False,
+ sql=row["query"],
+ get_insert_into=True,
+ model = model_to_send,
+ prompt = input_data["prompt"].format(question=question, db_schema=""),
+ )
+
+ #prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
+ prompt_to_send = input_data["prompt"]
+ #PREDICTION SQL
+
+ # TODO add button for QA or SP and pass to .make_prediction parameter TASK
+ if flag_TQA: task="QA"
+ else: task="SP"
+ start_time = time.time()
+ response = predictor.make_prediction(
+ question=question,
+ db_schema=db_schema_text,
+ model_name=model,
+ prompt=f"{prompt_to_send}",
+ task=task
+ )
+ #if flag_TQA: response = {'response_parsed': "[['Alice'],['Bob'],['Charlie']]", 'cost': 0, 'response': "[['Alice'],['Bob'],['Charlie']]"} # TODO remove this line
+ #else : response = {'response_parsed': "SELECT * FROM 'MyTable'", 'cost': 0, 'response': "SQL_QUERY"}
+ end_time = time.time()
+ prediction = response['response_parsed']
+ price = response['cost']
+ answer = response['response']
+
+ if flag_TQA:
+ task_string = "Answer"
+ else:
+ task_string = "SQL"
+
+ display_prediction = f"""Predicted {task_string}:
+
+
β‘οΈ
+
{prediction}
+
+ """
+ # Create a new row as dataframe
+ new_row = pd.DataFrame([{
+ 'id': index,
+ 'question': question,
+ 'predicted_sql': prediction,
+ 'time': end_time - start_time,
+ 'query': row["query"],
+ 'db_path': input_data["data_path"],
+ 'price':price,
+ 'answer': answer,
+ 'number_question':count,
+ 'target_answer' : row["target_answer"] if flag_TQA else None,
+
+ }]).dropna(how="all") # Remove only completely empty rows
+ count=count+1
+ # TODO: use a for loop
+ if (flag_TQA) :
+ new_row['predicted_answer'] = prediction
+ for col in target_df.columns:
+ if col not in new_row.columns:
+ new_row[col] = row[col]
+ # Update model's prediction dataframe incrementally
+ if not new_row.empty:
+ predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
+
+ # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
+ yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
+ yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
+ # END
+ eval_text = generate_eval_text("Evaluation")
+ yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
+
+ evaluator = OrchestratorEvaluator()
+
+ for model in input_data["models"]:
+ if not flag_TQA:
+ metrics_df_model = evaluator.evaluate_df(
+ df=predictions_dict[model],
+ target_col_name="query",
+ prediction_col_name="predicted_sql",
+ db_path_name="db_path"
+ )
+ else:
+ metrics_df_model = us.evaluate_answer(predictions_dict[model])
+ metrics_df_model['model'] = model
+ metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
+
+ if 'VES' not in metrics_conc.columns and 'valid_efficency_score' not in metrics_conc.columns:
+ metrics_conc['VES'] = 0
+ metrics_conc['valid_efficency_score'] = 0
+
+ if 'valid_efficency_score' not in metrics_conc.columns:
+ metrics_conc['valid_efficency_score'] = metrics_conc['VES']
+
+ if 'VES' not in metrics_conc.columns:
+ metrics_conc['VES'] = metrics_conc['valid_efficency_score']
+
+ eval_text = generate_eval_text("End evaluation")
+ yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
+
+ # Loading Bar
+ with gr.Row():
+ # progress = gr.Progress()
+ variable = gr.Markdown()
+
+ # NL -> MODEL -> Generated Query
+ with gr.Row():
+ with gr.Column():
+ with gr.Column():
+ question_display = gr.Markdown()
+ with gr.Column():
+ model_logo = gr.Image(visible=True,
+ show_label=False,
+ container=False,
+ interactive=False,
+ show_fullscreen_button=False,
+ show_download_button=False,
+ show_share_button=False)
+ with gr.Column():
+ with gr.Column():
+ prediction_display = gr.Markdown()
+
+ dataframe_per_model = {}
+
+ with gr.Tabs() as model_tabs:
+ tab_dict = {}
+
+ for model, model_name in zip(model_list, model_names):
+ with gr.TabItem(model_name, visible=(model in input_data["models"])) as tab:
+ gr.Markdown(f"**Results for {model}**")
+ tab_dict[model] = tab
+ dataframe_per_model[model] = gr.DataFrame()
+ #TODO download metrics per model
+ # download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
+
+ evaluation_loading = gr.Markdown()
+
+ def change_tab():
+ return [gr.update(visible=(model in input_data["models"])) for model in model_list]
+
+ submit_models_button.click(
+ change_tab,
+ inputs=[],
+ outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
+ )
+
+ selected_models_display = gr.JSON(label="Final input data", visible=False)
+ metrics_df = gr.DataFrame(visible=False)
+ metrics_df_out = gr.DataFrame(visible=False)
+
+ submit_models_button.click(
+ fn=qatch_flow_nl_sql,
+ inputs=[],
+ outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
+ )
+
+ submit_models_button.click(
+ fn=lambda: gr.update(value=input_data),
+ outputs=[selected_models_display]
+ )
+
+ # Works for METRICS
+ metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
+
+ proceed_to_metrics_button = gr.Button("Proceed to Metrics", visible=False)
+ proceed_to_metrics_button.click(
+ fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
+ outputs=[qatch_acc, metrics_acc]
+ )
+
+ def allow_download(metrics_df_out):
+ #path = os.path.join(".", "data", "data_results", "results.csv")
+ path = os.path.join(".", "results.csv")
+ metrics_df_out.to_csv(path, index=False)
+ return gr.update(value=path, visible=True), gr.update(visible=True), gr.update(interactive=True)
+
+ download_metrics = gr.DownloadButton(label="Download Metrics Evaluation", visible=False)
+
+ submit_models_button.click(
+ fn=lambda: gr.update(visible=False),
+ outputs=[download_metrics]
+ )
+
+ def refresh():
+ global reset_flag
+ global flag_TQA
+ reset_flag = True
+ flag_TQA = False
+
+ reset_data = gr.Button("Back to upload data section", interactive=True)
+
+ metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data])
+
+ reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
+ #WHY NOT WORKING?
+ reset_data.click(
+ fn=lambda: gr.update(visible=False),
+ outputs=[download_metrics]
+ )
+ reset_data.click(refresh)
+
+ reset_data.click(
+ fn=enable_disable,
+ inputs=[gr.State(True)],
+ outputs=[
+ *model_checkboxes,
+ submit_models_button,
+ preview_output,
+ submit_button,
+ file_input,
+ default_checkbox,
+ table_selector,
+ *table_outputs,
+ open_model_selection
+ ]
+ )
+
+ ##########################################
+ # METRICS VISUALIZATION SECTION #
+ ##########################################
+ with metrics_acc:
+ #data_path = 'test_results_metrics1.csv'
+
+ @gr.render(inputs=metrics_df_out)
+ def function_metrics(metrics_df_out):
+
+ ####################################
+ # UTILS FUNCTIONS SECTION #
+ ####################################
+
+ def load_data_csv_es():
+
+ if input_data["input_method"]=="default":
+ global flag_TQA
+ #df = pd.read_csv(pnp_path)
+ df = us.load_csv(pnp_path) if not flag_TQA else us.load_csv(PNP_TQA_PATH)
+ df = df[df['model'].isin(input_data["models"])]
+ df = df[df['tbl_name'].isin(input_data["data"]["selected_tables"])]
+
+ df['model'] = df['model'].replace('DeepSeek-R1-Distill-Llama-70B', 'DS-Llama3 70B')
+ df['model'] = df['model'].replace('gpt-3.5', 'GPT-3.5')
+ df['model'] = df['model'].replace('gpt-4o-mini', 'GPT-4o-mini')
+ df['model'] = df['model'].replace('llama-70', 'Llama-70B')
+ df['model'] = df['model'].replace('llama-8', 'Llama-8B')
+ df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY')
+ #if (flag_TQA) : flag_TQA = False #TODO delete after make pred
+ return df
+ return metrics_df_out
+
+ def calculate_average_metrics(df, selected_metrics):
+ # Exclude the 'tuple_order' column from the selected metrics
+
+ #TODO tuple_order has NULL VALUE
+ selected_metrics = [metric for metric in selected_metrics if metric != 'tuple_order']
+ #print(df[selected_metrics])
+ df['avg_metric'] = df[selected_metrics].mean(axis=1)
+ return df
+
+ def generate_model_colors():
+ """Generates a unique color map for models in the dataset."""
+ df = load_data_csv_es()
+ unique_models = df['model'].unique() # Extract unique models
+ num_models = len(unique_models)
+
+ # Use the Plotly color scale (you can change it if needed)
+ color_palette = ['#00B4D8', '#BCE784', '#C84630', '#F79256', '#D269FC']
+ #color_palette = pc.qualitative.Plotly # ['#636EFA', '#EF553B', '#00CC96', ...]
+
+ # If there are more models than colors, cycle through them
+ colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)}
+
+ return colors
+
+ MODEL_COLORS = generate_model_colors()
+
+ def generate_db_category_colors():
+ """Assigns 3 distinct colors to db_category groups."""
+ return {
+ "Spider": "#1f77b4", # blu
+ "Beaver": "#ff7f0e", # arancione
+ "Economic": "#2ca02c", # tutti gli altri verdi
+ "Financial": "#2ca02c",
+ "Medical": "#2ca02c",
+ "Miscellaneous": "#2ca02c"
+ }
+
+ DB_CATEGORY_COLORS = generate_db_category_colors()
+
+ def normalize_valid_efficency_score(df):
+ df['valid_efficency_score'] = df['valid_efficency_score'].replace([np.nan, ''], 0)
+ df['valid_efficency_score'] = df['valid_efficency_score'].astype(int)
+ min_val = df['valid_efficency_score'].min()
+ max_val = df['valid_efficency_score'].max()
+
+ if min_val == max_val :
+ # All values are equal, so for avoid division by zero, we set the score to 1/0
+ if min_val == None:
+ df['valid_efficency_score'] = 0
+ else:
+ df['valid_efficency_score'] = 1.0
+ else:
+ df['valid_efficency_score'] = (
+ df['valid_efficency_score'] - min_val
+ ) / (max_val - min_val)
+
+ return df
+
+
+ ####################################
+ # GRAPH FUNCTIONS SECTION #
+ ####################################
+
+ # BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
+ def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models):
+ df = df[df['model'].isin(selected_models)]
+ df = normalize_valid_efficency_score(df)
+
+ # Mappatura nomi leggibili -> tecnici
+ qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
+ external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric]
+
+ selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal
+
+ df = calculate_average_metrics(df, selected_metrics)
+
+ if group_by == ["model"]:
+ # Bar plot per "model"
+ avg_metrics = df.groupby("model")['avg_metric'].mean().reset_index()
+ avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
+
+ fig = px.bar(
+ avg_metrics,
+ x="model",
+ y="avg_metric",
+ color="model",
+ color_discrete_map=MODEL_COLORS,
+ title='Average metrics per Model π§ ',
+ labels={"model": "Model", "avg_metric": "Average Metrics"},
+ template='simple_white',
+ #template='plotly_dark',
+ text='text_label'
+ )
+ else:
+ if group_by != ["tbl_name", "model"]:
+ group_by = ["tbl_name", "model"]
+
+ avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index()
+ avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
+
+ fig = px.bar(
+ avg_metrics,
+ x=group_by[0],
+ y='avg_metric',
+ color='model',
+ color_discrete_map=MODEL_COLORS,
+ barmode='group',
+ title=f'Average metrics per {group_by[0]} π',
+ labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metrics'},
+ template='simple_white',
+ #template='plotly_dark',
+ text='text_label'
+ )
+
+ fig.update_traces(textposition='outside', textfont_size=10)
+
+ # Applica font Inter a tutto il layout
+ fig.update_layout(
+ margin=dict(t=80),
+ title=dict(
+ font=dict(
+ family="Inter, sans-serif",
+ size=22,
+ #color="white"
+ ),
+ x=0.5
+ ),
+ xaxis=dict(
+ title=dict(
+ font=dict(
+ family="Inter, sans-serif",
+ size=18,
+ #color="white"
+ )
+ ),
+ tickfont=dict(
+ family="Inter, sans-serif",
+ #color="white"
+ size=16
+ )
+ ),
+ yaxis=dict(
+ title=dict(
+ font=dict(
+ family="Inter, sans-serif",
+ size=18,
+ #color="white"
+ )
+ ),
+ tickfont=dict(
+ family="Inter, sans-serif",
+ #color="white"
+ )
+ ),
+ legend=dict(
+ title=dict(
+ font=dict(
+ family="Inter, sans-serif",
+ size=16,
+ #color="white"
+ )
+ ),
+ font=dict(
+ family="Inter, sans-serif",
+ #color="white"
+ )
+ )
+ )
+
+ return gr.Plot(fig, visible=True)
+
+ def update_plot(radio_metric, qatch_selected_metrics, external_selected_metric,group_by, selected_models):
+ df = load_data_csv_es()
+ return plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models)
+
+ # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION
+ def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
+ if selected_models == "All":
+ selected_models = models
+ else:
+ selected_models = [selected_models]
+
+ df = df[df['model'].isin(selected_models)]
+ df = normalize_valid_efficency_score(df)
+
+ # Converti nomi leggibili -> tecnici
+ qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
+ external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric]
+
+ selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal
+
+ df = calculate_average_metrics(df, selected_metrics)
+
+ avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index()
+ avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
+ fig = px.bar(
+ avg_metrics,
+ x='db_category',
+ y='avg_metric',
+ color='model',
+ color_discrete_map=MODEL_COLORS,
+ barmode='group',
+ title='Average metrics per database types π',
+ labels={'db_path': 'DB Path', 'avg_metric': 'Average Metrics'},
+ template='simple_white',
+ text='text_label'
+ )
+
+ fig.update_traces(textposition='outside', textfont_size=14)
+
+ # Aggiorna layout con font Inter
+ fig.update_layout(
+ margin=dict(t=80),
+ title=dict(
+ font=dict(
+ family="Inter, sans-serif",
+ size=24,
+ color="black"
+ ),
+ x=0.5
+ ),
+ xaxis=dict(
+ title=dict(
+ text='Database Category',
+ font=dict(
+ family='Inter, sans-serif',
+ size=22,
+ color='black'
+ )
+ ),
+ tickfont=dict(
+ family='Inter, sans-serif',
+ color='black',
+ size=20
+ )
+ ),
+ yaxis=dict(
+ title=dict(
+ text='Average Metrics',
+ font=dict(
+ family='Inter, sans-serif',
+ size=22,
+ color='black'
+ )
+ ),
+ tickfont=dict(
+ family='Inter, sans-serif',
+ color='black'
+ )
+ ),
+ legend=dict(
+ title=dict(
+ text='Models',
+ font=dict(
+ family='Inter, sans-serif',
+ size=20,
+ color='black'
+ )
+ ),
+ font=dict(
+ family='Inter, sans-serif',
+ color='black',
+ size=18
+ )
+ )
+ )
+
+ return gr.Plot(fig, visible=True)
+
+ def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
+ df = load_data_csv_es()
+ return plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models)
+
+ # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION
+
+ def lollipop_propietary(selected_models):
+ df = load_data_csv_es()
+
+ # Filtra solo le categorie rilevanti
+ target_cats = ["Spider", "Economic", "Financial", "Medical", "Miscellaneous", "Beaver"]
+ df = df[df['db_category'].isin(target_cats)]
+ df = df[df['model'].isin(selected_models)]
+
+ df = normalize_valid_efficency_score(df)
+ df = calculate_average_metrics(df, qatch_metrics)
+
+ # Calcola la media per db_category e modello
+ avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index()
+
+ # Separa Spider e le altre 4 categorie
+ spider_df = avg_metrics[avg_metrics["db_category"] == "Spider"]
+ other_df = avg_metrics[avg_metrics["db_category"] != "Spider"]
+
+ # Calcola media delle altre categorie per ciascun modello
+ other_mean_df = other_df.groupby("model")["avg_metric"].mean().reset_index()
+ other_mean_df["db_category"] = "Others"
+
+ # Rinominare per chiarezza e uniformitΓ
+ spider_df = spider_df.rename(columns={"avg_metric": "Spider"})
+ other_mean_df = other_mean_df.rename(columns={"avg_metric": "Others"})
+
+ # Unione dei due dataset
+ merged_df = pd.merge(spider_df[["model", "Spider"]], other_mean_df[["model", "Others"]], on="model")
+
+ # Ordina per modello o per valore se vuoi
+ merged_df = merged_df.sort_values(by="model")
+
+ fig = go.Figure()
+
+ # Aggiungi linee orizzontali tra Spider e Others
+ for _, row in merged_df.iterrows():
+ fig.add_trace(go.Scatter(
+ x=[row["Spider"], row["Others"]],
+ y=[row["model"]] * 2,
+ mode='lines',
+ line=dict(color='gray', width=2),
+ showlegend=False
+ ))
+
+ # Punto per Spider
+ fig.add_trace(go.Scatter(
+ x=merged_df["Spider"],
+ y=merged_df["model"],
+ mode='markers',
+ name='Non-Proprietary (Spider)',
+ marker=dict(size=10, color='#C84630')
+ ))
+
+ # Punto per Others (media delle altre 4 categorie)
+ fig.add_trace(go.Scatter(
+ x=merged_df["Others"],
+ y=merged_df["model"],
+ mode='markers',
+ name='Proprietary Databases',
+ marker=dict(size=10, color='#0077B6')
+ ))
+
+ fig.update_layout(
+ xaxis_title='Average Metrics',
+ yaxis_title='Models',
+ template='simple_white',
+ #template='plotly_dark',
+ margin=dict(t=80),
+ title=dict(
+ font=dict(
+ family="Inter, sans-serif",
+ size=22,
+ color="black"
+ ),
+ x=0.5,
+ text='Dumbbell graph: Non-Proprietary (Spider π·οΈ) vs Proprietary Databases π'
+ ),
+ legend_title='Type of Databases:',
+ height=600,
+ xaxis=dict(
+ title=dict(
+ text='DB Category',
+ font=dict(
+ family='Inter, sans-serif',
+ size=18,
+ color='black'
+ )
+ ),
+ tickfont=dict(
+ family='Inter, sans-serif',
+ color='black'
+ )
+ ),
+ yaxis=dict(
+ title=dict(
+ text='Average Metrics',
+ font=dict(
+ family='Inter, sans-serif',
+ size=18,
+ color='black'
+ )
+ ),
+ tickfont=dict(
+ family='Inter, sans-serif',
+ color='black'
+ )
+ ),
+ legend=dict(
+ title=dict(
+ text='Models',
+ font=dict(
+ family='Inter, sans-serif',
+ size=18,
+ color='black'
+ )
+ ),
+ font=dict(
+ family='Inter, sans-serif',
+ color='black',
+ size=14
+ )
+ )
+ )
+
+ return gr.Plot(fig, visible=True)
+
+
+ # RADAR OR BAR CHART BASED ON CATEGORY COUNT
+ def plot_radar(df, selected_models, selected_metrics, selected_categories):
+ if "External" in selected_metrics:
+ selected_metrics = ["execution_accuracy", "valid_efficency_score"]
+ else:
+ selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
+
+ # Filtro modelli e normalizzazione
+ df = df[df['model'].isin(selected_models)]
+ df = normalize_valid_efficency_score(df)
+ df = calculate_average_metrics(df, selected_metrics)
+
+ avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
+
+ if avg_metrics.empty:
+ print("Error: No data available to compute averages.")
+ return go.Figure()
+
+ categories = selected_categories
+
+ if len(categories) < 3:
+ # π BAR PLOT
+ fig = go.Figure()
+ for model in selected_models:
+ model_data = avg_metrics[avg_metrics['model'] == model]
+ values = [
+ model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
+ if cat in model_data['test_category'].values else 0
+ for cat in categories
+ ]
+ fig.add_trace(go.Bar(
+ x=categories,
+ y=values,
+ name=model,
+ marker=dict(color=MODEL_COLORS.get(model, "gray"))
+ ))
+
+ fig.update_layout(
+ barmode='group',
+ title=dict(
+ text='π Bar Plot of Metrics per Model (Few Categories)',
+ font=dict(
+ family='Inter, sans-serif',
+ size=22,
+ #color='white'
+ ),
+ x=0.5
+ ),
+ template='simple_white',
+ #template='plotly_dark',
+ xaxis=dict(
+ title=dict(
+ text='Test Category',
+ font=dict(
+ family='Inter, sans-serif',
+ size=18,
+ #color='white'
+ )
+ ),
+ tickfont=dict(
+ family='Inter, sans-serif',
+ size=16
+ #color='white'
+ )
+ ),
+ yaxis=dict(
+ title=dict(
+ text='Average Metrics',
+ font=dict(
+ family='Inter, sans-serif',
+ size=18,
+ #color='white'
+ )
+ ),
+ tickfont=dict(
+ family='Inter, sans-serif',
+ #color='white'
+ )
+ ),
+ legend=dict(
+ title=dict(
+ text='Models',
+ font=dict(
+ family='Inter, sans-serif',
+ size=16,
+ #color='white'
+ )
+ ),
+ font=dict(
+ family='Inter, sans-serif',
+ #color='white'
+ )
+ )
+ )
+ else:
+ # π§ RADAR PLOT
+ fig = go.Figure()
+ for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True):
+ model_data = avg_metrics[avg_metrics['model'] == model]
+ values = [
+ model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
+ if cat in model_data['test_category'].values else 0
+ for cat in categories
+ ]
+ fig.add_trace(go.Scatterpolar(
+ r=values,
+ theta=categories,
+ fill='toself',
+ name=model,
+ line=dict(color=MODEL_COLORS.get(model, "gray"))
+ ))
+
+ fig.update_layout(
+ polar=dict(
+ radialaxis=dict(
+ visible=True,
+ range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
+ tickfont=dict(
+ family='Inter, sans-serif',
+ #color='white'
+ )
+ ),
+ angularaxis=dict(
+ tickfont=dict(
+ family='Inter, sans-serif',
+ size=16
+ #color='white'
+ )
+ )
+ ),
+ title=dict(
+ text='βοΈ Radar Plot of Metrics per Model (Average per SQL Category)',
+ font=dict(
+ family='Inter, sans-serif',
+ size=22,
+ #color='white'
+ ),
+ x=0.5
+ ),
+ legend=dict(
+ title=dict(
+ text='Models',
+ font=dict(
+ family='Inter, sans-serif',
+ size=18,
+ #color='white'
+ )
+ ),
+ font=dict(
+ family='Inter, sans-serif',
+ size=16
+ #color='white'
+ )
+ ),
+ template='simple_white'
+ #template='plotly_dark'
+ )
+
+ return fig
+
+ def update_radar(selected_models, selected_metrics, selected_categories):
+ df = load_data_csv_es()
+ return plot_radar(df, selected_models, selected_metrics, selected_categories)
+
+ # RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT
+ def plot_radar_sub(df, selected_models, selected_metrics, selected_category):
+ if "External" in selected_metrics:
+ selected_metrics = ["execution_accuracy", "valid_efficency_score"]
+ else:
+ selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
+
+ df = df[df['model'].isin(selected_models)]
+ df = normalize_valid_efficency_score(df)
+ df = calculate_average_metrics(df, selected_metrics)
+
+ if isinstance(selected_category, str):
+ selected_category = [selected_category]
+
+ df = df[df['test_category'].isin(selected_category)]
+ avg_metrics = df.groupby(['model', 'sql_tag'])['avg_metric'].mean().reset_index()
+
+ if avg_metrics.empty:
+ print("Error: No data available to compute averages.")
+ return go.Figure()
+
+ categories = df['sql_tag'].unique().tolist()
+
+ if len(categories) < 3:
+ # π BAR PLOT
+ fig = go.Figure()
+ for model in selected_models:
+ model_data = avg_metrics[avg_metrics['model'] == model]
+ values = [
+ model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0]
+ if cat in model_data['sql_tag'].values else 0
+ for cat in categories
+ ]
+ fig.add_trace(go.Bar(
+ x=categories,
+ y=values,
+ name=model,
+ marker=dict(color=MODEL_COLORS.get(model, "gray"))
+ ))
+
+ fig.update_layout(
+ barmode='group',
+ title=dict(
+ text='π Bar Plot of Metrics per Model (Few Sub-Categories)',
+ font=dict(
+ family='Inter, sans-serif',
+ size=22,
+ #color='white'
+ ),
+ x=0.5
+ ),
+ template='simple_white',
+ #template='plotly_dark',
+ xaxis=dict(
+ title=dict(
+ text='SQL Tag (Sub Category)',
+ font=dict(
+ family='Inter, sans-serif',
+ size=18,
+ #color='white'
+ )
+ ),
+ tickfont=dict(
+ family='Inter, sans-serif',
+ #color='white'
+ )
+ ),
+ yaxis=dict(
+ title=dict(
+ text='Average Metrics',
+ font=dict(
+ family='Inter, sans-serif',
+ size=18,
+ #color='white'
+ )
+ ),
+ tickfont=dict(
+ family='Inter, sans-serif',
+ #color='white'
+ )
+ ),
+ legend=dict(
+ title=dict(
+ text='Models',
+ font=dict(
+ family='Inter, sans-serif',
+ size=16,
+ #color='white'
+ )
+ ),
+ font=dict(
+ family='Inter, sans-serif',
+ size=14
+ #color='white'
+ )
+ )
+ )
+ else:
+ # π§ RADAR PLOT
+ fig = go.Figure()
+
+ for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True):
+ model_data = avg_metrics[avg_metrics['model'] == model]
+ values = [
+ model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0]
+ if cat in model_data['sql_tag'].values else 0
+ for cat in categories
+ ]
+
+ fig.add_trace(go.Scatterpolar(
+ r=values,
+ theta=categories,
+ fill='toself',
+ name=model,
+ line=dict(color=MODEL_COLORS.get(model, "gray"))
+ ))
+
+ fig.update_layout(
+ polar=dict(
+ radialaxis=dict(
+ visible=True,
+ range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
+ tickfont=dict(
+ family='Inter, sans-serif',
+ #color='white'
+ )
+ ),
+ angularaxis=dict(
+ tickfont=dict(
+ family='Inter, sans-serif',
+ size=16
+ #color='white'
+ )
+ )
+ ),
+ title=dict(
+ text='βοΈ Radar Plot of Metrics per Model (Average per SQL Sub-Category)',
+ font=dict(
+ family='Inter, sans-serif',
+ size=22,
+ #color='white'
+ ),
+ x=0.5
+ ),
+ legend=dict(
+ title=dict(
+ text='Models',
+ font=dict(
+ family='Inter, sans-serif',
+ size=16,
+ #color='white'
+ )
+ ),
+ font=dict(
+ family='Inter, sans-serif',
+ size=14,
+ #color='white'
+ )
+ ),
+ template='simple_white'
+ #template='plotly_dark'
+ )
+
+ return fig
+
+ def update_radar_sub(selected_models, selected_metrics, selected_category):
+ df = load_data_csv_es()
+ return plot_radar_sub(df, selected_models, selected_metrics, selected_category)
+
+ # RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
+ def worst_cases_text(df, selected_models, selected_metrics, selected_categories):
+ global flag_TQA
+ if selected_models == "All":
+ selected_models = models
+ else:
+ selected_models = [selected_models]
+
+ if selected_categories == "All":
+ selected_categories = principal_categories
+ else:
+ selected_categories = [selected_categories]
+
+ df = df[df['model'].isin(selected_models)]
+ df = df[df['test_category'].isin(selected_categories)]
+
+ if "external" in selected_metrics:
+ selected_metrics = ["execution_accuracy", "valid_efficency_score"]
+ else:
+ selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
+
+ df = normalize_valid_efficency_score(df)
+ df = calculate_average_metrics(df, selected_metrics)
+
+ if flag_TQA:
+ df["target_answer"] = df["target_answer"] = df["target_answer"].apply(lambda x: "[" + ", ".join(map(str, x)) + "]")
+
+ worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'target_answer', 'predicted_answer', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
+ else:
+ worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
+
+ worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
+
+ worst_cases_top_3 = worst_cases_df.head(3)
+
+ worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2)
+
+ worst_str = []
+ answer_str = []
+
+ medals = ["π₯", "π₯", "π₯"]
+
+ for i, row in worst_cases_top_3.iterrows():
+ if flag_TQA:
+ entry = (
+ f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n"
+ f"- Question: {row['question']} \n"
+ f"- Original Answer: `{row['target_answer']}` \n"
+ f"- Predicted Answer: `{eval(row['predicted_answer'])}` \n\n"
+ )
+
+ worst_str.append(entry)
+ else:
+ entry = (
+ f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n"
+ f"- Question: {row['question']} \n"
+ f"- Original Query: `{row['query']}` \n"
+ f"- Predicted SQL: `{row['predicted_sql']}` \n\n"
+ )
+
+ worst_str.append(entry)
+
+ raw_answer = (
+ f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n"
+ f"- Raw Answer:
`{row['answer']}` \n"
+ )
+
+ answer_str.append(raw_answer)
+
+ return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2]
+
+ def update_worst_cases_text(selected_models, selected_metrics, selected_categories):
+ df = load_data_csv_es()
+ return worst_cases_text(df, selected_models, selected_metrics, selected_categories)
+
+ # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
+ def plot_cumulative_flow(df, selected_models, max_points):
+ df = df[df['model'].isin(selected_models)]
+ df = normalize_valid_efficency_score(df)
+
+ fig = go.Figure()
+
+ for model in selected_models:
+ model_df = df[df['model'] == model].copy()
+
+ # Limita il numero di punti se richiesto
+ if max_points is not None:
+ model_df = model_df.head(max_points + 1)
+
+ # Tooltip personalizzato
+ model_df['hover_info'] = model_df.apply(
+ lambda row:
+ f"Id question: {row['number_question']}
"
+ f"Question: {row['question']}
"
+ f"Target: {row['query']}
"
+ f"Prediction: {row['predicted_sql']}
"
+ f"Category: {row['test_category']}",
+ axis=1
+ )
+
+ # Calcoli cumulativi
+ model_df['cumulative_time'] = model_df['time'].cumsum()
+ model_df['cumulative_price'] = model_df['price'].cumsum()
+
+ # Colore del modello
+ color = MODEL_COLORS.get(model, "gray")
+
+ fig.add_trace(go.Scatter(
+ x=model_df['cumulative_time'],
+ y=model_df['cumulative_price'],
+ mode='lines+markers',
+ name=model,
+ line=dict(width=2, color=color),
+ customdata=model_df['hover_info'],
+ hovertemplate=
+ "Model: " + model + "
" +
+ "Cumulative Time: %{x}s
" +
+ "Cumulative Price: $%{y:.2f}
" +
+ "
Details:
%{customdata}"
+ ))
+
+ # Layout con font elegante
+ fig.update_layout(
+ title=dict(
+ text="Cumulative Price Flow Chart π°",
+ font=dict(
+ family="Inter, sans-serif",
+ size=24,
+ #color="white"
+ ),
+ x=0.5
+ ),
+ xaxis=dict(
+ title=dict(
+ text="Cumulative Time (s)",
+ font=dict(
+ family="Inter, sans-serif",
+ size=20,
+ #color="white"
+ )
+ ),
+ tickfont=dict(
+ family="Inter, sans-serif",
+ size=18
+ #color="white"
+ )
+ ),
+ yaxis=dict(
+ title=dict(
+ text="Cumulative Price ($)",
+ font=dict(
+ family="Inter, sans-serif",
+ size=20,
+ #color="white"
+ )
+ ),
+ tickfont=dict(
+ family="Inter, sans-serif",
+ size=18
+ #color="white"
+ )
+ ),
+ legend=dict(
+ title=dict(
+ text="Models",
+ font=dict(
+ family="Inter, sans-serif",
+ size=18,
+ #color="white"
+ )
+ ),
+ font=dict(
+ family="Inter, sans-serif",
+ size=16,
+ #color="white"
+ )
+ ),
+ template='simple_white',
+ #template="plotly_dark"
+ )
+
+ return fig
+
+ def update_query_rate(selected_models, max_points):
+ df = load_data_csv_es()
+ return plot_cumulative_flow(df, selected_models, max_points)
+
+
+
+
+ #######################
+ # PARAMETER SECTION #
+ #######################
+ qatch_metrics_dict = {
+ "Cell Precision": "cell_precision",
+ "Cell Recall": "cell_recall",
+ "Tuple Order": "tuple_order",
+ "Tuple Cardinality": "tuple_cardinality",
+ "Tuple Constraint": "tuple_constraint"
+ }
+
+ qatch_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
+ last_valid_qatch_metrics_selection = qatch_metrics.copy() # Per salvare lβultima selezione valida
+ def enforce_qatch_metrics_selection(selected):
+ global last_valid_qatch_metrics_selection
+ if not selected: # Se nessuna metrica Γ¨ selezionata
+ return gr.update(value=last_valid_qatch_metrics_selection)
+ last_valid_qatch_metrics_selection = selected # Altrimenti aggiorna la selezione valida
+ return gr.update(value=selected)
+
+ external_metrics_dict = {
+ "Execution Accuracy": "execution_accuracy",
+ "Valid Efficency Score": "valid_efficency_score"
+ }
+
+ external_metric = ["execution_accuracy", "valid_efficency_score"]
+ last_valid_external_metric_selection = external_metric.copy()
+ def enforce_external_metric_selection(selected):
+ global last_valid_external_metric_selection
+ if not selected: # Se nessuna metrica Γ¨ selezionata
+ return gr.update(value=last_valid_external_metric_selection)
+ last_valid_external_metric_selection = selected # Altrimenti aggiorna la selezione valida
+ return gr.update(value=selected)
+
+ all_metrics = {
+ "Qatch": ["qatch"],
+ "External": ["external"]
+ }
+
+ group_options = {
+ "Table": ["tbl_name", "model"],
+ "Model": ["model"]
+ }
+
+ df_initial = load_data_csv_es()
+ models = models = df_initial['model'].unique().tolist()
+ last_valid_model_selection = models.copy() # Per salvare lβultima selezione valida
+ def enforce_model_selection(selected):
+ global last_valid_model_selection
+ if not selected: # Se nessuna metrica Γ¨ selezionata
+ return gr.update(value=last_valid_model_selection)
+ last_valid_model_selection = selected # Altrimenti aggiorna la selezione valida
+ return gr.update(value=selected)
+
+ all_categories = df_initial['sql_tag'].unique().tolist()
+
+ principal_categories = df_initial['test_category'].unique().tolist()
+ last_valid_category_selection = principal_categories.copy() # Per salvare lβultima selezione valida
+ def enforce_category_selection(selected):
+ global last_valid_category_selection
+ if not selected: # Se nessuna metrica Γ¨ selezionata
+ return gr.update(value=last_valid_category_selection)
+ last_valid_category_selection = selected # Altrimenti aggiorna la selezione valida
+ return gr.update(value=selected)
+
+ all_categories_as_dic = {cat: [f"{cat}"] for cat in principal_categories}
+
+ all_categories_as_dic_ranking = {cat: [f"{cat}"] for cat in principal_categories}
+ all_categories_as_dic_ranking["All"] = principal_categories
+
+ all_model_as_dic = {cat: [f"{cat}"] for cat in models}
+ all_model_as_dic["All"] = models
+
+ ###########################
+ # VISUALIZATION SECTION #
+ ###########################
+ gr.Markdown("""# Model Performance Analysis""")
+
+ #FOR BAR
+ gr.Markdown("""## Section 1: Model - Data""")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ with gr.Row():
+ choose_metrics_bar = gr.Radio(
+ choices=list(all_metrics.keys()),
+ label="Select the metrics group that you want to use:",
+ value="Qatch"
+ )
+
+ with gr.Row():
+ qatch_info = gr.HTML("""
+
+ Qatch metric info βΉοΈ
+
+ """, visible=True)
+
+ external_info = gr.HTML("""
+
+ External metric info βΉοΈ
+
+ """, visible=False)
+
+ qatch_metric_multiselect_bar = gr.CheckboxGroup(
+ choices=list(qatch_metrics_dict.keys()),
+ label="Select one or mode Qatch metrics:",
+ value=list(qatch_metrics_dict.keys()),
+ visible=True
+ )
+
+ external_metric_select_bar = gr.CheckboxGroup(
+ choices=list(external_metrics_dict.keys()),
+ label="Select one or more External metrics:",
+ visible=False
+ )
+
+ if(input_data['input_method'] == 'default'):
+ model_radio_bar = gr.Radio(
+ choices=list(all_model_as_dic.keys()),
+ label="Select the model that you want to use:",
+ value="All"
+ )
+ else:
+ model_multiselect_bar = gr.CheckboxGroup(
+ choices=models,
+ label="Select one or more models:",
+ value=models,
+ interactive=len(models) > 1
+ )
+
+ group_radio = gr.Radio(
+ choices=list(group_options.keys()),
+ label="Select the grouping view:",
+ value="Table"
+ )
+
+ def toggle_metric_selector(selected_type):
+ if selected_type == "Qatch":
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True, value=list(qatch_metrics_dict.keys())), gr.update(visible=False, value=[])
+ else:
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False, value=[]), gr.update(visible=True, value=list(external_metrics_dict.keys()))
+
+ output_plot = gr.Plot(visible=False)
+
+ if(input_data['input_method'] == 'default'):
+ with gr.Row():
+ lollipop_propietary(models)
+
+ #FOR RADAR
+ gr.Markdown("""## Section 2: Model - Category""")
+ with gr.Row():
+ all_metrics_radar = gr.Radio(
+ choices=list(all_metrics.keys()),
+ label="Select the metrics group that you want to use:",
+ value="Qatch"
+ )
+
+ model_multiselect_radar = gr.CheckboxGroup(
+ choices=models,
+ label="Select one or more models:",
+ value=models,
+ interactive=len(models) > 1
+ )
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ category_multiselect_radar = gr.CheckboxGroup(
+ choices=principal_categories,
+ label="Select one or more categories:",
+ value=principal_categories
+ )
+ with gr.Column(scale=1):
+ category_radio_radar = gr.Radio(
+ choices=list(all_categories_as_dic.keys()),
+ label="Select the metrics that you want to use:",
+ value=list(all_categories_as_dic.keys())[0]
+ )
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ radar_plot_multiselect = gr.Plot(value=update_radar(models, "Qatch", principal_categories))
+
+ with gr.Column(scale=1):
+ radar_plot_radio = gr.Plot(value=update_radar_sub(models, "Qatch", list(all_categories_as_dic.keys())[0]))
+
+ #FOR RANKING
+ with gr.Row():
+ all_metrics_ranking = gr.Radio(
+ choices=list(all_metrics.keys()),
+ label="Select the metrics group that you want to use:",
+ value="Qatch"
+ )
+ model_choices = list(all_model_as_dic.keys())
+
+ if len(model_choices) == 2:
+ model_choices = [model_choices[0]] # supponiamo che il modello sia in prima posizione
+ selected_value = model_choices[0]
+ else:
+ selected_value = "All"
+
+ model_radio_ranking = gr.Radio(
+ choices=model_choices,
+ label="Select the model that you want to use:",
+ value=selected_value
+ )
+
+ category_radio_ranking = gr.Radio(
+ choices=list(all_categories_as_dic_ranking.keys()),
+ label="Select the category that you want to use",
+ value="All"
+ )
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown("## β 3 Worst Cases\n")
+
+ worst_first, worst_second, worst_third, raw_first, raw_second, raw_third = update_worst_cases_text("All", "Qatch", "All")
+
+ with gr.Row():
+ first = gr.Markdown(worst_first)
+
+ with gr.Row():
+ first_button = gr.Button("Show raw answer for π₯")
+
+ with gr.Row():
+ second = gr.Markdown(worst_second)
+
+ with gr.Row():
+ second_button = gr.Button("Show raw answer for π₯")
+
+ with gr.Row():
+ third = gr.Markdown(worst_third)
+
+ with gr.Row():
+ third_button = gr.Button("Show raw answer for π₯")
+
+ with gr.Column(scale=1):
+ gr.Markdown("""## Raw Answer""")
+ row_answer_first = gr.Markdown(value=raw_first, visible=True)
+ row_answer_second = gr.Markdown(value=raw_second, visible=False)
+ row_answer_third = gr.Markdown(value=raw_third, visible=False)
+
+ #FOR RATE
+ gr.Markdown("""## Section 3: Time - Price""")
+ with gr.Row():
+ model_multiselect_rate = gr.CheckboxGroup(
+ choices=models,
+ label="Select one or more models:",
+ value=models,
+ interactive=len(models) > 1
+ )
+
+
+ with gr.Row():
+ slicer = gr.Slider(minimum=0, maximum=max(df_initial["number_question"]), step=1, value=max(df_initial["number_question"]), label="Number of instances to visualize", elem_id="custom-slider")
+
+ query_rate_plot = gr.Plot(value=update_query_rate(models, len(df_initial["number_question"].unique())))
+
+
+ #FOR RESET
+ reset_data = gr.Button("Back to upload data section")
+
+
+
+
+ ###############################
+ # CALLBACK FUNCTION SECTION #
+ ###############################
+
+ #FOR BAR
+ def on_change(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_group, selected_models):
+ return update_plot(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, group_options[selected_group], selected_models)
+
+ def on_change_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models):
+ return update_plot_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models)
+
+ #FOR RADAR
+ def on_radar_multiselect_change(selected_models, selected_metrics, selected_categories):
+ return update_radar(selected_models, selected_metrics, selected_categories)
+
+ def on_radar_radio_change(selected_models, selected_metrics, selected_category):
+ return update_radar_sub(selected_models, selected_metrics, selected_category)
+
+ #FOR RANKING
+ def on_ranking_change(selected_models, selected_metrics, selected_categories):
+ return update_worst_cases_text(selected_models, selected_metrics, selected_categories)
+
+ def show_first():
+ return (
+ gr.update(visible=True),
+ gr.update(visible=False),
+ gr.update(visible=False)
+ )
+
+ def show_second():
+ return (
+ gr.update(visible=False),
+ gr.update(visible=True),
+ gr.update(visible=False)
+ )
+
+ def show_third():
+ return (
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=True)
+ )
+
+
+
+
+ ######################
+ # ON CLICK SECTION #
+ ######################
+
+ #FOR BAR
+ if(input_data['input_method'] == 'default'):
+ proceed_to_metrics_button.click(on_change_propietary, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
+ qatch_metric_multiselect_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
+ external_metric_select_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
+ model_radio_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
+ qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
+ choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar])
+ external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
+
+ else:
+ proceed_to_metrics_button.click(on_change, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
+ qatch_metric_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
+ external_metric_select_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
+ group_radio.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
+ model_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
+ qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
+ model_multiselect_bar.change(fn=enforce_model_selection, inputs=model_multiselect_bar, outputs=model_multiselect_bar)
+ choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar])
+ external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
+
+
+ #FOR RADAR MULTISELECT
+ model_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
+ all_metrics_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
+ category_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
+ model_multiselect_radar.change(fn=enforce_model_selection, inputs=model_multiselect_radar, outputs=model_multiselect_radar)
+ category_multiselect_radar.change(fn=enforce_category_selection, inputs=category_multiselect_radar, outputs=category_multiselect_radar)
+
+ #FOR RADAR RADIO
+ model_multiselect_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
+ all_metrics_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
+ category_radio_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
+
+ #FOR RANKING
+ model_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
+ model_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
+ all_metrics_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
+ all_metrics_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
+ category_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
+ category_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
+ model_radio_ranking.change(fn=enforce_model_selection, inputs=model_radio_ranking, outputs=model_radio_ranking)
+ category_radio_ranking.change(fn=enforce_category_selection, inputs=category_radio_ranking, outputs=category_radio_ranking)
+ first_button.click(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
+ second_button.click(fn=show_second, outputs=[row_answer_first, row_answer_second, row_answer_third])
+ third_button.click(fn=show_third, outputs=[row_answer_first, row_answer_second, row_answer_third])
+
+ #FOR RATE
+ model_multiselect_rate.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
+ proceed_to_metrics_button.click(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
+ model_multiselect_rate.change(fn=enforce_model_selection, inputs=model_multiselect_rate, outputs=model_multiselect_rate)
+ slicer.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
+
+ #FOR RESET
+ reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
+ reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics])
+ reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
+
+
interface.launch(share = True)
\ No newline at end of file