diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,17 +1,34 @@ import gradio as gr import pandas as pd import os - +# # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10 +# if os.environ.get("SPACES_ZERO_GPU") is not None: +# import spaces +# else: +# class spaces: +# @staticmethod +# def GPU(func): +# def wrapper(*args, **kwargs): +# return func(*args, **kwargs) +# return wrapper import sys from qatch.connectors.sqlite_connector import SqliteConnector from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator +from prediction import ModelPrediction import utils_get_db_tables_info import utilities as us import time import plotly.express as px import plotly.graph_objects as go import plotly.colors as pc +import re +import csv + +# @spaces.GPU +# def model_prediction(): +# pass +pnp_path = os.path.join("data", "evaluation_p_metrics.csv") with open('style.css', 'r') as file: css = file.read() @@ -34,39 +51,16 @@ input_data = { 'db_name': "", 'data': { 'data_frames': {}, # dictionary of dataframes - 'db': None # SQLITE3 database object + 'db': None, # SQLITE3 database object + 'selected_tables' :[] }, - 'models': [] + 'models': [], + 'prompt': "{question} {schema}" } def load_data(file, path, use_default): """Carica i dati da un file, un percorso o usa il DataFrame di default.""" global df_current - if use_default: - input_data["input_method"] = 'default' - input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable.sqlite") - input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0] - input_data["data"]['data_frames'] = {'MyTable': df_current} - - if( input_data["data"]['data_frames']): - table2primary_key = {} - for table_name, df in input_data["data"]['data_frames'].items(): - # Assign primary keys for each table - table2primary_key[table_name] = 'id' - input_data["data"]["db"] = SqliteConnector( - relative_db_path=input_data["data_path"], - db_name=input_data["db_name"], - tables= input_data["data"]['data_frames'], - table2primary_key=table2primary_key - ) - - df_current = df_default.copy() # Ripristina i dati di default - return input_data["data"]['data_frames'] - - selected_inputs = sum([file is not None, bool(path), use_default]) - if selected_inputs > 1: - return 'Errore: Selezionare solo un metodo di input alla volta.' - if file is not None: try: input_data["input_method"] = 'uploaded_file' @@ -74,7 +68,7 @@ def load_data(file, path, use_default): input_data["data_path"] = os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite") input_data["data"] = us.load_data(file, input_data["db_name"]) df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame - if( input_data["data"]['data_frames']): + if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files table2primary_key = {} for table_name, df in input_data["data"]['data_frames'].items(): # Assign primary keys for each table @@ -88,31 +82,51 @@ def load_data(file, path, use_default): return input_data["data"]['data_frames'] except Exception as e: return f'Errore nel caricamento del file: {e}' - - """ - if path: - if not os.path.exists(path): - return 'Errore: Il percorso specificato non esiste.' - try: - input_data["input_method"] = 'uploaded_file' - input_data["data_path"] = path - input_data["db_name"] = os.path.splitext(os.path.basename(path))[0] - input_data["data"] = us.load_data(input_data["data_path"], input_data["db_name"]) - df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame - + if use_default: + if(use_default == 'Custom'): + input_data["input_method"] = 'custom' + input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable_0.sqlite") + #if file already exist + while os.path.exists(input_data["data_path"]): + input_data["data_path"] = us.increment_filename(input_data["data_path"]) + input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0] + input_data["data"]['data_frames'] = {'MyTable': df_current} + + if(input_data["data"]['data_frames']): + table2primary_key = {} + for table_name, df in input_data["data"]['data_frames'].items(): + # Assign primary keys for each table + table2primary_key[table_name] = 'id' + input_data["data"]["db"] = SqliteConnector( + relative_db_path=input_data["data_path"], + db_name=input_data["db_name"], + tables= input_data["data"]['data_frames'], + table2primary_key=table2primary_key + ) + df_current = df_default.copy() # Ripristina i dati di default + return input_data["data"]['data_frames'] + + if(use_default == 'Proprietary vs Non-proprietary'): + input_data["input_method"] = 'default' + #input_data["data_path"] = os.path.join(".", "data", "data_interface", "default.sqlite") + #input_data["data_path"] = os.path.join(".", "data", "spider_databases", "defeault.sqlite") + #input_data["db_name"] = "default" + #input_data["data"]['db'] = SqliteConnector(relative_db_path=input_data["data_path"], db_name=input_data["db_name"]) + input_data["data"]['data_frames'] = us.extract_tables_dict(pnp_path) return input_data["data"]['data_frames'] - except Exception as e: - return f'Errore nel caricamento del file dal percorso: {e}' - """ - + selected_inputs = sum([file is not None, bool(path), use_default]) + if selected_inputs > 1: + return 'Error: Select only one input method at a time.' + return input_data["data"]['data_frames'] def preview_default(use_default): - """Mostra il DataFrame di default se il checkbox Γ¨ selezionato.""" - if use_default: - return df_default # Mostra il DataFrame di default - return df_current # Mostra il DataFrame corrente, che potrebbe essere stato modificato + if use_default == 'Custom': + return gr.DataFrame(interactive=True, visible = True, value = df_default), gr.update(visible=False) + else: + return gr.DataFrame(interactive=False, visible = False, value = df_default), gr.update(visible=True) + #return gr.DataFrame(interactive=True, value = df_current) # Mostra il DataFrame corrente, che potrebbe essere stato modificato def update_df(new_df): """Aggiorna il DataFrame corrente.""" @@ -128,14 +142,16 @@ def open_accordion(target): input_data['data_path'] = "" input_data['db_name'] = "" input_data['data']['data_frames'] = {} + input_data['data']['selected_tables'] = [] input_data['data']['db'] = None input_data['models'] = [] - return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value=False), gr.update(value=None) + return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value='Proprietary vs Non-proprietary'), gr.update(value=None) elif target == "model_selection": return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False) # Interfaccia Gradio with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: +#with gr.Blocks(theme='NoCrypt/miku/light', css_paths='style.css') as interface: with gr.Row(): gr.Column(scale=1) gr.Image( @@ -153,21 +169,24 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: select_model_acc = gr.Accordion("Select models", open=False, visible=False) qatch_acc = gr.Accordion("QATCH execution", open=False, visible=False) metrics_acc = gr.Accordion("Metrics", open=False, visible=False) - #metrics_acc = gr.Accordion("Metrics", open=False, visible=False, render=False) - - ################################# # DATABASE INSERTION # ################################# with upload_acc: - gr.Markdown("## Data Upload") - - file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"]) + gr.Markdown("## Choose data input method") with gr.Row(): - default_checkbox = gr.Checkbox(label="Use default DataFrame") - preview_output = gr.DataFrame(interactive=True, visible=True, value=df_default) - submit_button = gr.Button("Load Data", interactive=False) # Disabled by default + default_checkbox = gr.Radio(label = "Use default DataFrame or costume one table", choices=['Proprietary vs Non-proprietary', 'Custom'], value='Proprietary vs Non-proprietary') + #default_checkbox = gr.Checkbox(label="Use default DataFrame") + preview_output = gr.DataFrame(interactive=False, visible=False, value=df_default) + description = """## Comparison of proprietary and non-proprietary databases + - Proprietary (Economic, Medical, Financial, Miscellaneous) + - Non-proprietary (Spider 1.0)""" + + table_default = gr.Markdown(description, visible=True) + gr.Markdown("## Or upload your data") + file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"]) + submit_button = gr.Button("Load Data") # Disabled by default output = gr.JSON(visible=False) # Dictionary output # Function to enable the button if there is data to load @@ -177,15 +196,23 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: # Function to uncheck the checkbox if a file is uploaded def deselect_default(file): if file: - return gr.update(value=False) + return gr.update(value='Proprietary vs Non-proprietary') return gr.update() + + def enable_disable_first(enable): + return ( + gr.update(interactive=enable), + gr.update(interactive=enable), + gr.update(interactive=enable), + gr.update(interactive=enable) + ) # Enable the button when inputs are provided - file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button]) - default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button]) + #file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button]) + #default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button]) # Show preview of the default DataFrame when checkbox is selected - default_checkbox.change(fn=preview_default, inputs=[default_checkbox], outputs=[preview_output]) + default_checkbox.change(fn=preview_default, inputs=[default_checkbox], outputs=[preview_output, table_default]) preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output]) # Uncheck the checkbox when a file is uploaded @@ -197,6 +224,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: if isinstance(result, dict): # If result is a dictionary of DataFrames if len(result) == 1: # If there's only one table + input_data['data']['selected_tables'] = list(input_data['data']['data_frames'].keys()) return ( gr.update(visible=False), # Hide JSON output result, # Save the data state @@ -214,7 +242,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: result, gr.update(interactive=False), gr.update(visible=False), # Keep current behavior - gr.update(visible=True, open=True) + gr.update(visible=True, open=False) ) else: return ( @@ -224,7 +252,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: None, gr.update(interactive=True), gr.update(visible=False), - gr.update(visible=True, open=True) + gr.update(visible=True, open=False) ) submit_button.click( @@ -232,14 +260,24 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: inputs=[file_input, default_checkbox], outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc] ) + + submit_button.click( + fn=enable_disable_first, + inputs=[gr.State(False)], + outputs=[ + preview_output, + submit_button, + file_input, + default_checkbox + ] + ) - - ###################################### + ###################################### # TABLE SELECTION PART # ###################################### with select_table_acc: table_selector = gr.CheckboxGroup(choices=[], label="Select tables to display", value=[]) - table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(5)] + table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(10)] selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False) # Model selection button (initially disabled) @@ -265,10 +303,10 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: updates.append(gr.update(value=df, label=f"Table: {name}", visible=True)) # If there are fewer than 5 tables, hide the other DataFrames - for _ in range(len(tables), 5): + for _ in range(len(tables), 10): updates.append(gr.update(visible=False)) else: - updates = [gr.update(value=pd.DataFrame(), visible=False) for _ in range(5)] + updates = [gr.update(value=pd.DataFrame(), visible=False) for _ in range(10)] # Enable/disable the button based on selections button_state = bool(selected_tables) # True if at least one table is selected, False otherwise @@ -279,6 +317,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: def show_selected_table_names(selected_tables): """Displays the names of the selected tables when the button is pressed.""" if selected_tables: + input_data['data']['selected_tables'] = selected_tables return gr.update(value=", ".join(selected_tables), visible=False) return gr.update(value="", visible=False) @@ -291,7 +330,22 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: # Shows the list of selected tables when "Choose your models" is clicked open_model_selection.click(fn=show_selected_table_names, inputs=[table_selector], outputs=[selected_table_names]) open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc]) - + + reset_data = gr.Button("Back to upload data section") + + reset_data.click( + fn=enable_disable_first, + inputs=[gr.State(True)], + outputs=[ + preview_output, + submit_button, + file_input, + default_checkbox + ] + ) + + reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) + #################################### # MODEL SELECTION PART # @@ -322,18 +376,42 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: cols.append(checkbox) rows.append(cols) - selected_models_output = gr.JSON(visible=False) + selected_models_output = gr.JSON(visible=True) # Function to get selected models def get_selected_models(*model_selections): selected_models = [model for model, selected in zip(model_list, model_selections) if selected] input_data['models'] = selected_models - button_state = bool(selected_models) # True if at least one model is selected, False otherwise + button_state = bool(selected_models and '{schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state) + # Add the Textbox to the interface + prompt = gr.TextArea(label="Customise the prompt for selected models here or leave the default one . The prompt must contain {question} and {schema} which will be automatically replaced during SQL generation.", + placeholder='Default prompt with a {question} and db {schema} are to be specified') + warning_prompt = gr.Markdown(value="# Error in the prompt format", visible=False) + # Submit button (initially disabled) + submit_models_button = gr.Button("Submit Models", interactive=False) + def check_prompt(prompt): + #TODO + missing_elements = [] + if(prompt==""): + input_data["prompt"]="{question} {schema}" + button_state = bool(len(input_data['models']) > 0 and '{schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) + else: + input_data["prompt"]=prompt + if "{schema}" not in prompt: + missing_elements.append("{schema}") + if "{question}" not in prompt: + missing_elements.append("{question}") + button_state = bool(len(input_data['models']) > 0 and '{schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) + if missing_elements: + return gr.update(value=f"## ❌ Missing {', '.join(missing_elements)} in the prompt ❌", visible=True), gr.update(interactive=button_state) + return gr.update(visible=False), gr.update(interactive=button_state) + + prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button]) # Link checkboxes to selection events for checkbox in model_checkboxes: checkbox.change( @@ -341,13 +419,18 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: inputs=model_checkboxes, outputs=[selected_models_output, select_model_acc, submit_models_button] ) + prompt.change( + fn=get_selected_models, + inputs=model_checkboxes, + outputs=[selected_models_output, select_model_acc, submit_models_button] + ) submit_models_button.click( fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)), inputs=model_checkboxes, outputs=[selected_models_output, select_model_acc, qatch_acc] ) - + def enable_disable(enable): return ( *[gr.update(interactive=enable) for _ in model_checkboxes], @@ -397,7 +480,6 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: ] ) - ############################# # QATCH EXECUTION # ############################# @@ -422,89 +504,172 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: symbols = loading_symbols.get(num_symbols, "π“†Ÿ") mirrored_symbols = f'{symbols.strip()}' css_symbols = f'{symbols.strip()}' - return f"
{css_symbols} Generation {percent}%{mirrored_symbols}
" + return f""" +
+ {css_symbols} + + Generation {percent}% + + {mirrored_symbols} +
+ """ + #return f"{css_symbols}"+f"# Loading {percent}% #"+f"{mirrored_symbols}" def qatch_flow(): - orchestrator_generator = OrchestratorGenerator() - # TODO: add to target_df column target_df["columns_used"], tables selection - # print(input_data['data']['db']) - target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db']) - - schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string( - db_id = input_data["db_name"], - base_path = input_data["data_path"], - normalize=False, - sql=None - ) - - # TODO: QUERY PREDICTION + #caching predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list} metrics_conc = pd.DataFrame() - - for model in input_data["models"]: - model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None) - yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] - - for index, row in target_df.iterrows(): - - percent_complete = round(((index+1) / len(target_df)) * 100, 2) - load_text = f"{generate_loading_text(percent_complete)}" - - question = row['question'] - display_question = f"
Natural Language:
{row['question']}
" - # yield gr.Textbox(question), gr.Textbox(), *[predictions_dict[model] for model in input_data["models"]], None - - yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] - start_time = time.time() - - # Simulate prediction - time.sleep(0.4) - prediction = "Prediction_placeholder" - display_prediction = f"
Generated SQL:
{prediction}
" - # Run real prediction here - # prediction = predictor.run(model, schema_text, question) - - end_time = time.time() - # Create a new row as dataframe - new_row = pd.DataFrame([{ - 'id': index, - 'question': question, - 'predicted_sql': prediction, - 'time': end_time - start_time, - 'query': row["query"], - 'db_path': input_data["data_path"] - }]).dropna(how="all") # Remove only completely empty rows - - # TODO: use a for loop - for col in target_df.columns: - if col not in new_row.columns: - new_row[col] = row[col] - - # Update model's prediction dataframe incrementally - if not new_row.empty: - predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True) - - # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None + if (input_data['input_method']=="default"): + target_df = us.load_csv(pnp_path) #target_df = us.load_csv("priority_non_priority_metrics.csv") + #predictions_dict = {model: pd.DataFrame(columns=target_df.columns) for model in model_list} + target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])] + target_df = target_df[target_df["model"].isin(input_data['models'])] + predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list} + for model in target_df["model"].unique(): + model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None) + yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] + count=1 + for _, row in predictions_dict[model].iterrows(): + #for index, row in target_df.iterrows(): + percent_complete = round(count / len(predictions_dict[model]) * 100, 2) + count=count+1 + load_text = f"{generate_loading_text(percent_complete)}" + question = row['question'] + + # display_question = f""" + #
+ # Natural Language: + #
+ #
+ # {row['question']} + #
+ # """ + display_question = f"""
Natural Language:
+
+
{question}
+
➑️
+
+ """ + yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] + #time.sleep(0.02) + prediction = row['predicted_sql'] + + # display_prediction = f""" + #
+ # Generated SQL: + #
+ #
+ # {prediction} + #
+ # """ + display_prediction = f"""
Natural Language:
+
+
➑️
+
{prediction}
+
+ """ + yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list] yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list] + metrics_conc = target_df + if 'valid_efficiency_score' not in metrics_conc.columns: + metrics_conc['valid_efficiency_score'] = metrics_conc['VES'] + yield gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] + else: - yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list] - # END - evaluator = OrchestratorEvaluator() - for model in input_data["models"]: - metrics_df_model = evaluator.evaluate_df( - df=predictions_dict[model], - target_col_name="query", - prediction_col_name="predicted_sql", - db_path_name="db_path" + orchestrator_generator = OrchestratorGenerator() + # TODO: add to target_df column target_df["columns_used"], tables selection + # print(input_data['data']['db']) + #print(input_data['data']['selected_tables']) + target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=input_data['data']['selected_tables']) + #target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None) + + schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string( + db_id = input_data["db_name"], + base_path = input_data["data_path"], + normalize=False, + sql=None ) - metrics_df_model['model'] = model - metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True) - + + predictor = ModelPrediction() + + for model in input_data["models"]: + model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None) + yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] + count=0 + for index, row in target_df.iterrows(): + + percent_complete = round(((index+1) / len(target_df)) * 100, 2) + load_text = f"{generate_loading_text(percent_complete)}" + + question = row['question'] + #display_question = f"
Natural Language:
{row['question']}
" + display_question = f"""
Natural Language:
+
+
{question}
+
➑️
+
+ """ + yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] + start_time = time.time() + samples = us.generate_some_samples(input_data['data']['db'], row["tbl_name"]) + prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples) + #PREDICTION SQL + #response = predictor.make_prediction(question=question, db_schema=schema_text, model_name=model, prompt=f"{prompt_to_send}") + prediction = row["query"]#"SQL"#response[response_parsed] + price = 0.0#response[cost] + answer = "Answer"#response[response] + + end_time = time.time() + #display_prediction = f"
Generated SQL:
{prediction}
" + display_prediction = f"""
Natural Language:
+
+
➑️
+
{prediction}
+
+ """ + # Create a new row as dataframe + new_row = pd.DataFrame([{ + 'id': index, + 'question': question, + 'predicted_sql': prediction, + 'time': end_time - start_time, + 'query': row["query"], + 'db_path': input_data["data_path"], + 'price':price, + 'answer':answer, + 'number_question':count + }]).dropna(how="all") # Remove only completely empty rows + count=count+1 + # TODO: use a for loop + for col in target_df.columns: + if col not in new_row.columns: + new_row[col] = row[col] + + # Update model's prediction dataframe incrementally + if not new_row.empty: + predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True) + + # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None + yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list] + + yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list] + # END + evaluator = OrchestratorEvaluator() + for model in input_data["models"]: + metrics_df_model = evaluator.evaluate_df( + df=predictions_dict[model], + target_col_name="query", + prediction_col_name="predicted_sql", + db_path_name="db_path" + ) + metrics_df_model['model'] = model + metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True) + if 'valid_efficiency_score' not in metrics_conc.columns: metrics_conc['valid_efficiency_score'] = metrics_conc['VES'] - yield gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] + yield gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] # Loading Bar with gr.Row(): @@ -516,15 +681,11 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: with gr.Column(): with gr.Column(): question_display = gr.Markdown() - with gr.Column(): - gr.Markdown("
‴
") with gr.Column(): model_logo = gr.Image(visible=True, show_label=False) with gr.Column(): with gr.Column(): prediction_display = gr.Markdown() - with gr.Column(): - gr.Markdown("
‴
") dataframe_per_model = {} @@ -611,23 +772,35 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: open_model_selection ] ) - - + + + + ########################################## # METRICS VISUALIZATION SECTION # ########################################## with metrics_acc: - #confirmation_text = gr.Markdown("## Metrics successfully loaded") - - data_path = 'test_results.csv' + #data_path = 'test_results_metrics1.csv' + data_path = '/Users/francescogiannuzzo/Desktop/EURECOM/semester_project_gradio_git/Automatic-LLM-Benchmark-Analysis-for-Text2SQL-GRADIO/data/evaluation_p_metrics.csv' @gr.render(inputs=metrics_df_out) def function_metrics(metrics_df_out): + + #################################### + # UTILS FUNCTIONS SECTION # + #################################### + def load_data_csv_es(): - return pd.read_csv(data_path) - #return metrics_df_out + #return pd.read_csv(data_path) + #print("---------------->",metrics_df_out) + return metrics_df_out def calculate_average_metrics(df, selected_metrics): + # Exclude the 'tuple_order' column from the selected metrics + + #TODO tuple_order has NULL VALUE + selected_metrics = [metric for metric in selected_metrics if metric != 'tuple_order'] + #print(df[selected_metrics]) df['avg_metric'] = df[selected_metrics].mean(axis=1) return df @@ -646,198 +819,723 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: return colors MODEL_COLORS = generate_model_colors() + + def generate_db_category_colors(): + """Assigns 3 distinct colors to db_category groups.""" + return { + "Spider": "#1f77b4", # blu + "Beaver": "#ff7f0e", # arancione + "Economic": "#2ca02c", # tutti gli altri verdi + "Financial": "#2ca02c", + "Medical": "#2ca02c", + "Miscellaneous": "#2ca02c" + } + + DB_CATEGORY_COLORS = generate_db_category_colors() + + def normalize_valid_efficiency_score(df): + #TODO valid_efficiency_score + #print(df['valid_efficiency_score']) + df['valid_efficiency_score'] = df['valid_efficiency_score'].replace('', 0) + df['valid_efficiency_score'] = df['valid_efficiency_score'].astype(int) + min_val = df['valid_efficiency_score'].min() + max_val = df['valid_efficiency_score'].max() + + if min_val == max_val: + # Tutti i valori sono uguali, assegna 1.0 a tutto per evitare divisione per zero + df['valid_efficiency_score'] = 1.0 + else: + df['valid_efficiency_score'] = ( + df['valid_efficiency_score'] - min_val + ) / (max_val - min_val) + + return df + + + + + #################################### + # GRAPH FUNCTIONS SECTION # + #################################### # BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION - def plot_metric(df, selected_metrics, group_by, selected_models): + def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models): df = df[df['model'].isin(selected_models)] + df = normalize_valid_efficiency_score(df) + + # Mappatura nomi leggibili -> tecnici + qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics] + external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric] + + selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal + df = calculate_average_metrics(df, selected_metrics) + + if group_by == ["model"]: + # Bar plot per "model" + avg_metrics = df.groupby("model")['avg_metric'].mean().reset_index() + avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') + + fig = px.bar( + avg_metrics, + x="model", + y="avg_metric", + color="model", + color_discrete_map=MODEL_COLORS, + title='Average metric per Model 🧠', + labels={"model": "Model", "avg_metric": "Average Metric"}, + template='plotly_dark', + text='text_label' + ) + else: + if group_by != ["tbl_name", "model"]: + group_by = ["tbl_name", "model"] + + avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index() + avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') + + fig = px.bar( + avg_metrics, + x=group_by[0], + y='avg_metric', + color='model', + color_discrete_map=MODEL_COLORS, + barmode='group', + title=f'Average metric per {group_by[0]} πŸ“Š', + labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metric'}, + template='plotly_dark', + text='text_label' + ) + + fig.update_traces(textposition='outside', textfont_size=10) + + # font Playfair Display + fig.update_layout( + margin=dict(t=80), + title=dict( + font=dict( + family="Playfair Display, serif", + size=22, + color="white" + ), + x=0.5 + ), + xaxis=dict( + title=dict( + font=dict( + family="Playfair Display, serif", + size=16, + color="white" + ) + ), + tickfont=dict( + family="Playfair Display, serif", + color="white" + ) + ), + yaxis=dict( + title=dict( + font=dict( + family="Playfair Display, serif", + size=16, + color="white" + ) + ), + tickfont=dict( + family="Playfair Display, serif", + color="white" + ) + ), + legend=dict( + title=dict( + font=dict( + family="Playfair Display, serif", + size=14, + color="white" + ) + ), + font=dict( + family="Playfair Display, serif", + color="white" + ) + ) + ) + + return gr.Plot(fig, visible=True) + + def update_plot(radio_metric, qatch_selected_metrics, external_selected_metric,group_by, selected_models): + df = load_data_csv_es() + return plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models) + + # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION + def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models): + if selected_models == "All": + selected_models = models + else: + selected_models = [selected_models] - # Ensure the group_by value is always valid - if group_by not in [["tbl_name", "model"], ["model"]]: - group_by = ["tbl_name", "model"] # Default + df = df[df['model'].isin(selected_models)] + df = normalize_valid_efficiency_score(df) - avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index() + # Converti nomi leggibili -> tecnici + qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics] + external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric] + + selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal + + df = calculate_average_metrics(df, selected_metrics) + + avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index() + avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') fig = px.bar( avg_metrics, - x=group_by[0], + x='db_category', y='avg_metric', color='model', color_discrete_map=MODEL_COLORS, barmode='group', - title=f'Average metric per {group_by[0]} πŸ“Š', - labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metric'}, - template='plotly_dark' + title='Average metric per db_category πŸ“Š', + labels={'db_path': 'DB Path', 'avg_metric': 'Average Metric'}, + template='simple_white', + text='text_label' ) - - return gr.Plot(fig, visible=True) - def update_plot(selected_metrics, group_by, selected_models): - df = load_data_csv_es() - return plot_metric(df, selected_metrics, group_by, selected_models) + fig.update_traces(textposition='outside', textfont_size=10) + #Playfair Display + fig.update_layout( + margin=dict(t=80), + title=dict( + font=dict( + family="Playfair Display, serif", + size=22, + color="black" + ), + x=0.5 + ), + xaxis=dict( + title=dict( + text='DB Category', + font=dict( + family='Playfair Display, serif', + size=16, + color='black' + ) + ), + tickfont=dict( + family='Playfair Display, serif', + color='black' + ) + ), + yaxis=dict( + title=dict( + text='Average Metric', + font=dict( + family='Playfair Display, serif', + size=16, + color='black' + ) + ), + tickfont=dict( + family='Playfair Display, serif', + color='black' + ) + ), + legend=dict( + title=dict( + text='Models', + font=dict( + family='Playfair Display, serif', + size=14, + color='black' + ) + ), + font=dict( + family='Playfair Display, serif', + color='black' + ) + ) + ) - # RADAR CHART FOR AVERAGE METRICS PER MODEL WITH UPDATE FUNCTION - def plot_radar(df, selected_models): - # Filter only selected models - df = df[df['model'].isin(selected_models)] + return gr.Plot(fig, visible=True) + + """ + def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models): + if selected_models == "All": + selected_models = models + else: + selected_models = [selected_models] - # Select relevant metrics - selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"] + df = df[df['model'].isin(selected_models)] + df = normalize_valid_efficiency_score(df) - # Compute average metrics per test_category and model + if radio_metric == "Qatch": + selected_metrics = qatch_selected_metrics + else: + selected_metrics = external_selected_metric + df = calculate_average_metrics(df, selected_metrics) - avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index() - # Check if data is available - if avg_metrics.empty: - print("Error: No data available to compute averages.") - return go.Figure() + # Raggruppamento per modello e categoria + avg_metrics = df.groupby(["model", "db_category"])['avg_metric'].mean().reset_index() + avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') - fig = go.Figure() - categories = avg_metrics['test_category'].unique() - - for model in selected_models: - model_data = avg_metrics[avg_metrics['model'] == model] - - # Build a list of values for each category (if a value is missing, set it to 0) - values = [ - model_data[model_data['test_category'] == cat]['avg_metric'].values[0] - if cat in model_data['test_category'].values else 0 - for cat in categories - ] - - fig.add_trace(go.Scatterpolar( - r=values, - theta=categories, - fill='toself', - name=model, - line=dict(color=MODEL_COLORS.get(model, "gray")) - )) + # Plot orizzontale con modello sull'asse Y + fig = px.bar( + avg_metrics, + x='avg_metric', + y='model', + color='db_category', # categoria come colore + text='text_label', + barmode='group', + orientation='h', + color_discrete_map=DB_CATEGORY_COLORS, # devi avere questo dict come MODEL_COLORS + title='Average metric per model and db_category πŸ“Š', + labels={'avg_metric': 'AVG Metric', 'model': 'Model'}, + template='plotly_dark' + ) + fig.update_traces(textposition='outside', textfont_size=10) fig.update_layout( - polar=dict(radialaxis=dict(visible=True, range=[0, max(avg_metrics['avg_metric'].max(), 0.5)])), # Set the radar range - title='❇️ Radar Plot of Metrics per Model (Average per Category) ❇️ ', - template='plotly_dark', - width=700, height=700 + margin=dict(t=80), + yaxis=dict(title=''), + xaxis=dict(title='AVG Metrics'), + legend_title='DB Name', + height=600 # puoi aumentare se ci sono tanti modelli ) - return fig - - def update_radar(selected_models): + return gr.Plot(fig, visible=True) + """ + + def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models): + df = load_data_csv_es() + return plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models) + + + # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION + + def lollipop_propietary(): df = load_data_csv_es() - return plot_radar(df, selected_models) - - # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION - def plot_cumulative_flow(df, selected_models): - df = df[df['model'].isin(selected_models)] + # Filtra solo le categorie rilevanti + target_cats = ["Spider", "Economic", "Financial", "Medical", "Miscellaneous"] + df = df[df['db_category'].isin(target_cats)] + df = normalize_valid_efficiency_score(df) + df = calculate_average_metrics(df, qatch_metrics) + + # Calcola la media per db_category e modello + avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index() + + # Separa Spider e le altre 4 categorie + spider_df = avg_metrics[avg_metrics["db_category"] == "Spider"] + other_df = avg_metrics[avg_metrics["db_category"] != "Spider"] + + # Calcola media delle altre categorie per ciascun modello + other_mean_df = other_df.groupby("model")["avg_metric"].mean().reset_index() + other_mean_df["db_category"] = "Others" + + # Rinominare per chiarezza e uniformitΓ  + spider_df = spider_df.rename(columns={"avg_metric": "Spider"}) + other_mean_df = other_mean_df.rename(columns={"avg_metric": "Others"}) + + # Unione dei due dataset + merged_df = pd.merge(spider_df[["model", "Spider"]], other_mean_df[["model", "Others"]], on="model") + + # Ordina per modello o per valore se vuoi + merged_df = merged_df.sort_values(by="model") + fig = go.Figure() - - for model in selected_models: - model_df = df[df['model'] == model].copy() - - # Calculate cumulative time - model_df['cumulative_time'] = model_df['time'].cumsum() - - # Calculate cumulative number of queries over time - model_df['cumulative_queries'] = range(1, len(model_df) + 1) - - # Select a color for the model - color = MODEL_COLORS.get(model, "gray") # Assigned model color - fillcolor = color.replace("rgb", "rgba").replace(")", ", 0.2)") # πŸ”Ή Makes the area semi-transparent - #color = f"rgba({hash(model) % 256}, {hash(model * 2) % 256}, {hash(model * 3) % 256}, 1)" - + # Aggiungi linee orizzontali tra Spider e Others + for _, row in merged_df.iterrows(): fig.add_trace(go.Scatter( - x=model_df['cumulative_time'], - y=model_df['cumulative_queries'], - mode='lines+markers', - name=model, - line=dict(width=2, color=color) + x=[row["Spider"], row["Others"]], + y=[row["model"]] * 2, + mode='lines', + line=dict(color='gray', width=2), + showlegend=False )) - - # Adds the underlying colored area (same color but transparent) - """ - fig.add_trace(go.Scatter( - x=model_df['cumulative_time'], - y=model_df['cumulative_queries'], - fill='tozeroy', - mode='none', - showlegend=False, # Hides the area in the legend - fillcolor=fillcolor - )) - """ + + # Punto per Spider + fig.add_trace(go.Scatter( + x=merged_df["Spider"], + y=merged_df["model"], + mode='markers', + name='Spider', + marker=dict(size=10, color='red') + )) + + # Punto per Others (media delle altre 4 categorie) + fig.add_trace(go.Scatter( + x=merged_df["Others"], + y=merged_df["model"], + mode='markers', + name='Others Avg', + marker=dict(size=10, color='blue') + )) fig.update_layout( - title="Cumulative Query Flow Chart πŸ“ˆ", - xaxis_title="Cumulative Time (s)", - yaxis_title="Number of Queries Completed", - template='plotly_dark', - legend_title="Models" + title='Dot-Range Plot: Spider vs Altri πŸ•·οΈπŸ“Š', + xaxis_title='Average Metric', + yaxis_title='Model', + template='simple_white', + #template='plotly_dark', + margin=dict(t=80), + legend_title='Categoria', + height=600 ) - + + return gr.Plot(fig, visible=True) + + + # RADAR OR BAR CHART BASED ON CATEGORY COUNT + def plot_radar(df, selected_models, selected_metrics, selected_categories): + if "external" in selected_metrics: + selected_metrics = ["execution_accuracy", "valid_efficiency_score"] + else: + selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] + + # Filtro modelli e normalizzazione + df = df[df['model'].isin(selected_models)] + df = normalize_valid_efficiency_score(df) + df = calculate_average_metrics(df, selected_metrics) + + avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index() + + if avg_metrics.empty: + print("Error: No data available to compute averages.") + return go.Figure() + + categories = selected_categories + + if len(categories) < 3: + # πŸ”„ BAR PLOT + fig = go.Figure() + for model in selected_models: + model_data = avg_metrics[avg_metrics['model'] == model] + values = [ + model_data[model_data['test_category'] == cat]['avg_metric'].values[0] + if cat in model_data['test_category'].values else 0 + for cat in categories + ] + fig.add_trace(go.Bar( + x=categories, + y=values, + name=model, + marker=dict(color=MODEL_COLORS.get(model, "gray")) + )) + + fig.update_layout( + barmode='group', + title=dict( + text='πŸ“Š Bar Plot of Metrics per Model (Few Categories)', + font=dict( + family='Playfair Display, serif', + size=22, + color='white' + ), + x=0.5 + ), + template='plotly_dark', + xaxis=dict( + title=dict( + text='Test Category', + font=dict( + family='Playfair Display, serif', + size=16, + color='white' + ) + ), + tickfont=dict( + family='Playfair Display, serif', + color='white' + ) + ), + yaxis=dict( + title=dict( + text='Average Metric', + font=dict( + family='Playfair Display, serif', + size=16, + color='white' + ) + ), + tickfont=dict( + family='Playfair Display, serif', + color='white' + ) + ), + legend=dict( + title=dict( + text='Models', + font=dict( + family='Playfair Display, serif', + size=14, + color='white' + ) + ), + font=dict( + family='Playfair Display, serif', + color='white' + ) + ) + ) + else: + # 🧭 RADAR PLOT + fig = go.Figure() + for model in selected_models: + model_data = avg_metrics[avg_metrics['model'] == model] + values = [ + model_data[model_data['test_category'] == cat]['avg_metric'].values[0] + if cat in model_data['test_category'].values else 0 + for cat in categories + ] + fig.add_trace(go.Scatterpolar( + r=values, + theta=categories, + fill='toself', + name=model, + line=dict(color=MODEL_COLORS.get(model, "gray")) + )) + + fig.update_layout( + polar=dict( + radialaxis=dict( + visible=True, + range=[0, max(avg_metrics['avg_metric'].max(), 0.5)], + tickfont=dict( + family='Playfair Display, serif', + color='white' + ) + ), + angularaxis=dict( + tickfont=dict( + family='Playfair Display, serif', + color='white' + ) + ) + ), + title=dict( + text='❇️ Radar Plot of Metrics per Model (Average per Category)', + font=dict( + family='Playfair Display, serif', + size=22, + color='white' + ), + x=0.5 + ), + legend=dict( + title=dict( + text='Models', + font=dict( + family='Playfair Display, serif', + size=14, + color='white' + ) + ), + font=dict( + family='Playfair Display, serif', + color='white' + ) + ), + template='plotly_dark' + ) + return fig - def update_query_rate(selected_models): + def update_radar(selected_models, selected_metrics, selected_categories): df = load_data_csv_es() - return plot_cumulative_flow(df, selected_models) + return plot_radar(df, selected_models, selected_metrics, selected_categories) + # RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT + def plot_radar_sub(df, selected_models, selected_metrics, selected_category): + if "external" in selected_metrics: + selected_metrics = ["execution_accuracy", "valid_efficiency_score"] + else: + selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] - # RANKING FOR THE TOP 3 MODELS WITH UPDATE FUNCTION - def ranking_text(df, selected_models, ranking_type): - #df = load_data_csv_es() df = df[df['model'].isin(selected_models)] - df['valid_efficiency_score'] = pd.to_numeric(df['valid_efficiency_score'], errors='coerce') - if ranking_type == "valid_efficiency_score": - rank_df = df.groupby('model')['valid_efficiency_score'].mean().reset_index() - #rank_df = df.groupby('model')['valid_efficiency_score'].mean().reset_index() - ascending_order = False # Higher is better - elif ranking_type == "time": - rank_df = df.groupby('model')['time'].sum().reset_index() - rank_df["Ranking Value"] = rank_df["time"].round(2).astype(str) + " s" # Adds "s" for seconds - ascending_order = True # For time, lower is better - elif ranking_type == "metrics": - selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"] - df = calculate_average_metrics(df, selected_metrics) - rank_df = df.groupby('model')['avg_metric'].mean().reset_index() - ascending_order = False # Higher is better - - if ranking_type != "time": - rank_df.rename(columns={rank_df.columns[1]: "Ranking Value"}, inplace=True) - rank_df["Ranking Value"] = rank_df["Ranking Value"].round(2) # Round values except for time - - # Sort based on the selected criterion - rank_df = rank_df.sort_values(by="Ranking Value", ascending=ascending_order).reset_index(drop=True) - - # Select only the top 3 models - rank_df = rank_df.head(3) - - # Add medal icons for the top 3 - medals = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"] - rank_df.insert(0, "Rank", medals[:len(rank_df)]) + df = normalize_valid_efficiency_score(df) + df = calculate_average_metrics(df, selected_metrics) - # Build the formatted ranking string - ranking_str = "## πŸ† Model Ranking\n" - for _, row in rank_df.iterrows(): - ranking_str += f"{row['Rank']} {row['model']} ({row['Ranking Value']})
\n" - - return ranking_str + if isinstance(selected_category, str): + selected_category = [selected_category] - def update_ranking_text(selected_models, ranking_type): - df = load_data_csv_es() - return ranking_text(df, selected_models, ranking_type) + df = df[df['test_category'].isin(selected_category)] + avg_metrics = df.groupby(['model', 'sql_tag'])['avg_metric'].mean().reset_index() + + if avg_metrics.empty: + print("Error: No data available to compute averages.") + return go.Figure() + + categories = df['sql_tag'].unique().tolist() + + if len(categories) < 3: + # πŸ”„ BAR PLOT + fig = go.Figure() + for model in selected_models: + model_data = avg_metrics[avg_metrics['model'] == model] + values = [ + model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0] + if cat in model_data['sql_tag'].values else 0 + for cat in categories + ] + fig.add_trace(go.Bar( + x=categories, + y=values, + name=model, + marker=dict(color=MODEL_COLORS.get(model, "gray")) + )) + + fig.update_layout( + barmode='group', + title=dict( + text='πŸ“Š Bar Plot of Metrics per Model (Few Sub-Categories)', + font=dict( + family='Playfair Display, serif', + size=22, + color='white' + ), + x=0.5 + ), + template='plotly_dark', + xaxis=dict( + title=dict( + text='SQL Tag (Sub Category)', + font=dict( + family='Playfair Display, serif', + size=16, + color='white' + ) + ), + tickfont=dict( + family='Playfair Display, serif', + color='white' + ) + ), + yaxis=dict( + title=dict( + text='Average Metric', + font=dict( + family='Playfair Display, serif', + size=16, + color='white' + ) + ), + tickfont=dict( + family='Playfair Display, serif', + color='white' + ) + ), + legend=dict( + title=dict( + text='Models', + font=dict( + family='Playfair Display, serif', + size=14, + color='white' + ) + ), + font=dict( + family='Playfair Display, serif', + color='white' + ) + ) + ) + else: + # 🧭 RADAR PLOT + fig = go.Figure() + for model in selected_models: + model_data = avg_metrics[avg_metrics['model'] == model] + values = [ + model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0] + if cat in model_data['sql_tag'].values else 0 + for cat in categories + ] + + fig.add_trace(go.Scatterpolar( + r=values, + theta=categories, + fill='toself', + name=model, + line=dict(color=MODEL_COLORS.get(model, "gray")) + )) + + fig.update_layout( + polar=dict( + radialaxis=dict( + visible=True, + range=[0, max(avg_metrics['avg_metric'].max(), 0.5)], + tickfont=dict( + family='Playfair Display, serif', + color='white' + ) + ), + angularaxis=dict( + tickfont=dict( + family='Playfair Display, serif', + color='white' + ) + ) + ), + title=dict( + text='❇️ Radar Plot of Metrics per Model (Average per Sub-Category)', + font=dict( + family='Playfair Display, serif', + size=22, + color='white' + ), + x=0.5 + ), + legend=dict( + title=dict( + text='Models', + font=dict( + family='Playfair Display, serif', + size=14, + color='white' + ) + ), + font=dict( + family='Playfair Display, serif', + color='white' + ) + ), + template='plotly_dark' + ) + + return fig + def update_radar_sub(selected_models, selected_metrics, selected_category): + df = load_data_csv_es() + return plot_radar_sub(df, selected_models, selected_metrics, selected_category) # RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION - def worst_cases_text(df, selected_models): + def worst_cases_text(df, selected_models, selected_metrics, selected_categories): + if selected_models == "All": + selected_models = models + else: + selected_models = [selected_models] + + if selected_categories == "All": + selected_categories = principal_categories + else: + selected_categories = [selected_categories] + df = df[df['model'].isin(selected_models)] + df = df[df['test_category'].isin(selected_categories)] - selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"] + if "external" in selected_metrics: + selected_metrics = ["execution_accuracy", "valid_efficiency_score"] + else: + selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] + + df = normalize_valid_efficiency_score(df) df = calculate_average_metrics(df, selected_metrics) - worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql'])['avg_metric'].mean().reset_index() + worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index() worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True) @@ -845,110 +1543,495 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2) - worst_str = "## ❌ Top 3 Worst Cases\n" + worst_str = [] + answer_str = [] + medals = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"] for i, row in worst_cases_top_3.iterrows(): - worst_str += ( - f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} ({row['avg_metric']}) \n" + entry = ( + f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n" f"- Question: {row['question']} \n" f"- Original Query: `{row['query']}` \n" f"- Predicted SQL: `{row['predicted_sql']}` \n\n" ) + + worst_str.append(entry) + + raw_answer = ( + f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n" + f"- Raw Answer:
`{row['answer']}`
\n" + ) + + answer_str.append(raw_answer) - return worst_str + return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2] - def update_worst_cases_text(selected_models): + def update_worst_cases_text(selected_models, selected_metrics, selected_categories): df = load_data_csv_es() - return worst_cases_text(df, selected_models) + return worst_cases_text(df, selected_models, selected_metrics, selected_categories) + + # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION + def plot_cumulative_flow(df, selected_models, max_points): + df = df[df['model'].isin(selected_models)] + df = normalize_valid_efficiency_score(df) + + fig = go.Figure() + + for model in selected_models: + model_df = df[df['model'] == model].copy() + + # Limita il numero di punti se richiesto + if max_points is not None: + model_df = model_df.head(max_points + 1) + + # Tooltip personalizzato + model_df['hover_info'] = model_df.apply( + lambda row: + f"Id question: {row['number_question']}
" + f"Question: {row['question']}
" + f"Target: {row['query']}
" + f"Prediction: {row['predicted_sql']}
" + f"Category: {row['test_category']}", + axis=1 + ) + + # Calcoli cumulativi + model_df['cumulative_time'] = model_df['time'].cumsum() + model_df['cumulative_price'] = model_df['price'].cumsum() + + # Colore del modello + color = MODEL_COLORS.get(model, "gray") + + fig.add_trace(go.Scatter( + x=model_df['cumulative_time'], + y=model_df['cumulative_price'], + mode='lines+markers', + name=model, + line=dict(width=2, color=color), + customdata=model_df['hover_info'], + hovertemplate= + "Model: " + model + "
" + + "Cumulative Time: %{x}s
" + + "Cumulative Price: $%{y:.2f}
" + + "
Details:
%{customdata}" + )) + + # Layout con font elegante + fig.update_layout( + title=dict( + text="Cumulative Price Flow Chart πŸ’°", + font=dict( + family="Playfair Display, serif", + size=24, + color="white" + ), + x=0.5 + ), + xaxis=dict( + title=dict( + text="Cumulative Time (s)", + font=dict( + family="Playfair Display, serif", + size=16, + color="white" + ) + ), + tickfont=dict( + family="Playfair Display, serif", + color="white" + ) + ), + yaxis=dict( + title=dict( + text="Cumulative Price ($)", + font=dict( + family="Playfair Display, serif", + size=16, + color="white" + ) + ), + tickfont=dict( + family="Playfair Display, serif", + color="white" + ) + ), + legend=dict( + title=dict( + text="Models", + font=dict( + family="Playfair Display, serif", + size=14, + color="white" + ) + ), + font=dict( + family="Playfair Display, serif", + color="white" + ) + ), + template="plotly_dark" + ) + + return fig + + def update_query_rate(selected_models, max_points): + df = load_data_csv_es() + return plot_cumulative_flow(df, selected_models, max_points) + + - metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"] + ####################### + # PARAMETER SECTION # + ####################### + qatch_metrics_dict = { + "Cell Precision": "cell_precision", + "Cell Recall": "cell_recall", + "Tuple Order": "tuple_order", + "Tuple Cardinality": "tuple_cardinality", + "Tuple Constraint": "tuple_constraint" + } + + qatch_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] + last_valid_qatch_metrics_selection = qatch_metrics.copy() # Per salvare l’ultima selezione valida + def enforce_qatch_metrics_selection(selected): + global last_valid_qatch_metrics_selection + if not selected: # Se nessuna metrica Γ¨ selezionata + return gr.update(value=last_valid_qatch_metrics_selection) + last_valid_qatch_metrics_selection = selected # Altrimenti aggiorna la selezione valida + return gr.update(value=selected) + + external_metrics_dict = { + "Execution Accuracy": "execution_accuracy", + "Valid Efficiency Score": "valid_efficiency_score" + } + + external_metric = ["execution_accuracy", "valid_efficiency_score"] + last_valid_external_metric_selection = external_metric.copy() + def enforce_external_metric_selection(selected): + global last_valid_external_metric_selection + if not selected: # Se nessuna metrica Γ¨ selezionata + return gr.update(value=last_valid_external_metric_selection) + last_valid_external_metric_selection = selected # Altrimenti aggiorna la selezione valida + return gr.update(value=selected) + + all_metrics = { + "Qatch": ["qatch"], + "External": ["external"] + } + group_options = { "Table": ["tbl_name", "model"], "Model": ["model"] } df_initial = load_data_csv_es() + models = df_initial['model'].unique().tolist() + last_valid_model_selection = models.copy() # Per salvare l’ultima selezione valida + def enforce_model_selection(selected): + global last_valid_model_selection + if not selected: # Se nessuna metrica Γ¨ selezionata + return gr.update(value=last_valid_model_selection) + last_valid_model_selection = selected # Altrimenti aggiorna la selezione valida + return gr.update(value=selected) + + all_categories = df_initial['sql_tag'].unique().tolist() + + principal_categories = df_initial['test_category'].unique().tolist() + last_valid_category_selection = principal_categories.copy() # Per salvare l’ultima selezione valida + def enforce_category_selection(selected): + global last_valid_category_selection + if not selected: # Se nessuna metrica Γ¨ selezionata + return gr.update(value=last_valid_category_selection) + last_valid_category_selection = selected # Altrimenti aggiorna la selezione valida + return gr.update(value=selected) + + all_categories_as_dic = {cat: [f"{cat}"] for cat in principal_categories} + + all_categories_as_dic_ranking = {cat: [f"{cat}"] for cat in principal_categories} + all_categories_as_dic_ranking["All"] = principal_categories + + all_model_as_dic = {cat: [f"{cat}"] for cat in models} + all_model_as_dic["All"] = models #with gr.Blocks(theme=gr.themes.Default(primary_hue='blue')) as demo: - gr.Markdown("""## πŸ“Š Model Performance Analysis πŸ“Š - Select one or more metrics to calculate the average and visualize histograms and radar plots. - """) - # Options selection section + + + ########################### + # VISUALIZATION SECTION # + ########################### + gr.Markdown("""# Model Performance Analysis""") + + #FOR BAR + gr.Markdown("""## Section 1: Model - Data""") with gr.Row(): + choose_metrics_bar = gr.Radio( + choices=list(all_metrics.keys()), + label="Select the metrics group that you want to use:", + value="Qatch" + ) - metric_multiselect = gr.CheckboxGroup(choices=metrics, label="Select metrics", value=metrics) - model_multiselect = gr.CheckboxGroup(choices=models, label="Select models", value=models) - group_radio = gr.Radio(choices=list(group_options.keys()), label="Select grouping", value="Table") + qatch_metric_multiselect_bar = gr.CheckboxGroup( + choices=list(qatch_metrics_dict.keys()), + label="Select one or mode Qatch metrics:", + value=list(qatch_metrics_dict.keys()), + visible=True + ) - output_plot = gr.Plot(visible=False) + external_metric_select_bar = gr.CheckboxGroup( + choices=list(external_metrics_dict.keys()), + label="Select one or more External metrics:", + visible=False + ) + + if(input_data['input_method'] == 'default'): + model_radio_bar = gr.Radio( + choices=list(all_model_as_dic.keys()), + label="Select the model that you want to use:", + value="All" + ) + else: + model_multiselect_bar = gr.CheckboxGroup( + choices=models, + label="Select one or more models:", + value=models + ) + + group_radio = gr.Radio( + choices=list(group_options.keys()), + label="Select the grouping view:", + value="Table" + ) - query_rate_plot = gr.Plot(value=update_query_rate(models)) + def toggle_metric_selector(selected_type): + if selected_type == "Qatch": + return gr.update(visible=True, value=list(qatch_metrics_dict.keys())), gr.update(visible=False, value=[]) + else: + return gr.update(visible=False, value=[]), gr.update(visible=True, value=list(external_metrics_dict.keys())) + + output_plot = gr.Plot(visible=False) + + if(input_data['input_method'] == 'default'): + with gr.Row(): + lollipop_propietary() + #FOR RADAR + gr.Markdown("""## Section 2: Model - Category""") + with gr.Row(): + all_metrics_radar = gr.Radio( + choices=list(all_metrics.keys()), + label="Select the metrics group that you want to use:", + value="Qatch" + ) + + model_multiselect_radar = gr.CheckboxGroup( + choices=models, + label="Select one or more models:", + value=models + ) + + with gr.Row(): + with gr.Column(scale=1): + category_multiselect_radar = gr.CheckboxGroup( + choices=principal_categories, + label="Select one or more categories:", + value=principal_categories + ) + with gr.Column(scale=1): + category_radio_radar = gr.Radio( + choices=list(all_categories_as_dic.keys()), + label="Select the metrics that you want to use:", + value=list(all_categories_as_dic.keys())[0] + ) + with gr.Row(): with gr.Column(scale=1): - radar_plot = gr.Plot(value=update_radar(models)) + radar_plot_multiselect = gr.Plot(value=update_radar(models, "Qatch", principal_categories)) with gr.Column(scale=1): - ranking_type_radio = gr.Radio( - ["valid_efficiency_score", "time", "metrics"], - label="Choose ranking criteria", - value="valid_efficiency_score" - ) - ranking_text_display = gr.Markdown(value=update_ranking_text(models, "valid_efficiency_score")) - worst_cases_display = gr.Markdown(value=update_worst_cases_text(models)) - - # Callback functions for updating charts - def on_change(selected_metrics, selected_group, selected_models): - return update_plot(selected_metrics, group_options[selected_group], selected_models) + radar_plot_radio = gr.Plot(value=update_radar_sub(models, "Qatch", list(all_categories_as_dic.keys())[0])) - def on_radar_change(selected_models): - return update_radar(selected_models) + #FOR RANKING + with gr.Row(): + all_metrics_ranking = gr.Radio( + choices=list(all_metrics.keys()), + label="Select the metrics group that you want to use:", + value="Qatch" + ) + + model_radio_ranking = gr.Radio( + choices=list(all_model_as_dic.keys()), + label="Select the model that you want to use:", + value="All" + ) + + category_radio_ranking = gr.Radio( + choices=list(all_categories_as_dic_ranking.keys()), + label="Select the category that you want to use", + value="All" + ) - #metrics_df_out.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot) - proceed_to_metrics_button.click(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot) + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("## ❌ 3 Worst Cases\n") + + worst_first, worst_second, worst_third, raw_first, raw_second, raw_third = update_worst_cases_text("All", "Qatch", "All") + + with gr.Row(): + first = gr.Markdown(worst_first) + + with gr.Row(): + first_button = gr.Button("Show row answer for πŸ₯‡") + + with gr.Row(): + second = gr.Markdown(worst_second) + + with gr.Row(): + second_button = gr.Button("Show row answer for πŸ₯ˆ") + + with gr.Row(): + third = gr.Markdown(worst_third) + + with gr.Row(): + third_button = gr.Button("Show row answer for πŸ₯‰") + + with gr.Column(scale=1): + gr.Markdown("""## Row Answer""") + row_answer_first = gr.Markdown(value=raw_first, visible=True) + row_answer_second = gr.Markdown(value=raw_second, visible=False) + row_answer_third = gr.Markdown(value=raw_third, visible=False) + + #FOR RATE + gr.Markdown("""## Section 3: Time - Price""") + with gr.Row(): + model_multiselect_rate = gr.CheckboxGroup( + choices=models, + label="Select one or more models:", + value=models + ) - proceed_to_metrics_button.click(update_query_rate, inputs=[model_multiselect], outputs=query_rate_plot) + with gr.Row(): + slicer = gr.Slider(minimum=0, maximum=max(df_initial["number_question"]), step=0, value=max(df_initial["number_question"]), label="Number of instances that you want to visualize") + + query_rate_plot = gr.Plot(value=update_query_rate(models, len(df_initial["number_question"].unique()))) - metric_multiselect.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot) - group_radio.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot) - model_multiselect.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot) - model_multiselect.change(update_radar, inputs=model_multiselect, outputs=radar_plot) - model_multiselect.change(update_ranking_text, inputs=[model_multiselect, ranking_type_radio], outputs=ranking_text_display) - ranking_type_radio.change(update_ranking_text, inputs=[model_multiselect, ranking_type_radio], outputs=ranking_text_display) - model_multiselect.change(update_worst_cases_text, inputs=model_multiselect, outputs=worst_cases_display) - model_multiselect.change(update_query_rate, inputs=[model_multiselect], outputs=query_rate_plot) + #FOR RESET reset_data = gr.Button("Back to upload data section") - reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) + + + + + ############################### + # CALLBACK FUNCTION SECTION # + ############################### - reset_data.click( - fn=lambda: gr.update(visible=False), - outputs=[download_metrics] - ) - reset_data.click( - fn=lambda: gr.update(visible=False), - outputs=[download_metrics] - ) - reset_data.click( - fn=enable_disable, - inputs=[gr.State(True)], - outputs=[ - *model_checkboxes, - submit_models_button, - preview_output, - submit_button, - file_input, - default_checkbox, - table_selector, - *table_outputs, - open_model_selection - ] - ) + #FOR BAR + def on_change(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_group, selected_models): + return update_plot(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, group_options[selected_group], selected_models) + + def on_change_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models): + return update_plot_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models) + + #FOR RADAR + def on_radar_multiselect_change(selected_models, selected_metrics, selected_categories): + return update_radar(selected_models, selected_metrics, selected_categories) + + def on_radar_radio_change(selected_models, selected_metrics, selected_category): + return update_radar_sub(selected_models, selected_metrics, selected_category) + + #FOR RANKING + def on_ranking_change(selected_models, selected_metrics, selected_categories): + return update_worst_cases_text(selected_models, selected_metrics, selected_categories) + + def show_first(): + return ( + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=False) + ) + + def show_second(): + return ( + gr.update(visible=False), + gr.update(visible=True), + gr.update(visible=False) + ) + + def show_third(): + return ( + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=True) + ) + + + + + ###################### + # ON CLICK SECTION # + ###################### + + #FOR BAR + if(input_data['input_method'] == 'default'): + proceed_to_metrics_button.click(on_change_propietary, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) + qatch_metric_multiselect_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) + external_metric_select_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) + model_radio_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) + qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar) + choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_metric_multiselect_bar, external_metric_select_bar]) + external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar) + + else: + proceed_to_metrics_button.click(on_change, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) + qatch_metric_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) + external_metric_select_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) + group_radio.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) + model_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) + qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar) + model_multiselect_bar.change(fn=enforce_model_selection, inputs=model_multiselect_bar, outputs=model_multiselect_bar) + choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_metric_multiselect_bar, external_metric_select_bar]) + external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar) + + + #FOR RADAR MULTISELECT + model_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) + all_metrics_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) + category_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) + model_multiselect_radar.change(fn=enforce_model_selection, inputs=model_multiselect_radar, outputs=model_multiselect_radar) + category_multiselect_radar.change(fn=enforce_category_selection, inputs=category_multiselect_radar, outputs=category_multiselect_radar) + + #FOR RADAR RADIO + model_multiselect_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) + all_metrics_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) + category_radio_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) + #FOR RANKING + model_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) + model_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) + all_metrics_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) + all_metrics_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) + category_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) + category_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) + model_radio_ranking.change(fn=enforce_model_selection, inputs=model_radio_ranking, outputs=model_radio_ranking) + category_radio_ranking.change(fn=enforce_category_selection, inputs=category_radio_ranking, outputs=category_radio_ranking) + first_button.click(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) + second_button.click(fn=show_second, outputs=[row_answer_first, row_answer_second, row_answer_third]) + third_button.click(fn=show_third, outputs=[row_answer_first, row_answer_second, row_answer_third]) + #FOR RATE + model_multiselect_rate.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) + proceed_to_metrics_button.click(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) + model_multiselect_rate.change(fn=enforce_model_selection, inputs=model_multiselect_rate, outputs=model_multiselect_rate) + slicer.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) + + #FOR RESET + reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) + reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics]) + reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection]) + interface.launch() \ No newline at end of file