diff --git "a/app.py" "b/app.py"
--- "a/app.py"
+++ "b/app.py"
@@ -1,17 +1,34 @@
import gradio as gr
import pandas as pd
import os
-
+# # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
+# if os.environ.get("SPACES_ZERO_GPU") is not None:
+# import spaces
+# else:
+# class spaces:
+# @staticmethod
+# def GPU(func):
+# def wrapper(*args, **kwargs):
+# return func(*args, **kwargs)
+# return wrapper
import sys
from qatch.connectors.sqlite_connector import SqliteConnector
from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
+from prediction import ModelPrediction
import utils_get_db_tables_info
import utilities as us
import time
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
+import re
+import csv
+
+# @spaces.GPU
+# def model_prediction():
+# pass
+pnp_path = os.path.join("data", "evaluation_p_metrics.csv")
with open('style.css', 'r') as file:
css = file.read()
@@ -34,39 +51,16 @@ input_data = {
'db_name': "",
'data': {
'data_frames': {}, # dictionary of dataframes
- 'db': None # SQLITE3 database object
+ 'db': None, # SQLITE3 database object
+ 'selected_tables' :[]
},
- 'models': []
+ 'models': [],
+ 'prompt': "{question} {schema}"
}
def load_data(file, path, use_default):
"""Carica i dati da un file, un percorso o usa il DataFrame di default."""
global df_current
- if use_default:
- input_data["input_method"] = 'default'
- input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable.sqlite")
- input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0]
- input_data["data"]['data_frames'] = {'MyTable': df_current}
-
- if( input_data["data"]['data_frames']):
- table2primary_key = {}
- for table_name, df in input_data["data"]['data_frames'].items():
- # Assign primary keys for each table
- table2primary_key[table_name] = 'id'
- input_data["data"]["db"] = SqliteConnector(
- relative_db_path=input_data["data_path"],
- db_name=input_data["db_name"],
- tables= input_data["data"]['data_frames'],
- table2primary_key=table2primary_key
- )
-
- df_current = df_default.copy() # Ripristina i dati di default
- return input_data["data"]['data_frames']
-
- selected_inputs = sum([file is not None, bool(path), use_default])
- if selected_inputs > 1:
- return 'Errore: Selezionare solo un metodo di input alla volta.'
-
if file is not None:
try:
input_data["input_method"] = 'uploaded_file'
@@ -74,7 +68,7 @@ def load_data(file, path, use_default):
input_data["data_path"] = os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
input_data["data"] = us.load_data(file, input_data["db_name"])
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
- if( input_data["data"]['data_frames']):
+ if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
table2primary_key = {}
for table_name, df in input_data["data"]['data_frames'].items():
# Assign primary keys for each table
@@ -88,31 +82,51 @@ def load_data(file, path, use_default):
return input_data["data"]['data_frames']
except Exception as e:
return f'Errore nel caricamento del file: {e}'
-
- """
- if path:
- if not os.path.exists(path):
- return 'Errore: Il percorso specificato non esiste.'
- try:
- input_data["input_method"] = 'uploaded_file'
- input_data["data_path"] = path
- input_data["db_name"] = os.path.splitext(os.path.basename(path))[0]
- input_data["data"] = us.load_data(input_data["data_path"], input_data["db_name"])
- df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
-
+ if use_default:
+ if(use_default == 'Custom'):
+ input_data["input_method"] = 'custom'
+ input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable_0.sqlite")
+ #if file already exist
+ while os.path.exists(input_data["data_path"]):
+ input_data["data_path"] = us.increment_filename(input_data["data_path"])
+ input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0]
+ input_data["data"]['data_frames'] = {'MyTable': df_current}
+
+ if(input_data["data"]['data_frames']):
+ table2primary_key = {}
+ for table_name, df in input_data["data"]['data_frames'].items():
+ # Assign primary keys for each table
+ table2primary_key[table_name] = 'id'
+ input_data["data"]["db"] = SqliteConnector(
+ relative_db_path=input_data["data_path"],
+ db_name=input_data["db_name"],
+ tables= input_data["data"]['data_frames'],
+ table2primary_key=table2primary_key
+ )
+ df_current = df_default.copy() # Ripristina i dati di default
+ return input_data["data"]['data_frames']
+
+ if(use_default == 'Proprietary vs Non-proprietary'):
+ input_data["input_method"] = 'default'
+ #input_data["data_path"] = os.path.join(".", "data", "data_interface", "default.sqlite")
+ #input_data["data_path"] = os.path.join(".", "data", "spider_databases", "defeault.sqlite")
+ #input_data["db_name"] = "default"
+ #input_data["data"]['db'] = SqliteConnector(relative_db_path=input_data["data_path"], db_name=input_data["db_name"])
+ input_data["data"]['data_frames'] = us.extract_tables_dict(pnp_path)
return input_data["data"]['data_frames']
- except Exception as e:
- return f'Errore nel caricamento del file dal percorso: {e}'
- """
-
+ selected_inputs = sum([file is not None, bool(path), use_default])
+ if selected_inputs > 1:
+ return 'Error: Select only one input method at a time.'
+
return input_data["data"]['data_frames']
def preview_default(use_default):
- """Mostra il DataFrame di default se il checkbox Γ¨ selezionato."""
- if use_default:
- return df_default # Mostra il DataFrame di default
- return df_current # Mostra il DataFrame corrente, che potrebbe essere stato modificato
+ if use_default == 'Custom':
+ return gr.DataFrame(interactive=True, visible = True, value = df_default), gr.update(visible=False)
+ else:
+ return gr.DataFrame(interactive=False, visible = False, value = df_default), gr.update(visible=True)
+ #return gr.DataFrame(interactive=True, value = df_current) # Mostra il DataFrame corrente, che potrebbe essere stato modificato
def update_df(new_df):
"""Aggiorna il DataFrame corrente."""
@@ -128,14 +142,16 @@ def open_accordion(target):
input_data['data_path'] = ""
input_data['db_name'] = ""
input_data['data']['data_frames'] = {}
+ input_data['data']['selected_tables'] = []
input_data['data']['db'] = None
input_data['models'] = []
- return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value=False), gr.update(value=None)
+ return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value='Proprietary vs Non-proprietary'), gr.update(value=None)
elif target == "model_selection":
return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False)
# Interfaccia Gradio
with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
+#with gr.Blocks(theme='NoCrypt/miku/light', css_paths='style.css') as interface:
with gr.Row():
gr.Column(scale=1)
gr.Image(
@@ -153,21 +169,24 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
select_model_acc = gr.Accordion("Select models", open=False, visible=False)
qatch_acc = gr.Accordion("QATCH execution", open=False, visible=False)
metrics_acc = gr.Accordion("Metrics", open=False, visible=False)
- #metrics_acc = gr.Accordion("Metrics", open=False, visible=False, render=False)
-
-
#################################
# DATABASE INSERTION #
#################################
with upload_acc:
- gr.Markdown("## Data Upload")
-
- file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"])
+ gr.Markdown("## Choose data input method")
with gr.Row():
- default_checkbox = gr.Checkbox(label="Use default DataFrame")
- preview_output = gr.DataFrame(interactive=True, visible=True, value=df_default)
- submit_button = gr.Button("Load Data", interactive=False) # Disabled by default
+ default_checkbox = gr.Radio(label = "Use default DataFrame or costume one table", choices=['Proprietary vs Non-proprietary', 'Custom'], value='Proprietary vs Non-proprietary')
+ #default_checkbox = gr.Checkbox(label="Use default DataFrame")
+ preview_output = gr.DataFrame(interactive=False, visible=False, value=df_default)
+ description = """## Comparison of proprietary and non-proprietary databases
+ - Proprietary (Economic, Medical, Financial, Miscellaneous)
+ - Non-proprietary (Spider 1.0)"""
+
+ table_default = gr.Markdown(description, visible=True)
+ gr.Markdown("## Or upload your data")
+ file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"])
+ submit_button = gr.Button("Load Data") # Disabled by default
output = gr.JSON(visible=False) # Dictionary output
# Function to enable the button if there is data to load
@@ -177,15 +196,23 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
# Function to uncheck the checkbox if a file is uploaded
def deselect_default(file):
if file:
- return gr.update(value=False)
+ return gr.update(value='Proprietary vs Non-proprietary')
return gr.update()
+
+ def enable_disable_first(enable):
+ return (
+ gr.update(interactive=enable),
+ gr.update(interactive=enable),
+ gr.update(interactive=enable),
+ gr.update(interactive=enable)
+ )
# Enable the button when inputs are provided
- file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
- default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
+ #file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
+ #default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
# Show preview of the default DataFrame when checkbox is selected
- default_checkbox.change(fn=preview_default, inputs=[default_checkbox], outputs=[preview_output])
+ default_checkbox.change(fn=preview_default, inputs=[default_checkbox], outputs=[preview_output, table_default])
preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output])
# Uncheck the checkbox when a file is uploaded
@@ -197,6 +224,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
if isinstance(result, dict): # If result is a dictionary of DataFrames
if len(result) == 1: # If there's only one table
+ input_data['data']['selected_tables'] = list(input_data['data']['data_frames'].keys())
return (
gr.update(visible=False), # Hide JSON output
result, # Save the data state
@@ -214,7 +242,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
result,
gr.update(interactive=False),
gr.update(visible=False), # Keep current behavior
- gr.update(visible=True, open=True)
+ gr.update(visible=True, open=False)
)
else:
return (
@@ -224,7 +252,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
None,
gr.update(interactive=True),
gr.update(visible=False),
- gr.update(visible=True, open=True)
+ gr.update(visible=True, open=False)
)
submit_button.click(
@@ -232,14 +260,24 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
inputs=[file_input, default_checkbox],
outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc]
)
+
+ submit_button.click(
+ fn=enable_disable_first,
+ inputs=[gr.State(False)],
+ outputs=[
+ preview_output,
+ submit_button,
+ file_input,
+ default_checkbox
+ ]
+ )
-
- ######################################
+ ######################################
# TABLE SELECTION PART #
######################################
with select_table_acc:
table_selector = gr.CheckboxGroup(choices=[], label="Select tables to display", value=[])
- table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(5)]
+ table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(10)]
selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
# Model selection button (initially disabled)
@@ -265,10 +303,10 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
updates.append(gr.update(value=df, label=f"Table: {name}", visible=True))
# If there are fewer than 5 tables, hide the other DataFrames
- for _ in range(len(tables), 5):
+ for _ in range(len(tables), 10):
updates.append(gr.update(visible=False))
else:
- updates = [gr.update(value=pd.DataFrame(), visible=False) for _ in range(5)]
+ updates = [gr.update(value=pd.DataFrame(), visible=False) for _ in range(10)]
# Enable/disable the button based on selections
button_state = bool(selected_tables) # True if at least one table is selected, False otherwise
@@ -279,6 +317,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
def show_selected_table_names(selected_tables):
"""Displays the names of the selected tables when the button is pressed."""
if selected_tables:
+ input_data['data']['selected_tables'] = selected_tables
return gr.update(value=", ".join(selected_tables), visible=False)
return gr.update(value="", visible=False)
@@ -291,7 +330,22 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
# Shows the list of selected tables when "Choose your models" is clicked
open_model_selection.click(fn=show_selected_table_names, inputs=[table_selector], outputs=[selected_table_names])
open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
-
+
+ reset_data = gr.Button("Back to upload data section")
+
+ reset_data.click(
+ fn=enable_disable_first,
+ inputs=[gr.State(True)],
+ outputs=[
+ preview_output,
+ submit_button,
+ file_input,
+ default_checkbox
+ ]
+ )
+
+ reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
+
####################################
# MODEL SELECTION PART #
@@ -322,18 +376,42 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
cols.append(checkbox)
rows.append(cols)
- selected_models_output = gr.JSON(visible=False)
+ selected_models_output = gr.JSON(visible=True)
# Function to get selected models
def get_selected_models(*model_selections):
selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
input_data['models'] = selected_models
- button_state = bool(selected_models) # True if at least one model is selected, False otherwise
+ button_state = bool(selected_models and '{schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
+ # Add the Textbox to the interface
+ prompt = gr.TextArea(label="Customise the prompt for selected models here or leave the default one . The prompt must contain {question} and {schema} which will be automatically replaced during SQL generation.",
+ placeholder='Default prompt with a {question} and db {schema} are to be specified')
+ warning_prompt = gr.Markdown(value="# Error in the prompt format", visible=False)
+
# Submit button (initially disabled)
+
submit_models_button = gr.Button("Submit Models", interactive=False)
+ def check_prompt(prompt):
+ #TODO
+ missing_elements = []
+ if(prompt==""):
+ input_data["prompt"]="{question} {schema}"
+ button_state = bool(len(input_data['models']) > 0 and '{schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
+ else:
+ input_data["prompt"]=prompt
+ if "{schema}" not in prompt:
+ missing_elements.append("{schema}")
+ if "{question}" not in prompt:
+ missing_elements.append("{question}")
+ button_state = bool(len(input_data['models']) > 0 and '{schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
+ if missing_elements:
+ return gr.update(value=f"## β Missing {', '.join(missing_elements)} in the prompt β", visible=True), gr.update(interactive=button_state)
+ return gr.update(visible=False), gr.update(interactive=button_state)
+
+ prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button])
# Link checkboxes to selection events
for checkbox in model_checkboxes:
checkbox.change(
@@ -341,13 +419,18 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
inputs=model_checkboxes,
outputs=[selected_models_output, select_model_acc, submit_models_button]
)
+ prompt.change(
+ fn=get_selected_models,
+ inputs=model_checkboxes,
+ outputs=[selected_models_output, select_model_acc, submit_models_button]
+ )
submit_models_button.click(
fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
inputs=model_checkboxes,
outputs=[selected_models_output, select_model_acc, qatch_acc]
)
-
+
def enable_disable(enable):
return (
*[gr.update(interactive=enable) for _ in model_checkboxes],
@@ -397,7 +480,6 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
]
)
-
#############################
# QATCH EXECUTION #
#############################
@@ -422,89 +504,172 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
symbols = loading_symbols.get(num_symbols, "π")
mirrored_symbols = f'{symbols.strip()}'
css_symbols = f'{symbols.strip()}'
- return f"
{css_symbols} Generation {percent}%{mirrored_symbols}
"
+ return f"""
+
+ {css_symbols}
+
+ Generation {percent}%
+
+ {mirrored_symbols}
+
+ """
+
#return f"{css_symbols}"+f"# Loading {percent}% #"+f"{mirrored_symbols}"
def qatch_flow():
- orchestrator_generator = OrchestratorGenerator()
- # TODO: add to target_df column target_df["columns_used"], tables selection
- # print(input_data['data']['db'])
- target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'])
-
- schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
- db_id = input_data["db_name"],
- base_path = input_data["data_path"],
- normalize=False,
- sql=None
- )
-
- # TODO: QUERY PREDICTION
+ #caching
predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
metrics_conc = pd.DataFrame()
-
- for model in input_data["models"]:
- model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
- yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
-
- for index, row in target_df.iterrows():
-
- percent_complete = round(((index+1) / len(target_df)) * 100, 2)
- load_text = f"{generate_loading_text(percent_complete)}"
-
- question = row['question']
- display_question = f"Natural Language:
{row['question']}
"
- # yield gr.Textbox(question), gr.Textbox(), *[predictions_dict[model] for model in input_data["models"]], None
-
- yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
- start_time = time.time()
-
- # Simulate prediction
- time.sleep(0.4)
- prediction = "Prediction_placeholder"
- display_prediction = f"Generated SQL:
{prediction}
"
- # Run real prediction here
- # prediction = predictor.run(model, schema_text, question)
-
- end_time = time.time()
- # Create a new row as dataframe
- new_row = pd.DataFrame([{
- 'id': index,
- 'question': question,
- 'predicted_sql': prediction,
- 'time': end_time - start_time,
- 'query': row["query"],
- 'db_path': input_data["data_path"]
- }]).dropna(how="all") # Remove only completely empty rows
-
- # TODO: use a for loop
- for col in target_df.columns:
- if col not in new_row.columns:
- new_row[col] = row[col]
-
- # Update model's prediction dataframe incrementally
- if not new_row.empty:
- predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
-
- # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
+ if (input_data['input_method']=="default"):
+ target_df = us.load_csv(pnp_path) #target_df = us.load_csv("priority_non_priority_metrics.csv")
+ #predictions_dict = {model: pd.DataFrame(columns=target_df.columns) for model in model_list}
+ target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])]
+ target_df = target_df[target_df["model"].isin(input_data['models'])]
+ predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list}
+ for model in target_df["model"].unique():
+ model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
+ yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
+ count=1
+ for _, row in predictions_dict[model].iterrows():
+ #for index, row in target_df.iterrows():
+ percent_complete = round(count / len(predictions_dict[model]) * 100, 2)
+ count=count+1
+ load_text = f"{generate_loading_text(percent_complete)}"
+ question = row['question']
+
+ # display_question = f"""
+ #
+ # Natural Language:
+ #
+ #
+ # {row['question']}
+ #
+ # """
+ display_question = f"""Natural Language:
+
+ """
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
+ #time.sleep(0.02)
+ prediction = row['predicted_sql']
+
+ # display_prediction = f"""
+ #
+ # Generated SQL:
+ #
+ #
+ # {prediction}
+ #
+ # """
+ display_prediction = f"""Natural Language:
+
+
β‘οΈ
+
{prediction}
+
+ """
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
+ metrics_conc = target_df
+ if 'valid_efficiency_score' not in metrics_conc.columns:
+ metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
+ yield gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
+ else:
- yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
- # END
- evaluator = OrchestratorEvaluator()
- for model in input_data["models"]:
- metrics_df_model = evaluator.evaluate_df(
- df=predictions_dict[model],
- target_col_name="query",
- prediction_col_name="predicted_sql",
- db_path_name="db_path"
+ orchestrator_generator = OrchestratorGenerator()
+ # TODO: add to target_df column target_df["columns_used"], tables selection
+ # print(input_data['data']['db'])
+ #print(input_data['data']['selected_tables'])
+ target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=input_data['data']['selected_tables'])
+ #target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
+
+ schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
+ db_id = input_data["db_name"],
+ base_path = input_data["data_path"],
+ normalize=False,
+ sql=None
)
- metrics_df_model['model'] = model
- metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
-
+
+ predictor = ModelPrediction()
+
+ for model in input_data["models"]:
+ model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
+ yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
+ count=0
+ for index, row in target_df.iterrows():
+
+ percent_complete = round(((index+1) / len(target_df)) * 100, 2)
+ load_text = f"{generate_loading_text(percent_complete)}"
+
+ question = row['question']
+ #display_question = f"Natural Language:
{row['question']}
"
+ display_question = f"""Natural Language:
+
+ """
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
+ start_time = time.time()
+ samples = us.generate_some_samples(input_data['data']['db'], row["tbl_name"])
+ prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
+ #PREDICTION SQL
+ #response = predictor.make_prediction(question=question, db_schema=schema_text, model_name=model, prompt=f"{prompt_to_send}")
+ prediction = row["query"]#"SQL"#response[response_parsed]
+ price = 0.0#response[cost]
+ answer = "Answer"#response[response]
+
+ end_time = time.time()
+ #display_prediction = f"Generated SQL:
{prediction}
"
+ display_prediction = f"""Natural Language:
+
+
β‘οΈ
+
{prediction}
+
+ """
+ # Create a new row as dataframe
+ new_row = pd.DataFrame([{
+ 'id': index,
+ 'question': question,
+ 'predicted_sql': prediction,
+ 'time': end_time - start_time,
+ 'query': row["query"],
+ 'db_path': input_data["data_path"],
+ 'price':price,
+ 'answer':answer,
+ 'number_question':count
+ }]).dropna(how="all") # Remove only completely empty rows
+ count=count+1
+ # TODO: use a for loop
+ for col in target_df.columns:
+ if col not in new_row.columns:
+ new_row[col] = row[col]
+
+ # Update model's prediction dataframe incrementally
+ if not new_row.empty:
+ predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
+
+ # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
+
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
+ # END
+ evaluator = OrchestratorEvaluator()
+ for model in input_data["models"]:
+ metrics_df_model = evaluator.evaluate_df(
+ df=predictions_dict[model],
+ target_col_name="query",
+ prediction_col_name="predicted_sql",
+ db_path_name="db_path"
+ )
+ metrics_df_model['model'] = model
+ metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
+
if 'valid_efficiency_score' not in metrics_conc.columns:
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
- yield gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
+ yield gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
# Loading Bar
with gr.Row():
@@ -516,15 +681,11 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
with gr.Column():
with gr.Column():
question_display = gr.Markdown()
- with gr.Column():
- gr.Markdown("‴
")
with gr.Column():
model_logo = gr.Image(visible=True, show_label=False)
with gr.Column():
with gr.Column():
prediction_display = gr.Markdown()
- with gr.Column():
- gr.Markdown("‴
")
dataframe_per_model = {}
@@ -611,23 +772,35 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
open_model_selection
]
)
-
-
+
+
+
+
##########################################
# METRICS VISUALIZATION SECTION #
##########################################
with metrics_acc:
- #confirmation_text = gr.Markdown("## Metrics successfully loaded")
-
- data_path = 'test_results.csv'
+ #data_path = 'test_results_metrics1.csv'
+ data_path = '/Users/francescogiannuzzo/Desktop/EURECOM/semester_project_gradio_git/Automatic-LLM-Benchmark-Analysis-for-Text2SQL-GRADIO/data/evaluation_p_metrics.csv'
@gr.render(inputs=metrics_df_out)
def function_metrics(metrics_df_out):
+
+ ####################################
+ # UTILS FUNCTIONS SECTION #
+ ####################################
+
def load_data_csv_es():
- return pd.read_csv(data_path)
- #return metrics_df_out
+ #return pd.read_csv(data_path)
+ #print("---------------->",metrics_df_out)
+ return metrics_df_out
def calculate_average_metrics(df, selected_metrics):
+ # Exclude the 'tuple_order' column from the selected metrics
+
+ #TODO tuple_order has NULL VALUE
+ selected_metrics = [metric for metric in selected_metrics if metric != 'tuple_order']
+ #print(df[selected_metrics])
df['avg_metric'] = df[selected_metrics].mean(axis=1)
return df
@@ -646,198 +819,723 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
return colors
MODEL_COLORS = generate_model_colors()
+
+ def generate_db_category_colors():
+ """Assigns 3 distinct colors to db_category groups."""
+ return {
+ "Spider": "#1f77b4", # blu
+ "Beaver": "#ff7f0e", # arancione
+ "Economic": "#2ca02c", # tutti gli altri verdi
+ "Financial": "#2ca02c",
+ "Medical": "#2ca02c",
+ "Miscellaneous": "#2ca02c"
+ }
+
+ DB_CATEGORY_COLORS = generate_db_category_colors()
+
+ def normalize_valid_efficiency_score(df):
+ #TODO valid_efficiency_score
+ #print(df['valid_efficiency_score'])
+ df['valid_efficiency_score'] = df['valid_efficiency_score'].replace('', 0)
+ df['valid_efficiency_score'] = df['valid_efficiency_score'].astype(int)
+ min_val = df['valid_efficiency_score'].min()
+ max_val = df['valid_efficiency_score'].max()
+
+ if min_val == max_val:
+ # Tutti i valori sono uguali, assegna 1.0 a tutto per evitare divisione per zero
+ df['valid_efficiency_score'] = 1.0
+ else:
+ df['valid_efficiency_score'] = (
+ df['valid_efficiency_score'] - min_val
+ ) / (max_val - min_val)
+
+ return df
+
+
+
+
+ ####################################
+ # GRAPH FUNCTIONS SECTION #
+ ####################################
# BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
- def plot_metric(df, selected_metrics, group_by, selected_models):
+ def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models):
df = df[df['model'].isin(selected_models)]
+ df = normalize_valid_efficiency_score(df)
+
+ # Mappatura nomi leggibili -> tecnici
+ qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
+ external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric]
+
+ selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal
+
df = calculate_average_metrics(df, selected_metrics)
+
+ if group_by == ["model"]:
+ # Bar plot per "model"
+ avg_metrics = df.groupby("model")['avg_metric'].mean().reset_index()
+ avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
+
+ fig = px.bar(
+ avg_metrics,
+ x="model",
+ y="avg_metric",
+ color="model",
+ color_discrete_map=MODEL_COLORS,
+ title='Average metric per Model π§ ',
+ labels={"model": "Model", "avg_metric": "Average Metric"},
+ template='plotly_dark',
+ text='text_label'
+ )
+ else:
+ if group_by != ["tbl_name", "model"]:
+ group_by = ["tbl_name", "model"]
+
+ avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index()
+ avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
+
+ fig = px.bar(
+ avg_metrics,
+ x=group_by[0],
+ y='avg_metric',
+ color='model',
+ color_discrete_map=MODEL_COLORS,
+ barmode='group',
+ title=f'Average metric per {group_by[0]} π',
+ labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metric'},
+ template='plotly_dark',
+ text='text_label'
+ )
+
+ fig.update_traces(textposition='outside', textfont_size=10)
+
+ # font Playfair Display
+ fig.update_layout(
+ margin=dict(t=80),
+ title=dict(
+ font=dict(
+ family="Playfair Display, serif",
+ size=22,
+ color="white"
+ ),
+ x=0.5
+ ),
+ xaxis=dict(
+ title=dict(
+ font=dict(
+ family="Playfair Display, serif",
+ size=16,
+ color="white"
+ )
+ ),
+ tickfont=dict(
+ family="Playfair Display, serif",
+ color="white"
+ )
+ ),
+ yaxis=dict(
+ title=dict(
+ font=dict(
+ family="Playfair Display, serif",
+ size=16,
+ color="white"
+ )
+ ),
+ tickfont=dict(
+ family="Playfair Display, serif",
+ color="white"
+ )
+ ),
+ legend=dict(
+ title=dict(
+ font=dict(
+ family="Playfair Display, serif",
+ size=14,
+ color="white"
+ )
+ ),
+ font=dict(
+ family="Playfair Display, serif",
+ color="white"
+ )
+ )
+ )
+
+ return gr.Plot(fig, visible=True)
+
+ def update_plot(radio_metric, qatch_selected_metrics, external_selected_metric,group_by, selected_models):
+ df = load_data_csv_es()
+ return plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models)
+
+ # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION
+ def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
+ if selected_models == "All":
+ selected_models = models
+ else:
+ selected_models = [selected_models]
- # Ensure the group_by value is always valid
- if group_by not in [["tbl_name", "model"], ["model"]]:
- group_by = ["tbl_name", "model"] # Default
+ df = df[df['model'].isin(selected_models)]
+ df = normalize_valid_efficiency_score(df)
- avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index()
+ # Converti nomi leggibili -> tecnici
+ qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
+ external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric]
+
+ selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal
+
+ df = calculate_average_metrics(df, selected_metrics)
+
+ avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index()
+ avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
fig = px.bar(
avg_metrics,
- x=group_by[0],
+ x='db_category',
y='avg_metric',
color='model',
color_discrete_map=MODEL_COLORS,
barmode='group',
- title=f'Average metric per {group_by[0]} π',
- labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metric'},
- template='plotly_dark'
+ title='Average metric per db_category π',
+ labels={'db_path': 'DB Path', 'avg_metric': 'Average Metric'},
+ template='simple_white',
+ text='text_label'
)
-
- return gr.Plot(fig, visible=True)
- def update_plot(selected_metrics, group_by, selected_models):
- df = load_data_csv_es()
- return plot_metric(df, selected_metrics, group_by, selected_models)
+ fig.update_traces(textposition='outside', textfont_size=10)
+ #Playfair Display
+ fig.update_layout(
+ margin=dict(t=80),
+ title=dict(
+ font=dict(
+ family="Playfair Display, serif",
+ size=22,
+ color="black"
+ ),
+ x=0.5
+ ),
+ xaxis=dict(
+ title=dict(
+ text='DB Category',
+ font=dict(
+ family='Playfair Display, serif',
+ size=16,
+ color='black'
+ )
+ ),
+ tickfont=dict(
+ family='Playfair Display, serif',
+ color='black'
+ )
+ ),
+ yaxis=dict(
+ title=dict(
+ text='Average Metric',
+ font=dict(
+ family='Playfair Display, serif',
+ size=16,
+ color='black'
+ )
+ ),
+ tickfont=dict(
+ family='Playfair Display, serif',
+ color='black'
+ )
+ ),
+ legend=dict(
+ title=dict(
+ text='Models',
+ font=dict(
+ family='Playfair Display, serif',
+ size=14,
+ color='black'
+ )
+ ),
+ font=dict(
+ family='Playfair Display, serif',
+ color='black'
+ )
+ )
+ )
- # RADAR CHART FOR AVERAGE METRICS PER MODEL WITH UPDATE FUNCTION
- def plot_radar(df, selected_models):
- # Filter only selected models
- df = df[df['model'].isin(selected_models)]
+ return gr.Plot(fig, visible=True)
+
+ """
+ def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
+ if selected_models == "All":
+ selected_models = models
+ else:
+ selected_models = [selected_models]
- # Select relevant metrics
- selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
+ df = df[df['model'].isin(selected_models)]
+ df = normalize_valid_efficiency_score(df)
- # Compute average metrics per test_category and model
+ if radio_metric == "Qatch":
+ selected_metrics = qatch_selected_metrics
+ else:
+ selected_metrics = external_selected_metric
+
df = calculate_average_metrics(df, selected_metrics)
- avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
- # Check if data is available
- if avg_metrics.empty:
- print("Error: No data available to compute averages.")
- return go.Figure()
+ # Raggruppamento per modello e categoria
+ avg_metrics = df.groupby(["model", "db_category"])['avg_metric'].mean().reset_index()
+ avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
- fig = go.Figure()
- categories = avg_metrics['test_category'].unique()
-
- for model in selected_models:
- model_data = avg_metrics[avg_metrics['model'] == model]
-
- # Build a list of values for each category (if a value is missing, set it to 0)
- values = [
- model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
- if cat in model_data['test_category'].values else 0
- for cat in categories
- ]
-
- fig.add_trace(go.Scatterpolar(
- r=values,
- theta=categories,
- fill='toself',
- name=model,
- line=dict(color=MODEL_COLORS.get(model, "gray"))
- ))
+ # Plot orizzontale con modello sull'asse Y
+ fig = px.bar(
+ avg_metrics,
+ x='avg_metric',
+ y='model',
+ color='db_category', # categoria come colore
+ text='text_label',
+ barmode='group',
+ orientation='h',
+ color_discrete_map=DB_CATEGORY_COLORS, # devi avere questo dict come MODEL_COLORS
+ title='Average metric per model and db_category π',
+ labels={'avg_metric': 'AVG Metric', 'model': 'Model'},
+ template='plotly_dark'
+ )
+ fig.update_traces(textposition='outside', textfont_size=10)
fig.update_layout(
- polar=dict(radialaxis=dict(visible=True, range=[0, max(avg_metrics['avg_metric'].max(), 0.5)])), # Set the radar range
- title='βοΈ Radar Plot of Metrics per Model (Average per Category) βοΈ ',
- template='plotly_dark',
- width=700, height=700
+ margin=dict(t=80),
+ yaxis=dict(title=''),
+ xaxis=dict(title='AVG Metrics'),
+ legend_title='DB Name',
+ height=600 # puoi aumentare se ci sono tanti modelli
)
- return fig
-
- def update_radar(selected_models):
+ return gr.Plot(fig, visible=True)
+ """
+
+ def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
+ df = load_data_csv_es()
+ return plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models)
+
+
+ # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION
+
+ def lollipop_propietary():
df = load_data_csv_es()
- return plot_radar(df, selected_models)
-
- # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
- def plot_cumulative_flow(df, selected_models):
- df = df[df['model'].isin(selected_models)]
+ # Filtra solo le categorie rilevanti
+ target_cats = ["Spider", "Economic", "Financial", "Medical", "Miscellaneous"]
+ df = df[df['db_category'].isin(target_cats)]
+ df = normalize_valid_efficiency_score(df)
+ df = calculate_average_metrics(df, qatch_metrics)
+
+ # Calcola la media per db_category e modello
+ avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index()
+
+ # Separa Spider e le altre 4 categorie
+ spider_df = avg_metrics[avg_metrics["db_category"] == "Spider"]
+ other_df = avg_metrics[avg_metrics["db_category"] != "Spider"]
+
+ # Calcola media delle altre categorie per ciascun modello
+ other_mean_df = other_df.groupby("model")["avg_metric"].mean().reset_index()
+ other_mean_df["db_category"] = "Others"
+
+ # Rinominare per chiarezza e uniformitΓ
+ spider_df = spider_df.rename(columns={"avg_metric": "Spider"})
+ other_mean_df = other_mean_df.rename(columns={"avg_metric": "Others"})
+
+ # Unione dei due dataset
+ merged_df = pd.merge(spider_df[["model", "Spider"]], other_mean_df[["model", "Others"]], on="model")
+
+ # Ordina per modello o per valore se vuoi
+ merged_df = merged_df.sort_values(by="model")
+
fig = go.Figure()
-
- for model in selected_models:
- model_df = df[df['model'] == model].copy()
-
- # Calculate cumulative time
- model_df['cumulative_time'] = model_df['time'].cumsum()
-
- # Calculate cumulative number of queries over time
- model_df['cumulative_queries'] = range(1, len(model_df) + 1)
-
- # Select a color for the model
- color = MODEL_COLORS.get(model, "gray") # Assigned model color
- fillcolor = color.replace("rgb", "rgba").replace(")", ", 0.2)") # πΉ Makes the area semi-transparent
- #color = f"rgba({hash(model) % 256}, {hash(model * 2) % 256}, {hash(model * 3) % 256}, 1)"
-
+ # Aggiungi linee orizzontali tra Spider e Others
+ for _, row in merged_df.iterrows():
fig.add_trace(go.Scatter(
- x=model_df['cumulative_time'],
- y=model_df['cumulative_queries'],
- mode='lines+markers',
- name=model,
- line=dict(width=2, color=color)
+ x=[row["Spider"], row["Others"]],
+ y=[row["model"]] * 2,
+ mode='lines',
+ line=dict(color='gray', width=2),
+ showlegend=False
))
-
- # Adds the underlying colored area (same color but transparent)
- """
- fig.add_trace(go.Scatter(
- x=model_df['cumulative_time'],
- y=model_df['cumulative_queries'],
- fill='tozeroy',
- mode='none',
- showlegend=False, # Hides the area in the legend
- fillcolor=fillcolor
- ))
- """
+
+ # Punto per Spider
+ fig.add_trace(go.Scatter(
+ x=merged_df["Spider"],
+ y=merged_df["model"],
+ mode='markers',
+ name='Spider',
+ marker=dict(size=10, color='red')
+ ))
+
+ # Punto per Others (media delle altre 4 categorie)
+ fig.add_trace(go.Scatter(
+ x=merged_df["Others"],
+ y=merged_df["model"],
+ mode='markers',
+ name='Others Avg',
+ marker=dict(size=10, color='blue')
+ ))
fig.update_layout(
- title="Cumulative Query Flow Chart π",
- xaxis_title="Cumulative Time (s)",
- yaxis_title="Number of Queries Completed",
- template='plotly_dark',
- legend_title="Models"
+ title='Dot-Range Plot: Spider vs Altri π·οΈπ',
+ xaxis_title='Average Metric',
+ yaxis_title='Model',
+ template='simple_white',
+ #template='plotly_dark',
+ margin=dict(t=80),
+ legend_title='Categoria',
+ height=600
)
-
+
+ return gr.Plot(fig, visible=True)
+
+
+ # RADAR OR BAR CHART BASED ON CATEGORY COUNT
+ def plot_radar(df, selected_models, selected_metrics, selected_categories):
+ if "external" in selected_metrics:
+ selected_metrics = ["execution_accuracy", "valid_efficiency_score"]
+ else:
+ selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
+
+ # Filtro modelli e normalizzazione
+ df = df[df['model'].isin(selected_models)]
+ df = normalize_valid_efficiency_score(df)
+ df = calculate_average_metrics(df, selected_metrics)
+
+ avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
+
+ if avg_metrics.empty:
+ print("Error: No data available to compute averages.")
+ return go.Figure()
+
+ categories = selected_categories
+
+ if len(categories) < 3:
+ # π BAR PLOT
+ fig = go.Figure()
+ for model in selected_models:
+ model_data = avg_metrics[avg_metrics['model'] == model]
+ values = [
+ model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
+ if cat in model_data['test_category'].values else 0
+ for cat in categories
+ ]
+ fig.add_trace(go.Bar(
+ x=categories,
+ y=values,
+ name=model,
+ marker=dict(color=MODEL_COLORS.get(model, "gray"))
+ ))
+
+ fig.update_layout(
+ barmode='group',
+ title=dict(
+ text='π Bar Plot of Metrics per Model (Few Categories)',
+ font=dict(
+ family='Playfair Display, serif',
+ size=22,
+ color='white'
+ ),
+ x=0.5
+ ),
+ template='plotly_dark',
+ xaxis=dict(
+ title=dict(
+ text='Test Category',
+ font=dict(
+ family='Playfair Display, serif',
+ size=16,
+ color='white'
+ )
+ ),
+ tickfont=dict(
+ family='Playfair Display, serif',
+ color='white'
+ )
+ ),
+ yaxis=dict(
+ title=dict(
+ text='Average Metric',
+ font=dict(
+ family='Playfair Display, serif',
+ size=16,
+ color='white'
+ )
+ ),
+ tickfont=dict(
+ family='Playfair Display, serif',
+ color='white'
+ )
+ ),
+ legend=dict(
+ title=dict(
+ text='Models',
+ font=dict(
+ family='Playfair Display, serif',
+ size=14,
+ color='white'
+ )
+ ),
+ font=dict(
+ family='Playfair Display, serif',
+ color='white'
+ )
+ )
+ )
+ else:
+ # π§ RADAR PLOT
+ fig = go.Figure()
+ for model in selected_models:
+ model_data = avg_metrics[avg_metrics['model'] == model]
+ values = [
+ model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
+ if cat in model_data['test_category'].values else 0
+ for cat in categories
+ ]
+ fig.add_trace(go.Scatterpolar(
+ r=values,
+ theta=categories,
+ fill='toself',
+ name=model,
+ line=dict(color=MODEL_COLORS.get(model, "gray"))
+ ))
+
+ fig.update_layout(
+ polar=dict(
+ radialaxis=dict(
+ visible=True,
+ range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
+ tickfont=dict(
+ family='Playfair Display, serif',
+ color='white'
+ )
+ ),
+ angularaxis=dict(
+ tickfont=dict(
+ family='Playfair Display, serif',
+ color='white'
+ )
+ )
+ ),
+ title=dict(
+ text='βοΈ Radar Plot of Metrics per Model (Average per Category)',
+ font=dict(
+ family='Playfair Display, serif',
+ size=22,
+ color='white'
+ ),
+ x=0.5
+ ),
+ legend=dict(
+ title=dict(
+ text='Models',
+ font=dict(
+ family='Playfair Display, serif',
+ size=14,
+ color='white'
+ )
+ ),
+ font=dict(
+ family='Playfair Display, serif',
+ color='white'
+ )
+ ),
+ template='plotly_dark'
+ )
+
return fig
- def update_query_rate(selected_models):
+ def update_radar(selected_models, selected_metrics, selected_categories):
df = load_data_csv_es()
- return plot_cumulative_flow(df, selected_models)
+ return plot_radar(df, selected_models, selected_metrics, selected_categories)
+ # RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT
+ def plot_radar_sub(df, selected_models, selected_metrics, selected_category):
+ if "external" in selected_metrics:
+ selected_metrics = ["execution_accuracy", "valid_efficiency_score"]
+ else:
+ selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
- # RANKING FOR THE TOP 3 MODELS WITH UPDATE FUNCTION
- def ranking_text(df, selected_models, ranking_type):
- #df = load_data_csv_es()
df = df[df['model'].isin(selected_models)]
- df['valid_efficiency_score'] = pd.to_numeric(df['valid_efficiency_score'], errors='coerce')
- if ranking_type == "valid_efficiency_score":
- rank_df = df.groupby('model')['valid_efficiency_score'].mean().reset_index()
- #rank_df = df.groupby('model')['valid_efficiency_score'].mean().reset_index()
- ascending_order = False # Higher is better
- elif ranking_type == "time":
- rank_df = df.groupby('model')['time'].sum().reset_index()
- rank_df["Ranking Value"] = rank_df["time"].round(2).astype(str) + " s" # Adds "s" for seconds
- ascending_order = True # For time, lower is better
- elif ranking_type == "metrics":
- selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
- df = calculate_average_metrics(df, selected_metrics)
- rank_df = df.groupby('model')['avg_metric'].mean().reset_index()
- ascending_order = False # Higher is better
-
- if ranking_type != "time":
- rank_df.rename(columns={rank_df.columns[1]: "Ranking Value"}, inplace=True)
- rank_df["Ranking Value"] = rank_df["Ranking Value"].round(2) # Round values except for time
-
- # Sort based on the selected criterion
- rank_df = rank_df.sort_values(by="Ranking Value", ascending=ascending_order).reset_index(drop=True)
-
- # Select only the top 3 models
- rank_df = rank_df.head(3)
-
- # Add medal icons for the top 3
- medals = ["π₯", "π₯", "π₯"]
- rank_df.insert(0, "Rank", medals[:len(rank_df)])
+ df = normalize_valid_efficiency_score(df)
+ df = calculate_average_metrics(df, selected_metrics)
- # Build the formatted ranking string
- ranking_str = "## π Model Ranking\n"
- for _, row in rank_df.iterrows():
- ranking_str += f"{row['Rank']} {row['model']} ({row['Ranking Value']})
\n"
-
- return ranking_str
+ if isinstance(selected_category, str):
+ selected_category = [selected_category]
- def update_ranking_text(selected_models, ranking_type):
- df = load_data_csv_es()
- return ranking_text(df, selected_models, ranking_type)
+ df = df[df['test_category'].isin(selected_category)]
+ avg_metrics = df.groupby(['model', 'sql_tag'])['avg_metric'].mean().reset_index()
+
+ if avg_metrics.empty:
+ print("Error: No data available to compute averages.")
+ return go.Figure()
+
+ categories = df['sql_tag'].unique().tolist()
+
+ if len(categories) < 3:
+ # π BAR PLOT
+ fig = go.Figure()
+ for model in selected_models:
+ model_data = avg_metrics[avg_metrics['model'] == model]
+ values = [
+ model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0]
+ if cat in model_data['sql_tag'].values else 0
+ for cat in categories
+ ]
+ fig.add_trace(go.Bar(
+ x=categories,
+ y=values,
+ name=model,
+ marker=dict(color=MODEL_COLORS.get(model, "gray"))
+ ))
+
+ fig.update_layout(
+ barmode='group',
+ title=dict(
+ text='π Bar Plot of Metrics per Model (Few Sub-Categories)',
+ font=dict(
+ family='Playfair Display, serif',
+ size=22,
+ color='white'
+ ),
+ x=0.5
+ ),
+ template='plotly_dark',
+ xaxis=dict(
+ title=dict(
+ text='SQL Tag (Sub Category)',
+ font=dict(
+ family='Playfair Display, serif',
+ size=16,
+ color='white'
+ )
+ ),
+ tickfont=dict(
+ family='Playfair Display, serif',
+ color='white'
+ )
+ ),
+ yaxis=dict(
+ title=dict(
+ text='Average Metric',
+ font=dict(
+ family='Playfair Display, serif',
+ size=16,
+ color='white'
+ )
+ ),
+ tickfont=dict(
+ family='Playfair Display, serif',
+ color='white'
+ )
+ ),
+ legend=dict(
+ title=dict(
+ text='Models',
+ font=dict(
+ family='Playfair Display, serif',
+ size=14,
+ color='white'
+ )
+ ),
+ font=dict(
+ family='Playfair Display, serif',
+ color='white'
+ )
+ )
+ )
+ else:
+ # π§ RADAR PLOT
+ fig = go.Figure()
+ for model in selected_models:
+ model_data = avg_metrics[avg_metrics['model'] == model]
+ values = [
+ model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0]
+ if cat in model_data['sql_tag'].values else 0
+ for cat in categories
+ ]
+
+ fig.add_trace(go.Scatterpolar(
+ r=values,
+ theta=categories,
+ fill='toself',
+ name=model,
+ line=dict(color=MODEL_COLORS.get(model, "gray"))
+ ))
+
+ fig.update_layout(
+ polar=dict(
+ radialaxis=dict(
+ visible=True,
+ range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
+ tickfont=dict(
+ family='Playfair Display, serif',
+ color='white'
+ )
+ ),
+ angularaxis=dict(
+ tickfont=dict(
+ family='Playfair Display, serif',
+ color='white'
+ )
+ )
+ ),
+ title=dict(
+ text='βοΈ Radar Plot of Metrics per Model (Average per Sub-Category)',
+ font=dict(
+ family='Playfair Display, serif',
+ size=22,
+ color='white'
+ ),
+ x=0.5
+ ),
+ legend=dict(
+ title=dict(
+ text='Models',
+ font=dict(
+ family='Playfair Display, serif',
+ size=14,
+ color='white'
+ )
+ ),
+ font=dict(
+ family='Playfair Display, serif',
+ color='white'
+ )
+ ),
+ template='plotly_dark'
+ )
+
+ return fig
+ def update_radar_sub(selected_models, selected_metrics, selected_category):
+ df = load_data_csv_es()
+ return plot_radar_sub(df, selected_models, selected_metrics, selected_category)
# RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
- def worst_cases_text(df, selected_models):
+ def worst_cases_text(df, selected_models, selected_metrics, selected_categories):
+ if selected_models == "All":
+ selected_models = models
+ else:
+ selected_models = [selected_models]
+
+ if selected_categories == "All":
+ selected_categories = principal_categories
+ else:
+ selected_categories = [selected_categories]
+
df = df[df['model'].isin(selected_models)]
+ df = df[df['test_category'].isin(selected_categories)]
- selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
+ if "external" in selected_metrics:
+ selected_metrics = ["execution_accuracy", "valid_efficiency_score"]
+ else:
+ selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
+
+ df = normalize_valid_efficiency_score(df)
df = calculate_average_metrics(df, selected_metrics)
- worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql'])['avg_metric'].mean().reset_index()
+ worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
@@ -845,110 +1543,495 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2)
- worst_str = "## β Top 3 Worst Cases\n"
+ worst_str = []
+ answer_str = []
+
medals = ["π₯", "π₯", "π₯"]
for i, row in worst_cases_top_3.iterrows():
- worst_str += (
- f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} ({row['avg_metric']}) \n"
+ entry = (
+ f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n"
f"- Question: {row['question']} \n"
f"- Original Query: `{row['query']}` \n"
f"- Predicted SQL: `{row['predicted_sql']}` \n\n"
)
+
+ worst_str.append(entry)
+
+ raw_answer = (
+ f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n"
+ f"- Raw Answer:
`{row['answer']}` \n"
+ )
+
+ answer_str.append(raw_answer)
- return worst_str
+ return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2]
- def update_worst_cases_text(selected_models):
+ def update_worst_cases_text(selected_models, selected_metrics, selected_categories):
df = load_data_csv_es()
- return worst_cases_text(df, selected_models)
+ return worst_cases_text(df, selected_models, selected_metrics, selected_categories)
+
+ # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
+ def plot_cumulative_flow(df, selected_models, max_points):
+ df = df[df['model'].isin(selected_models)]
+ df = normalize_valid_efficiency_score(df)
+
+ fig = go.Figure()
+
+ for model in selected_models:
+ model_df = df[df['model'] == model].copy()
+
+ # Limita il numero di punti se richiesto
+ if max_points is not None:
+ model_df = model_df.head(max_points + 1)
+
+ # Tooltip personalizzato
+ model_df['hover_info'] = model_df.apply(
+ lambda row:
+ f"Id question: {row['number_question']}
"
+ f"Question: {row['question']}
"
+ f"Target: {row['query']}
"
+ f"Prediction: {row['predicted_sql']}
"
+ f"Category: {row['test_category']}",
+ axis=1
+ )
+
+ # Calcoli cumulativi
+ model_df['cumulative_time'] = model_df['time'].cumsum()
+ model_df['cumulative_price'] = model_df['price'].cumsum()
+
+ # Colore del modello
+ color = MODEL_COLORS.get(model, "gray")
+
+ fig.add_trace(go.Scatter(
+ x=model_df['cumulative_time'],
+ y=model_df['cumulative_price'],
+ mode='lines+markers',
+ name=model,
+ line=dict(width=2, color=color),
+ customdata=model_df['hover_info'],
+ hovertemplate=
+ "Model: " + model + "
" +
+ "Cumulative Time: %{x}s
" +
+ "Cumulative Price: $%{y:.2f}
" +
+ "
Details:
%{customdata}"
+ ))
+
+ # Layout con font elegante
+ fig.update_layout(
+ title=dict(
+ text="Cumulative Price Flow Chart π°",
+ font=dict(
+ family="Playfair Display, serif",
+ size=24,
+ color="white"
+ ),
+ x=0.5
+ ),
+ xaxis=dict(
+ title=dict(
+ text="Cumulative Time (s)",
+ font=dict(
+ family="Playfair Display, serif",
+ size=16,
+ color="white"
+ )
+ ),
+ tickfont=dict(
+ family="Playfair Display, serif",
+ color="white"
+ )
+ ),
+ yaxis=dict(
+ title=dict(
+ text="Cumulative Price ($)",
+ font=dict(
+ family="Playfair Display, serif",
+ size=16,
+ color="white"
+ )
+ ),
+ tickfont=dict(
+ family="Playfair Display, serif",
+ color="white"
+ )
+ ),
+ legend=dict(
+ title=dict(
+ text="Models",
+ font=dict(
+ family="Playfair Display, serif",
+ size=14,
+ color="white"
+ )
+ ),
+ font=dict(
+ family="Playfair Display, serif",
+ color="white"
+ )
+ ),
+ template="plotly_dark"
+ )
+
+ return fig
+
+ def update_query_rate(selected_models, max_points):
+ df = load_data_csv_es()
+ return plot_cumulative_flow(df, selected_models, max_points)
+
+
- metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
+ #######################
+ # PARAMETER SECTION #
+ #######################
+ qatch_metrics_dict = {
+ "Cell Precision": "cell_precision",
+ "Cell Recall": "cell_recall",
+ "Tuple Order": "tuple_order",
+ "Tuple Cardinality": "tuple_cardinality",
+ "Tuple Constraint": "tuple_constraint"
+ }
+
+ qatch_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
+ last_valid_qatch_metrics_selection = qatch_metrics.copy() # Per salvare lβultima selezione valida
+ def enforce_qatch_metrics_selection(selected):
+ global last_valid_qatch_metrics_selection
+ if not selected: # Se nessuna metrica Γ¨ selezionata
+ return gr.update(value=last_valid_qatch_metrics_selection)
+ last_valid_qatch_metrics_selection = selected # Altrimenti aggiorna la selezione valida
+ return gr.update(value=selected)
+
+ external_metrics_dict = {
+ "Execution Accuracy": "execution_accuracy",
+ "Valid Efficiency Score": "valid_efficiency_score"
+ }
+
+ external_metric = ["execution_accuracy", "valid_efficiency_score"]
+ last_valid_external_metric_selection = external_metric.copy()
+ def enforce_external_metric_selection(selected):
+ global last_valid_external_metric_selection
+ if not selected: # Se nessuna metrica Γ¨ selezionata
+ return gr.update(value=last_valid_external_metric_selection)
+ last_valid_external_metric_selection = selected # Altrimenti aggiorna la selezione valida
+ return gr.update(value=selected)
+
+ all_metrics = {
+ "Qatch": ["qatch"],
+ "External": ["external"]
+ }
+
group_options = {
"Table": ["tbl_name", "model"],
"Model": ["model"]
}
df_initial = load_data_csv_es()
+
models = df_initial['model'].unique().tolist()
+ last_valid_model_selection = models.copy() # Per salvare lβultima selezione valida
+ def enforce_model_selection(selected):
+ global last_valid_model_selection
+ if not selected: # Se nessuna metrica Γ¨ selezionata
+ return gr.update(value=last_valid_model_selection)
+ last_valid_model_selection = selected # Altrimenti aggiorna la selezione valida
+ return gr.update(value=selected)
+
+ all_categories = df_initial['sql_tag'].unique().tolist()
+
+ principal_categories = df_initial['test_category'].unique().tolist()
+ last_valid_category_selection = principal_categories.copy() # Per salvare lβultima selezione valida
+ def enforce_category_selection(selected):
+ global last_valid_category_selection
+ if not selected: # Se nessuna metrica Γ¨ selezionata
+ return gr.update(value=last_valid_category_selection)
+ last_valid_category_selection = selected # Altrimenti aggiorna la selezione valida
+ return gr.update(value=selected)
+
+ all_categories_as_dic = {cat: [f"{cat}"] for cat in principal_categories}
+
+ all_categories_as_dic_ranking = {cat: [f"{cat}"] for cat in principal_categories}
+ all_categories_as_dic_ranking["All"] = principal_categories
+
+ all_model_as_dic = {cat: [f"{cat}"] for cat in models}
+ all_model_as_dic["All"] = models
#with gr.Blocks(theme=gr.themes.Default(primary_hue='blue')) as demo:
- gr.Markdown("""## π Model Performance Analysis π
- Select one or more metrics to calculate the average and visualize histograms and radar plots.
- """)
- # Options selection section
+
+
+ ###########################
+ # VISUALIZATION SECTION #
+ ###########################
+ gr.Markdown("""# Model Performance Analysis""")
+
+ #FOR BAR
+ gr.Markdown("""## Section 1: Model - Data""")
with gr.Row():
+ choose_metrics_bar = gr.Radio(
+ choices=list(all_metrics.keys()),
+ label="Select the metrics group that you want to use:",
+ value="Qatch"
+ )
- metric_multiselect = gr.CheckboxGroup(choices=metrics, label="Select metrics", value=metrics)
- model_multiselect = gr.CheckboxGroup(choices=models, label="Select models", value=models)
- group_radio = gr.Radio(choices=list(group_options.keys()), label="Select grouping", value="Table")
+ qatch_metric_multiselect_bar = gr.CheckboxGroup(
+ choices=list(qatch_metrics_dict.keys()),
+ label="Select one or mode Qatch metrics:",
+ value=list(qatch_metrics_dict.keys()),
+ visible=True
+ )
- output_plot = gr.Plot(visible=False)
+ external_metric_select_bar = gr.CheckboxGroup(
+ choices=list(external_metrics_dict.keys()),
+ label="Select one or more External metrics:",
+ visible=False
+ )
+
+ if(input_data['input_method'] == 'default'):
+ model_radio_bar = gr.Radio(
+ choices=list(all_model_as_dic.keys()),
+ label="Select the model that you want to use:",
+ value="All"
+ )
+ else:
+ model_multiselect_bar = gr.CheckboxGroup(
+ choices=models,
+ label="Select one or more models:",
+ value=models
+ )
+
+ group_radio = gr.Radio(
+ choices=list(group_options.keys()),
+ label="Select the grouping view:",
+ value="Table"
+ )
- query_rate_plot = gr.Plot(value=update_query_rate(models))
+ def toggle_metric_selector(selected_type):
+ if selected_type == "Qatch":
+ return gr.update(visible=True, value=list(qatch_metrics_dict.keys())), gr.update(visible=False, value=[])
+ else:
+ return gr.update(visible=False, value=[]), gr.update(visible=True, value=list(external_metrics_dict.keys()))
+
+ output_plot = gr.Plot(visible=False)
+
+ if(input_data['input_method'] == 'default'):
+ with gr.Row():
+ lollipop_propietary()
+ #FOR RADAR
+ gr.Markdown("""## Section 2: Model - Category""")
+ with gr.Row():
+ all_metrics_radar = gr.Radio(
+ choices=list(all_metrics.keys()),
+ label="Select the metrics group that you want to use:",
+ value="Qatch"
+ )
+
+ model_multiselect_radar = gr.CheckboxGroup(
+ choices=models,
+ label="Select one or more models:",
+ value=models
+ )
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ category_multiselect_radar = gr.CheckboxGroup(
+ choices=principal_categories,
+ label="Select one or more categories:",
+ value=principal_categories
+ )
+ with gr.Column(scale=1):
+ category_radio_radar = gr.Radio(
+ choices=list(all_categories_as_dic.keys()),
+ label="Select the metrics that you want to use:",
+ value=list(all_categories_as_dic.keys())[0]
+ )
+
with gr.Row():
with gr.Column(scale=1):
- radar_plot = gr.Plot(value=update_radar(models))
+ radar_plot_multiselect = gr.Plot(value=update_radar(models, "Qatch", principal_categories))
with gr.Column(scale=1):
- ranking_type_radio = gr.Radio(
- ["valid_efficiency_score", "time", "metrics"],
- label="Choose ranking criteria",
- value="valid_efficiency_score"
- )
- ranking_text_display = gr.Markdown(value=update_ranking_text(models, "valid_efficiency_score"))
- worst_cases_display = gr.Markdown(value=update_worst_cases_text(models))
-
- # Callback functions for updating charts
- def on_change(selected_metrics, selected_group, selected_models):
- return update_plot(selected_metrics, group_options[selected_group], selected_models)
+ radar_plot_radio = gr.Plot(value=update_radar_sub(models, "Qatch", list(all_categories_as_dic.keys())[0]))
- def on_radar_change(selected_models):
- return update_radar(selected_models)
+ #FOR RANKING
+ with gr.Row():
+ all_metrics_ranking = gr.Radio(
+ choices=list(all_metrics.keys()),
+ label="Select the metrics group that you want to use:",
+ value="Qatch"
+ )
+
+ model_radio_ranking = gr.Radio(
+ choices=list(all_model_as_dic.keys()),
+ label="Select the model that you want to use:",
+ value="All"
+ )
+
+ category_radio_ranking = gr.Radio(
+ choices=list(all_categories_as_dic_ranking.keys()),
+ label="Select the category that you want to use",
+ value="All"
+ )
- #metrics_df_out.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
- proceed_to_metrics_button.click(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown("## β 3 Worst Cases\n")
+
+ worst_first, worst_second, worst_third, raw_first, raw_second, raw_third = update_worst_cases_text("All", "Qatch", "All")
+
+ with gr.Row():
+ first = gr.Markdown(worst_first)
+
+ with gr.Row():
+ first_button = gr.Button("Show row answer for π₯")
+
+ with gr.Row():
+ second = gr.Markdown(worst_second)
+
+ with gr.Row():
+ second_button = gr.Button("Show row answer for π₯")
+
+ with gr.Row():
+ third = gr.Markdown(worst_third)
+
+ with gr.Row():
+ third_button = gr.Button("Show row answer for π₯")
+
+ with gr.Column(scale=1):
+ gr.Markdown("""## Row Answer""")
+ row_answer_first = gr.Markdown(value=raw_first, visible=True)
+ row_answer_second = gr.Markdown(value=raw_second, visible=False)
+ row_answer_third = gr.Markdown(value=raw_third, visible=False)
+
+ #FOR RATE
+ gr.Markdown("""## Section 3: Time - Price""")
+ with gr.Row():
+ model_multiselect_rate = gr.CheckboxGroup(
+ choices=models,
+ label="Select one or more models:",
+ value=models
+ )
- proceed_to_metrics_button.click(update_query_rate, inputs=[model_multiselect], outputs=query_rate_plot)
+ with gr.Row():
+ slicer = gr.Slider(minimum=0, maximum=max(df_initial["number_question"]), step=0, value=max(df_initial["number_question"]), label="Number of instances that you want to visualize")
+
+ query_rate_plot = gr.Plot(value=update_query_rate(models, len(df_initial["number_question"].unique())))
- metric_multiselect.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
- group_radio.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
- model_multiselect.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
- model_multiselect.change(update_radar, inputs=model_multiselect, outputs=radar_plot)
- model_multiselect.change(update_ranking_text, inputs=[model_multiselect, ranking_type_radio], outputs=ranking_text_display)
- ranking_type_radio.change(update_ranking_text, inputs=[model_multiselect, ranking_type_radio], outputs=ranking_text_display)
- model_multiselect.change(update_worst_cases_text, inputs=model_multiselect, outputs=worst_cases_display)
- model_multiselect.change(update_query_rate, inputs=[model_multiselect], outputs=query_rate_plot)
+ #FOR RESET
reset_data = gr.Button("Back to upload data section")
- reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
+
+
+
+
+ ###############################
+ # CALLBACK FUNCTION SECTION #
+ ###############################
- reset_data.click(
- fn=lambda: gr.update(visible=False),
- outputs=[download_metrics]
- )
- reset_data.click(
- fn=lambda: gr.update(visible=False),
- outputs=[download_metrics]
- )
- reset_data.click(
- fn=enable_disable,
- inputs=[gr.State(True)],
- outputs=[
- *model_checkboxes,
- submit_models_button,
- preview_output,
- submit_button,
- file_input,
- default_checkbox,
- table_selector,
- *table_outputs,
- open_model_selection
- ]
- )
+ #FOR BAR
+ def on_change(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_group, selected_models):
+ return update_plot(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, group_options[selected_group], selected_models)
+
+ def on_change_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models):
+ return update_plot_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models)
+
+ #FOR RADAR
+ def on_radar_multiselect_change(selected_models, selected_metrics, selected_categories):
+ return update_radar(selected_models, selected_metrics, selected_categories)
+
+ def on_radar_radio_change(selected_models, selected_metrics, selected_category):
+ return update_radar_sub(selected_models, selected_metrics, selected_category)
+
+ #FOR RANKING
+ def on_ranking_change(selected_models, selected_metrics, selected_categories):
+ return update_worst_cases_text(selected_models, selected_metrics, selected_categories)
+
+ def show_first():
+ return (
+ gr.update(visible=True),
+ gr.update(visible=False),
+ gr.update(visible=False)
+ )
+
+ def show_second():
+ return (
+ gr.update(visible=False),
+ gr.update(visible=True),
+ gr.update(visible=False)
+ )
+
+ def show_third():
+ return (
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=True)
+ )
+
+
+
+
+ ######################
+ # ON CLICK SECTION #
+ ######################
+
+ #FOR BAR
+ if(input_data['input_method'] == 'default'):
+ proceed_to_metrics_button.click(on_change_propietary, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
+ qatch_metric_multiselect_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
+ external_metric_select_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
+ model_radio_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
+ qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
+ choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_metric_multiselect_bar, external_metric_select_bar])
+ external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
+
+ else:
+ proceed_to_metrics_button.click(on_change, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
+ qatch_metric_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
+ external_metric_select_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
+ group_radio.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
+ model_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
+ qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
+ model_multiselect_bar.change(fn=enforce_model_selection, inputs=model_multiselect_bar, outputs=model_multiselect_bar)
+ choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_metric_multiselect_bar, external_metric_select_bar])
+ external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
+
+
+ #FOR RADAR MULTISELECT
+ model_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
+ all_metrics_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
+ category_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
+ model_multiselect_radar.change(fn=enforce_model_selection, inputs=model_multiselect_radar, outputs=model_multiselect_radar)
+ category_multiselect_radar.change(fn=enforce_category_selection, inputs=category_multiselect_radar, outputs=category_multiselect_radar)
+
+ #FOR RADAR RADIO
+ model_multiselect_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
+ all_metrics_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
+ category_radio_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
+ #FOR RANKING
+ model_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
+ model_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
+ all_metrics_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
+ all_metrics_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
+ category_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
+ category_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
+ model_radio_ranking.change(fn=enforce_model_selection, inputs=model_radio_ranking, outputs=model_radio_ranking)
+ category_radio_ranking.change(fn=enforce_category_selection, inputs=category_radio_ranking, outputs=category_radio_ranking)
+ first_button.click(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
+ second_button.click(fn=show_second, outputs=[row_answer_first, row_answer_second, row_answer_third])
+ third_button.click(fn=show_third, outputs=[row_answer_first, row_answer_second, row_answer_third])
+ #FOR RATE
+ model_multiselect_rate.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
+ proceed_to_metrics_button.click(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
+ model_multiselect_rate.change(fn=enforce_model_selection, inputs=model_multiselect_rate, outputs=model_multiselect_rate)
+ slicer.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
+
+ #FOR RESET
+ reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
+ reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics])
+ reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
+
interface.launch()
\ No newline at end of file