qatch-demo / app.py
franceth's picture
More stable version. Link all acc, but still miss prediction
696b8fe verified
raw
history blame
40.8 kB
import gradio as gr
import pandas as pd
import os
import sys
from qatch.connectors.sqlite_connector import SqliteConnector
from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
#from predictor.orchestrator_predictor import OrchestratorPredictor
import utils_get_db_tables_info
import utilities as us
import time
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
with open('style.css', 'r') as file:
css = file.read()
# DataFrame di default
df_default = pd.DataFrame({
'Name': ['Alice', 'Bob', 'Charlie'],
'Age': [25, 30, 35],
'City': ['New York', 'Los Angeles', 'Chicago']
})
models_path = "models.csv"
# Variabile globale per tenere traccia dei dati correnti
df_current = df_default.copy()
input_data = {
'input_method': "",
'data_path': "",
'db_name': "",
'data': {
'data_frames': {}, # dictionary of dataframes
'db': None # SQLITE3 database object
},
'models': []
}
def load_data(file, path, use_default):
"""Carica i dati da un file, un percorso o usa il DataFrame di default."""
global df_current
if use_default:
input_data["input_method"] = 'default'
input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable.sqlite")
input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0]
input_data["data"]['data_frames'] = {'MyTable': df_current}
if( input_data["data"]['data_frames']):
table2primary_key = {}
for table_name, df in input_data["data"]['data_frames'].items():
# Assign primary keys for each table
table2primary_key[table_name] = 'id'
input_data["data"]["db"] = SqliteConnector(
relative_db_path=input_data["data_path"],
db_name=input_data["db_name"],
tables= input_data["data"]['data_frames'],
table2primary_key=table2primary_key
)
df_current = df_default.copy() # Ripristina i dati di default
return input_data["data"]['data_frames']
selected_inputs = sum([file is not None, bool(path), use_default])
if selected_inputs > 1:
return 'Errore: Selezionare solo un metodo di input alla volta.'
if file is not None:
try:
input_data["input_method"] = 'uploaded_file'
input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
input_data["data_path"] = os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
input_data["data"] = us.load_data(file, input_data["db_name"])
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
if( input_data["data"]['data_frames']):
table2primary_key = {}
for table_name, df in input_data["data"]['data_frames'].items():
# Assign primary keys for each table
table2primary_key[table_name] = 'id'
input_data["data"]["db"] = SqliteConnector(
relative_db_path=input_data["data_path"],
db_name=input_data["db_name"],
tables= input_data["data"]['data_frames'],
table2primary_key=table2primary_key
)
return input_data["data"]['data_frames']
except Exception as e:
return f'Errore nel caricamento del file: {e}'
"""
if path:
if not os.path.exists(path):
return 'Errore: Il percorso specificato non esiste.'
try:
input_data["input_method"] = 'uploaded_file'
input_data["data_path"] = path
input_data["db_name"] = os.path.splitext(os.path.basename(path))[0]
input_data["data"] = us.load_data(input_data["data_path"], input_data["db_name"])
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
return input_data["data"]['data_frames']
except Exception as e:
return f'Errore nel caricamento del file dal percorso: {e}'
"""
return input_data["data"]['data_frames']
def preview_default(use_default):
"""Mostra il DataFrame di default se il checkbox Γ¨ selezionato."""
if use_default:
return df_default # Mostra il DataFrame di default
return df_current # Mostra il DataFrame corrente, che potrebbe essere stato modificato
def update_df(new_df):
"""Aggiorna il DataFrame corrente."""
global df_current # Usa la variabile globale per aggiornarla
df_current = new_df
return df_current
def open_accordion(target):
# Apre uno e chiude l'altro
if target == "reset":
df_current = df_default.copy()
input_data['input_method'] = ""
input_data['data_path'] = ""
input_data['db_name'] = ""
input_data['data']['data_frames'] = {}
input_data['data']['db'] = None
input_data['models'] = []
return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value=False), gr.update(value=None)
elif target == "model_selection":
return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False)
# Interfaccia Gradio
with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
gr.Markdown("# QATCH")
data_state = gr.State(None) # Memorizza i dati caricati
upload_acc = gr.Accordion("Upload your data section", open=True, visible=True)
select_table_acc = gr.Accordion("Select tables", open=False, visible=False)
select_model_acc = gr.Accordion("Select models", open=False, visible=False)
qatch_acc = gr.Accordion("QATCH execution", open=False, visible=False)
metrics_acc = gr.Accordion("Metrics", open=False, visible=False)
#metrics_acc = gr.Accordion("Metrics", open=False, visible=False, render=False)
#################################
# PARTE DI INSERIMENTO DEL DB #
#################################
with upload_acc:
gr.Markdown("## Caricamento dei Dati")
file_input = gr.File(label="Trascina e rilascia un file", file_types=[".csv", ".xlsx", ".sqlite"])
with gr.Row():
default_checkbox = gr.Checkbox(label="Usa DataFrame di default")
preview_output = gr.DataFrame(interactive=True, visible=True, value=df_default)
submit_button = gr.Button("Carica Dati", interactive=False) # Disabilitato di default
output = gr.JSON(visible=False) # Output dizionario
# Funzione per abilitare il bottone se sono presenti dati da caricare
def enable_submit(file, use_default):
return gr.update(interactive=bool(file or use_default))
# Funzione per deselezionare il checkbox se viene caricato un file
def deselect_default(file):
if file:
return gr.update(value=False)
return gr.update()
# Abilita il bottone quando i campi di input sono valorizzati
file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
# Mostra l'anteprima del DataFrame di default quando il checkbox Γ¨ selezionato
default_checkbox.change(fn=preview_default, inputs=[default_checkbox], outputs=[preview_output])
preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output])
# Deseleziona il checkbox quando viene caricato un file
file_input.change(fn=deselect_default, inputs=[file_input], outputs=[default_checkbox])
def handle_output(file, use_default):
"""Gestisce l'output quando si preme il bottone 'Carica Dati'."""
result = load_data(file, None, use_default)
if isinstance(result, dict): # Se result Γ¨ un dizionario di DataFrame
if len(result) == 1: # Se c'Γ¨ solo una tabella
return (
gr.update(visible=False), # Nasconde l'output JSON
result, # Salva lo stato dei dati
gr.update(visible=False), # Nasconde la selezione tabella
result, # Mantiene lo stato dei dati
gr.update(interactive=False), # Disabilita il pulsante di submit
gr.update(visible=True, open=True), # Passa direttamente a select_model_acc
gr.update(visible=True, open=False)
)
else:
return (
gr.update(visible=False),
result,
gr.update(open=True, visible=True),
result,
gr.update(interactive=False),
gr.update(visible=False), # Mantiene il comportamento attuale
gr.update(visible=True, open=True)
)
else:
return (
gr.update(visible=False),
None,
gr.update(open=False, visible=True),
None,
gr.update(interactive=True),
gr.update(visible=False),
gr.update(visible=True, open=True)
)
submit_button.click(
fn=handle_output,
inputs=[file_input, default_checkbox],
outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc]
)
######################################
# PARTE DI SELEZIONE DELLE TABELLE #
######################################
with select_table_acc:
table_selector = gr.CheckboxGroup(choices=[], label="Seleziona le tabelle da visualizzare", value=[])
table_outputs = [gr.DataFrame(label=f"Tabella {i+1}", interactive=True, visible=False) for i in range(5)]
selected_table_names = gr.Textbox(label="Tabelle selezionate", visible=False, interactive=False)
# Bottone di selezione modelli (inizialmente disabilitato)
open_model_selection = gr.Button("Choose your models", interactive=False)
def update_table_list(data):
"""Aggiorna dinamicamente la lista delle tabelle disponibili."""
if isinstance(data, dict) and data:
table_names = list(data.keys()) # Ritorna solo i nomi delle tabelle
return gr.update(choices=table_names, value=[]) # Reset delle selezioni
return gr.update(choices=[], value=[])
def show_selected_tables(data, selected_tables):
"""Mostra solo le tabelle selezionate dall'utente e abilita il bottone."""
updates = []
if isinstance(data, dict) and data:
available_tables = list(data.keys()) # Nomi effettivamente disponibili
selected_tables = [t for t in selected_tables if t in available_tables] # Filtra selezioni valide
tables = {name: data[name] for name in selected_tables} # Filtra i DataFrame
for i, (name, df) in enumerate(tables.items()):
updates.append(gr.update(value=df, label=f"Tabella: {name}", visible=True))
# Se ci sono meno di 5 tabelle, nascondi gli altri DataFrame
for _ in range(len(tables), 5):
updates.append(gr.update(visible=False))
else:
updates = [gr.update(value=pd.DataFrame(), visible=False) for _ in range(5)]
# Abilitare/disabilitare il bottone in base alle selezioni
button_state = bool(selected_tables) # True se almeno una tabella Γ¨ selezionata, False altrimenti
updates.append(gr.update(interactive=button_state)) # Aggiorna stato bottone
return updates
def show_selected_table_names(selected_tables):
"""Mostra i nomi delle tabelle selezionate quando si preme il bottone."""
if selected_tables:
return gr.update(value=", ".join(selected_tables), visible=False)
return gr.update(value="", visible=False)
# Aggiorna automaticamente la lista delle checkbox quando `data_state` cambia
data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector])
# Aggiorna le tabelle visibili e lo stato del bottone in base alle selezioni dell'utente
table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
# Mostra la lista delle tabelle selezionate quando si preme "Choose your models"
open_model_selection.click(fn=show_selected_table_names, inputs=[table_selector], outputs=[selected_table_names])
open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
####################################
# PARTE DI SELEZIONE DEL MODELLO #
####################################
with select_model_acc:
gr.Markdown("**Model Selection**")
# Supponiamo che `us.read_models_csv` restituisca anche il percorso dell'immagine
model_list_dict = us.read_models_csv(models_path)
model_list = [model["code"] for model in model_list_dict]
model_images = [model["image_path"] for model in model_list_dict]
model_checkboxes = []
rows = []
# Creazione dinamica di checkbox con immagini (3 per riga)
for i in range(0, len(model_list), 3):
with gr.Row():
cols = []
for j in range(3):
if i + j < len(model_list):
model = model_list[i + j]
image_path = model_images[i + j]
with gr.Column():
gr.Image(image_path, show_label=False)
checkbox = gr.Checkbox(label=model, value=False)
model_checkboxes.append(checkbox)
cols.append(checkbox)
rows.append(cols)
selected_models_output = gr.JSON(visible=False)
# Funzione per ottenere i modelli selezionati
def get_selected_models(*model_selections):
selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
input_data['models'] = selected_models
button_state = bool(selected_models) # True se almeno un modello Γ¨ selezionato, False altrimenti
return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
# Bottone di submit (inizialmente disabilitato)
submit_models_button = gr.Button("Submit Models", interactive=False)
# Collegamento dei checkbox agli eventi di selezione
for checkbox in model_checkboxes:
checkbox.change(
fn=get_selected_models,
inputs=model_checkboxes,
outputs=[selected_models_output, select_model_acc, submit_models_button]
)
submit_models_button.click(
fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
inputs=model_checkboxes,
outputs=[selected_models_output, select_model_acc, qatch_acc]
)
reset_data = gr.Button("Back to upload data section")
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
###############################
# PARTE DI ESECUZIONE QATCH #
###############################
with qatch_acc:
def change_text(text):
return text
def qatch_flow():
orchestrator_generator = OrchestratorGenerator()
#TODO add to target_df column target_df["columns_used"], tables selection
#print(input_data['data']['db'])
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'])
schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
db_id = input_data["db_name"],
base_path = input_data["data_path"],
normalize=False,
sql=None
)
# TODO QUERY PREDICTION
predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
metrics_conc = pd.DataFrame()
for model in input_data["models"]:
for index, row in target_df.iterrows():
if len(target_df) != 0: load_value = f"##Loading... {round((index + 1) / len(target_df) * 100, 2)}%"
else: load_value = "##Loading..."
question = row['query']
#yield gr.Textbox(question), gr.Textbox(), *[predictions_dict[model] for model in input_data["models"]], None
yield gr.Markdown(value=load_value), gr.Textbox(question), gr.Textbox(), metrics_conc, *[predictions_dict[model] for model in model_list]
start_time = time.time()
# Simulazione della predizione
time.sleep(0.03)
prediction = "Prediction_placeholder"
# Esegui la predizione reale qui
# prediction = predictor.run(model, schema_text, question)
end_time = time.time()
# Crea una nuova riga come dataframe
new_row = pd.DataFrame([{
'id': index,
'question': question,
'predicted_sql': prediction,
'time': end_time - start_time,
'query': row["query"],
'db_path': input_data["data_path"]
}]).dropna(how="all") # Rimuove solo righe completamente vuote
#TODO con un for
for col in target_df.columns:
if col not in new_row.columns:
new_row[col] = row[col]
# Aggiorna il dataframe corrispondente al modello man mano
if not new_row.empty:
predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
#yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
yield gr.Markdown(value=load_value), gr.Textbox(), gr.Textbox(prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
#END
evaluator = OrchestratorEvaluator()
for model in input_data["models"]:
metrics_df_model = evaluator.evaluate_df(
df=predictions_dict[model],
target_col_name="query", #'<target_column_name>',
prediction_col_name="predicted_sql", #'<prediction_column_name>',
db_path_name= "db_path", #'<db_path_column_name>'
)
metrics_df_model['model'] = model
metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
if 'valid_efficiency_score' not in metrics_conc.columns:
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
yield gr.Markdown(), gr.Textbox(), gr.Textbox(), metrics_conc, *[predictions_dict[model] for model in model_list]
#Loading Bar
with gr.Row():
#progress = gr.Progress()
variable = gr.Markdown()
#NL -> MODEL -> Generated Quesy
with gr.Row():
with gr.Column():
question_display = gr.Textbox()
with gr.Column():
gr.Image()
with gr.Column():
prediction_display = gr.Textbox()
dataframe_per_model = {}
with gr.Tabs() as model_tabs:
#for model in input_data["models"]:
for model in model_list:
#TODO fix model tabs
with gr.TabItem(model):
gr.Markdown(f"**Results for {model}**")
dataframe_per_model[model] = gr.DataFrame()
#question_display.change(fn=change_text, inputs=[gr.State(question)], outputs=[question_display])
selected_models_display = gr.JSON(label="Modelli selezionati")
metrics_df = gr.DataFrame(visible=False)
metrics_df_out= gr.DataFrame(visible=False)
submit_models_button.click(
fn=qatch_flow,
inputs=[],
outputs=[variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
)
submit_models_button.click(
fn=lambda: gr.update(value=input_data),
outputs=[selected_models_display]
)
#Funziona per METRICS
metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
# def change_tab(selected_models_output, model_tabs):
# for model in model_list:
# if model in selected_models_output:
# pass#model_tabs[model].visible = True
# else:
# pass#model_tabs[model].visible = False
# return model_tabs
# selected_models_output.change(fn=change_tab, inputs=[selected_models_output, model_tabs], outputs=[])
proceed_to_metrics_button = gr.Button("Proceed to Metrics")
proceed_to_metrics_button.click(
fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
outputs=[qatch_acc, metrics_acc]
)
reset_data = gr.Button("Back to upload data section")
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
#######################################
# METRICS VISUALIZATION SECTION #
#######################################
with metrics_acc:
#confirmation_text = gr.Markdown("## Metrics successfully loaded")
data_path = 'test_results.csv'
@gr.render(inputs=metrics_df_out)
def function_metrics(metrics_df_out):
def load_data_csv_es():
return pd.read_csv(data_path)
#return metrics_df_out
def calculate_average_metrics(df, selected_metrics):
df['avg_metric'] = df[selected_metrics].mean(axis=1)
return df
def generate_model_colors():
"""Generates a unique color map for models in the dataset."""
df = load_data_csv_es()
unique_models = df['model'].unique() # Extract unique models
num_models = len(unique_models)
# Use the Plotly color scale (you can change it if needed)
color_palette = pc.qualitative.Plotly # ['#636EFA', '#EF553B', '#00CC96', ...]
# If there are more models than colors, cycle through them
colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)}
return colors
MODEL_COLORS = generate_model_colors()
# BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
def plot_metric(df, selected_metrics, group_by, selected_models):
df = df[df['model'].isin(selected_models)]
df = calculate_average_metrics(df, selected_metrics)
# Ensure the group_by value is always valid
if group_by not in [["tbl_name", "model"], ["model"]]:
group_by = ["tbl_name", "model"] # Default
avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index()
fig = px.bar(
avg_metrics,
x=group_by[0],
y='avg_metric',
color='model',
color_discrete_map=MODEL_COLORS,
barmode='group',
title=f'Average metric per {group_by[0]} πŸ“Š',
labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metric'},
template='plotly_dark'
)
return fig
def update_plot(selected_metrics, group_by, selected_models):
df = load_data_csv_es()
return plot_metric(df, selected_metrics, group_by, selected_models)
# RADAR CHART FOR AVERAGE METRICS PER MODEL WITH UPDATE FUNCTION
def plot_radar(df, selected_models):
# Filter only selected models
df = df[df['model'].isin(selected_models)]
# Select relevant metrics
selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
# Compute average metrics per test_category and model
df = calculate_average_metrics(df, selected_metrics)
avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
# Check if data is available
if avg_metrics.empty:
print("Error: No data available to compute averages.")
return go.Figure()
fig = go.Figure()
categories = avg_metrics['test_category'].unique()
for model in selected_models:
model_data = avg_metrics[avg_metrics['model'] == model]
# Build a list of values for each category (if a value is missing, set it to 0)
values = [
model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
if cat in model_data['test_category'].values else 0
for cat in categories
]
fig.add_trace(go.Scatterpolar(
r=values,
theta=categories,
fill='toself',
name=model,
line=dict(color=MODEL_COLORS.get(model, "gray"))
))
fig.update_layout(
polar=dict(radialaxis=dict(visible=True, range=[0, max(avg_metrics['avg_metric'].max(), 0.5)])), # Set the radar range
title='❇️ Radar Plot of Metrics per Model (Average per Category) ❇️ ',
template='plotly_dark',
width=700, height=700
)
return fig
def update_radar(selected_models):
df = load_data_csv_es()
return plot_radar(df, selected_models)
# LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
def plot_cumulative_flow(df, selected_models):
df = df[df['model'].isin(selected_models)]
fig = go.Figure()
for model in selected_models:
model_df = df[df['model'] == model].copy()
# Calculate cumulative time
model_df['cumulative_time'] = model_df['time'].cumsum()
# Calculate cumulative number of queries over time
model_df['cumulative_queries'] = range(1, len(model_df) + 1)
# Select a color for the model
color = MODEL_COLORS.get(model, "gray") # Assigned model color
fillcolor = color.replace("rgb", "rgba").replace(")", ", 0.2)") # πŸ”Ή Makes the area semi-transparent
#color = f"rgba({hash(model) % 256}, {hash(model * 2) % 256}, {hash(model * 3) % 256}, 1)"
fig.add_trace(go.Scatter(
x=model_df['cumulative_time'],
y=model_df['cumulative_queries'],
mode='lines+markers',
name=model,
line=dict(width=2, color=color)
))
# Adds the underlying colored area (same color but transparent)
"""
fig.add_trace(go.Scatter(
x=model_df['cumulative_time'],
y=model_df['cumulative_queries'],
fill='tozeroy',
mode='none',
showlegend=False, # Hides the area in the legend
fillcolor=fillcolor
))
"""
fig.update_layout(
title="Cumulative Query Flow Chart πŸ“ˆ",
xaxis_title="Cumulative Time (s)",
yaxis_title="Number of Queries Completed",
template='plotly_dark',
legend_title="Models"
)
return fig
def update_query_rate(selected_models):
df = load_data_csv_es()
return plot_cumulative_flow(df, selected_models)
# RANKING FOR THE TOP 3 MODELS WITH UPDATE FUNCTION
def ranking_text(df, selected_models, ranking_type):
#df = load_data_csv_es()
df = df[df['model'].isin(selected_models)]
df['valid_efficiency_score'] = pd.to_numeric(df['valid_efficiency_score'], errors='coerce')
if ranking_type == "valid_efficiency_score":
rank_df = df.groupby('model')['valid_efficiency_score'].mean().reset_index()
#rank_df = df.groupby('model')['valid_efficiency_score'].mean().reset_index()
ascending_order = False # Higher is better
elif ranking_type == "time":
rank_df = df.groupby('model')['time'].sum().reset_index()
rank_df["Ranking Value"] = rank_df["time"].round(2).astype(str) + " s" # Adds "s" for seconds
ascending_order = True # For time, lower is better
elif ranking_type == "metrics":
selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
df = calculate_average_metrics(df, selected_metrics)
rank_df = df.groupby('model')['avg_metric'].mean().reset_index()
ascending_order = False # Higher is better
if ranking_type != "time":
rank_df.rename(columns={rank_df.columns[1]: "Ranking Value"}, inplace=True)
rank_df["Ranking Value"] = rank_df["Ranking Value"].round(2) # Round values except for time
# Sort based on the selected criterion
rank_df = rank_df.sort_values(by="Ranking Value", ascending=ascending_order).reset_index(drop=True)
# Select only the top 3 models
rank_df = rank_df.head(3)
# Add medal icons for the top 3
medals = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"]
rank_df.insert(0, "Rank", medals[:len(rank_df)])
# Build the formatted ranking string
ranking_str = "## πŸ† Model Ranking\n"
for _, row in rank_df.iterrows():
ranking_str += f"<span style='font-size:18px;'>{row['Rank']} {row['model']} ({row['Ranking Value']})</span><br>\n"
return ranking_str
def update_ranking_text(selected_models, ranking_type):
df = load_data_csv_es()
return ranking_text(df, selected_models, ranking_type)
# RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
def worst_cases_text(df, selected_models):
df = df[df['model'].isin(selected_models)]
selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
df = calculate_average_metrics(df, selected_metrics)
worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql'])['avg_metric'].mean().reset_index()
worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
worst_cases_top_3 = worst_cases_df.head(3)
worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2)
worst_str = "## ❌ Top 3 Worst Cases\n"
medals = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"]
for i, row in worst_cases_top_3.iterrows():
worst_str += (
f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']}</b> ({row['avg_metric']})</span> \n"
f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n"
f"<span style='font-size:16px;'>- <b>Original Query:</b> `{row['query']}`</span> \n"
f"<span style='font-size:16px;'>- <b>Predicted SQL:</b> `{row['predicted_sql']}`</span> \n\n"
)
return worst_str
def update_worst_cases_text(selected_models):
df = load_data_csv_es()
return worst_cases_text(df, selected_models)
metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
group_options = {
"Table": ["tbl_name", "model"],
"Model": ["model"]
}
df_initial = load_data_csv_es()
models = df_initial['model'].unique().tolist()
#with gr.Blocks(theme=gr.themes.Default(primary_hue='blue')) as demo:
gr.Markdown("""## πŸ“Š Model Performance Analysis πŸ“Š
Select one or more metrics to calculate the average and visualize histograms and radar plots.
""")
# Options selection section
with gr.Row():
metric_multiselect = gr.CheckboxGroup(choices=metrics, label="Select metrics", value=metrics)
model_multiselect = gr.CheckboxGroup(choices=models, label="Select models", value=models)
group_radio = gr.Radio(choices=list(group_options.keys()), label="Select grouping", value="Model")
output_plot = gr.Plot()
query_rate_plot = gr.Plot(value=update_query_rate(models))
with gr.Row():
with gr.Column(scale=1):
radar_plot = gr.Plot(value=update_radar(models))
with gr.Column(scale=1):
ranking_type_radio = gr.Radio(
["valid_efficiency_score", "time", "metrics"],
label="Choose ranking criteria",
value="valid_efficiency_score"
)
ranking_text_display = gr.Markdown(value=update_ranking_text(models, "valid_efficiency_score"))
worst_cases_display = gr.Markdown(value=update_worst_cases_text(models))
# Callback functions for updating charts
def on_change(selected_metrics, selected_group, selected_models):
return update_plot(selected_metrics, group_options[selected_group], selected_models)
def on_radar_change(selected_models):
return update_radar(selected_models)
#metrics_df_out.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
metric_multiselect.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
group_radio.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
model_multiselect.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
model_multiselect.change(update_radar, inputs=model_multiselect, outputs=radar_plot)
model_multiselect.change(update_ranking_text, inputs=[model_multiselect, ranking_type_radio], outputs=ranking_text_display)
ranking_type_radio.change(update_ranking_text, inputs=[model_multiselect, ranking_type_radio], outputs=ranking_text_display)
model_multiselect.change(update_worst_cases_text, inputs=model_multiselect, outputs=worst_cases_display)
model_multiselect.change(update_query_rate, inputs=[model_multiselect], outputs=query_rate_plot)
reset_data = gr.Button("Back to upload data section")
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
# Hidden button to force UI refresh on load
force_update_button = gr.Button("", visible=False)
# State variable to track first load
load_trigger = gr.State(value=True)
# Function to force initial load
def force_update(is_first_load):
if is_first_load:
return (
update_plot(metrics, group_options["Model"], models),
update_query_rate(models),
update_radar(models),
update_ranking_text(models, "valid_efficiency_score"),
update_worst_cases_text(models),
False # Change state to prevent continuous reloads
)
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), False
# The invisible button forces chart loading only the first time
force_update_button.click(
fn=force_update,
inputs=[load_trigger],
outputs=[output_plot, query_rate_plot, radar_plot, ranking_text_display, worst_cases_display, load_trigger]
)
# Simulate button click when UI loads
with gr.Blocks() as demo:
demo.load(
lambda: force_update(True),
outputs=[output_plot, query_rate_plot, radar_plot, ranking_text_display, worst_cases_display, load_trigger]
)
interface.launch()