simone-papicchio franceth commited on
Commit
aff05a7
·
verified ·
1 Parent(s): ac204fe

More stable version. Link all acc, but still miss prediction (#6)

Browse files

- More stable version. Link all acc, but still miss prediction (696b8fe88b9890ac834f2892408dc569ef849c2f)


Co-authored-by: Francesco Giannuzzo <[email protected]>

Files changed (5) hide show
  1. app.py +828 -559
  2. style.css +6 -1
  3. test_results.csv +101 -0
  4. utilities.py +8 -3
  5. utils_get_db_tables_info.py +94 -0
app.py CHANGED
@@ -1,560 +1,829 @@
1
-
2
- import os
3
- # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
4
- if os.environ.get("SPACES_ZERO_GPU") is not None:
5
- import spaces
6
- else:
7
- class spaces:
8
- @staticmethod
9
- def GPU(func):
10
- def wrapper(*args, **kwargs):
11
- return func(*args, **kwargs)
12
- return wrapper
13
-
14
- import gradio as gr
15
- import pandas as pd
16
- import os
17
- from qatch.connectors.sqlite_connector import SqliteConnector
18
- from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
19
- from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
20
- import utilities as us
21
- import plotly.express as px
22
- import plotly.graph_objects as go
23
-
24
- @spaces.GPU
25
- def model_prediction():
26
- pass
27
-
28
- with open('style.css', 'r') as file:
29
- css = file.read()
30
-
31
- # DataFrame di default
32
- df_default = pd.DataFrame({
33
- 'Name': ['Alice', 'Bob', 'Charlie'],
34
- 'Age': [25, 30, 35],
35
- 'City': ['New York', 'Los Angeles', 'Chicago']
36
- })
37
-
38
- models_path = "models.csv"
39
-
40
- # Variabile globale per tenere traccia dei dati correnti
41
- df_current = df_default.copy()
42
-
43
- input_data = {
44
- 'input_method': "",
45
- 'data_path': "",
46
- 'db_name': "",
47
- 'data': {
48
- 'data_frames': {}, # dictionary of dataframes
49
- 'db': None # SQLITE3 database object
50
- },
51
- 'models': []
52
- }
53
- def load_data(file, path, use_default):
54
- """Carica i dati da un file, un percorso o usa il DataFrame di default."""
55
- global df_current
56
- if use_default:
57
- input_data["input_method"] = 'default'
58
- input_data["data_path"] = os.path.join(".", "data", "datainterface", "mytable.sqlite")
59
- input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0]
60
- input_data["data"]['data_frames'] = {'MyTable': df_current}
61
-
62
- #TODO assegna il db a input_data["data"]['db']
63
-
64
- df_current = df_default.copy() # Ripristina i dati di default
65
- return input_data["data"]['data_frames']
66
-
67
- selected_inputs = sum([file is not None, bool(path), use_default])
68
- if selected_inputs > 1:
69
- return 'Errore: Selezionare solo un metodo di input alla volta.'
70
-
71
- if file is not None:
72
- try:
73
- input_data["input_method"] = 'uploaded_file'
74
- input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
75
- input_data["data_path"] = os.path.join(".", "data", f"data_interface{input_data['db_name']}.sqlite")
76
- input_data["data"] = us.load_data(input_data["data_path"], input_data["db_name"])
77
- df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
78
- print(df_current)
79
- print(input_data["data"])
80
- if( input_data["data"]['data_frames'] and not input_data["data"]['db']):
81
- table2primary_key = {}
82
- print("ok")
83
- for table_name, df in input_data["data"]['data_frames'].items():
84
- # Assign primary keys for each table
85
- table2primary_key[table_name] = 'id'
86
- print("ok2")
87
- input_data["data"]["db"] = SqliteConnector(
88
- relative_db_path=input_data["data_path"],
89
- db_name=input_data["db_name"],
90
- tables= input_data["data"]['data_frames'],
91
- table2primary_key=table2primary_key
92
- )
93
- print(input_data["data"]["db"])
94
- return input_data["data"]['data_frames']
95
- except Exception as e:
96
- return f'Errore nel caricamento del file: {e}'
97
-
98
- if path:
99
- if not os.path.exists(path):
100
- return 'Errore: Il percorso specificato non esiste.'
101
- try:
102
- input_data["input_method"] = 'uploaded_file'
103
- input_data["data_path"] = path
104
- input_data["db_name"] = os.path.splitext(os.path.basename(path))[0]
105
- input_data["data"] = us.load_data(input_data["data_path"], input_data["db_name"])
106
- df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
107
-
108
- return input_data["data"]['data_frames']
109
- except Exception as e:
110
- return f'Errore nel caricamento del file dal percorso: {e}'
111
-
112
- return input_data["data"]['data_frames']
113
-
114
- def preview_default(use_default):
115
- """Mostra il DataFrame di default se il checkbox è selezionato."""
116
- if use_default:
117
- return df_default # Mostra il DataFrame di default
118
- return df_current # Mostra il DataFrame corrente, che potrebbe essere stato modificato
119
-
120
- def update_df(new_df):
121
- """Aggiorna il DataFrame corrente."""
122
- global df_current # Usa la variabile globale per aggiornarla
123
- df_current = new_df
124
- return df_current
125
-
126
- def open_accordion(target):
127
- # Apre uno e chiude l'altro
128
- if target == "reset":
129
- return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False)
130
- elif target == "model_selection":
131
- return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False)
132
-
133
- # Interfaccia Gradio
134
- interface = gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css')
135
-
136
- with interface:
137
- gr.Markdown("# QATCH")
138
- data_state = gr.State(None) # Memorizza i dati caricati
139
- upload_acc = gr.Accordion("Upload your data section", open=True, visible=True)
140
- select_table_acc = gr.Accordion("Select tables", open=False, visible=False)
141
- select_model_acc = gr.Accordion("Select models", open=False, visible=False)
142
- qatch_acc = gr.Accordion("QATCH execution", open=False, visible=False)
143
- metrics_acc = gr.Accordion("Metrics", open=False, visible=False)
144
-
145
-
146
-
147
- #################################
148
- # PARTE DI INSERIMENTO DEL DB #
149
- #################################
150
- with upload_acc:
151
- gr.Markdown("## Caricamento dei Dati")
152
-
153
- file_input = gr.File(label="Trascina e rilascia un file", file_types=[".csv", ".xlsx", ".sqlite"])
154
- path_input = gr.Textbox(label="Oppure inserisci il percorso locale del file")
155
- with gr.Row():
156
- default_checkbox = gr.Checkbox(label="Usa DataFrame di default")
157
- preview_output = gr.DataFrame(interactive=True, visible=True, value=df_default)
158
- submit_button = gr.Button("Carica Dati", interactive=False) # Disabilitato di default
159
- output = gr.JSON(visible=False) # Output dizionario
160
-
161
- # Funzione per abilitare il bottone se sono presenti dati da caricare
162
- def enable_submit(file, path, use_default):
163
- return gr.update(interactive=bool(file or path or use_default))
164
-
165
- # Abilita il bottone quando i campi di input sono valorizzati
166
- file_input.change(fn=enable_submit, inputs=[file_input, path_input, default_checkbox], outputs=[submit_button])
167
- path_input.change(fn=enable_submit, inputs=[file_input, path_input, default_checkbox], outputs=[submit_button])
168
- default_checkbox.change(fn=enable_submit, inputs=[file_input, path_input, default_checkbox], outputs=[submit_button])
169
-
170
- # Mostra l'anteprima del DataFrame di default quando il checkbox è selezionato
171
- default_checkbox.change(fn=preview_default, inputs=[default_checkbox], outputs=[preview_output])
172
- preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output])
173
-
174
- def handle_output(file, path, use_default):
175
- """Gestisce l'output quando si preme il bottone 'Carica Dati'."""
176
- result = load_data(file, path, use_default)
177
-
178
- if isinstance(result, dict): # Se result è un dizionario di DataFrame
179
- if len(result) == 1: # Se c'è solo una tabella
180
- return (
181
- gr.update(visible=False), # Nasconde l'output JSON
182
- result, # Salva lo stato dei dati
183
- gr.update(visible=False), # Nasconde la selezione tabella
184
- result, # Mantiene lo stato dei dati
185
- gr.update(interactive=False), # Disabilita il pulsante di submit
186
- gr.update(visible=True, open=True), # Passa direttamente a select_model_acc
187
- gr.update(visible=True, open=False)
188
- )
189
- else:
190
- return (
191
- gr.update(visible=False),
192
- result,
193
- gr.update(open=True, visible=True),
194
- result,
195
- gr.update(interactive=False),
196
- gr.update(visible=False), # Mantiene il comportamento attuale
197
- gr.update(visible=True, open=True)
198
- )
199
- else:
200
- return (
201
- gr.update(visible=False),
202
- None,
203
- gr.update(open=False, visible=True),
204
- None,
205
- gr.update(interactive=True),
206
- gr.update(visible=False),
207
- gr.update(visible=True, open=True)
208
- )
209
-
210
- submit_button.click(
211
- fn=handle_output,
212
- inputs=[file_input, path_input, default_checkbox],
213
- outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc]
214
- )
215
-
216
-
217
-
218
- ######################################
219
- # PARTE DI SELEZIONE DELLE TABELLE #
220
- ######################################
221
- with select_table_acc:
222
- table_selector = gr.CheckboxGroup(choices=[], label="Seleziona le tabelle da visualizzare", value=[])
223
- table_outputs = [gr.DataFrame(label=f"Tabella {i+1}", interactive=True, visible=False) for i in range(5)]
224
- selected_table_names = gr.Textbox(label="Tabelle selezionate", visible=False, interactive=False)
225
-
226
- # Bottone di selezione modelli (inizialmente disabilitato)
227
- open_model_selection = gr.Button("Choose your models", interactive=False)
228
-
229
- def update_table_list(data):
230
- """Aggiorna dinamicamente la lista delle tabelle disponibili."""
231
- if isinstance(data, dict) and data:
232
- table_names = list(data.keys()) # Ritorna solo i nomi delle tabelle
233
- return gr.update(choices=table_names, value=[]) # Reset delle selezioni
234
- return gr.update(choices=[], value=[])
235
-
236
- def show_selected_tables(data, selected_tables):
237
- """Mostra solo le tabelle selezionate dall'utente e abilita il bottone."""
238
- updates = []
239
- if isinstance(data, dict) and data:
240
- available_tables = list(data.keys()) # Nomi effettivamente disponibili
241
- selected_tables = [t for t in selected_tables if t in available_tables] # Filtra selezioni valide
242
-
243
- tables = {name: data[name] for name in selected_tables} # Filtra i DataFrame
244
-
245
- for i, (name, df) in enumerate(tables.items()):
246
- updates.append(gr.update(value=df, label=f"Tabella: {name}", visible=True))
247
-
248
- # Se ci sono meno di 5 tabelle, nascondi gli altri DataFrame
249
- for _ in range(len(tables), 5):
250
- updates.append(gr.update(visible=False))
251
- else:
252
- updates = [gr.update(value=pd.DataFrame(), visible=False) for _ in range(5)]
253
-
254
- # Abilitare/disabilitare il bottone in base alle selezioni
255
- button_state = bool(selected_tables) # True se almeno una tabella è selezionata, False altrimenti
256
- updates.append(gr.update(interactive=button_state)) # Aggiorna stato bottone
257
-
258
- return updates
259
-
260
- def show_selected_table_names(selected_tables):
261
- """Mostra i nomi delle tabelle selezionate quando si preme il bottone."""
262
- if selected_tables:
263
- return gr.update(value=", ".join(selected_tables), visible=False)
264
- return gr.update(value="", visible=False)
265
-
266
- # Aggiorna automaticamente la lista delle checkbox quando `data_state` cambia
267
- data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector])
268
-
269
- # Aggiorna le tabelle visibili e lo stato del bottone in base alle selezioni dell'utente
270
- table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
271
-
272
- # Mostra la lista delle tabelle selezionate quando si preme "Choose your models"
273
- open_model_selection.click(fn=show_selected_table_names, inputs=[table_selector], outputs=[selected_table_names])
274
- open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
275
-
276
-
277
-
278
- ####################################
279
- # PARTE DI SELEZIONE DEL MODELLO #
280
- ####################################
281
- with select_model_acc:
282
- gr.Markdown("**Model Selection**")
283
-
284
- # Supponiamo che `us.read_models_csv` restituisca anche il percorso dell'immagine
285
- model_list_dict = us.read_models_csv(models_path)
286
- model_list = [model["name"] for model in model_list_dict]
287
- model_images = [model["image_path"] for model in model_list_dict]
288
-
289
- # Creazione dinamica di checkbox con immagini
290
- model_checkboxes = []
291
- for model, image_path in zip(model_list, model_images):
292
- with gr.Row():
293
- with gr.Column(scale=1):
294
-
295
- gr.Image(image_path, show_label=False)
296
- with gr.Column(scale=2):
297
- model_checkboxes.append(gr.Checkbox(label=model, value=False))
298
-
299
- selected_models_output = gr.JSON(visible = False)
300
-
301
- # Funzione per ottenere i modelli selezionati
302
- def get_selected_models(*model_selections):
303
- selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
304
- input_data['models'] = selected_models
305
- button_state = bool(selected_models) # True se almeno un modello è selezionato, False altrimenti
306
- return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
307
-
308
- # Bottone di submit (inizialmente disabilitato)
309
- submit_models_button = gr.Button("Submit Models", interactive=False)
310
-
311
- # Collegamento dei checkbox agli eventi di selezione
312
- for checkbox in model_checkboxes:
313
- checkbox.change(
314
- fn=get_selected_models,
315
- inputs=model_checkboxes,
316
- outputs=[selected_models_output, select_model_acc, submit_models_button]
317
- )
318
-
319
- submit_models_button.click(
320
- fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
321
- inputs=model_checkboxes,
322
- outputs=[selected_models_output, select_model_acc, qatch_acc]
323
- )
324
-
325
- reset_data = gr.Button("Open upload data section")
326
- reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
327
-
328
-
329
-
330
- ###############################
331
- # PARTE DI ESECUZIONE QATCH #
332
- ###############################
333
- with qatch_acc:
334
- selected_models_display = gr.JSON(label="Modelli selezionati")
335
- submit_models_button.click(
336
- fn=lambda: gr.update(value=input_data),
337
- outputs=[selected_models_display]
338
- )
339
-
340
- proceed_to_metrics_button = gr.Button("Proceed to Metrics")
341
- proceed_to_metrics_button.click(
342
- fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
343
- outputs=[qatch_acc, metrics_acc]
344
- )
345
-
346
- reset_data = gr.Button("Open upload data section")
347
- reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
348
-
349
-
350
- #######################################
351
- # PARTE DI VISUALIZZAZIONE METRICHE #
352
- #######################################
353
- with metrics_acc:
354
- confirmation_text = gr.Markdown("## Metrics successfully loaded")
355
-
356
- data_path = 'metrics_random2.csv'
357
-
358
- def load_data_csv_es():
359
- return pd.read_csv(data_path)
360
-
361
- def calculate_average_metrics(df, selected_metrics):
362
- df['avg_metric'] = df[selected_metrics].mean(axis=1)
363
- return df
364
-
365
- def plot_metric(df, selected_metrics, group_by, selected_models):
366
- df = df[df['model'].isin(selected_models)]
367
- df = calculate_average_metrics(df, selected_metrics)
368
- avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index()
369
- fig = px.bar(
370
- avg_metrics, x=group_by[0], y='avg_metric', color=group_by[-1], barmode='group',
371
- title=f'Media metrica per {group_by[0]}',
372
- labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Media Metrica'},
373
- template='plotly_dark'
374
- )
375
- return fig
376
-
377
- def plot_radar(df, selected_models):
378
- radar_data = []
379
- for model in selected_models:
380
- model_df = df[df['model'] == model]
381
- valid_efficiency = model_df['valid_efficiency_score'].mean()
382
- avg_time = model_df['time'].mean()
383
- avg_tuple_order = model_df['tuple_order'].dropna().mean()
384
-
385
- radar_data.append({
386
- 'model': model,
387
- 'valid_efficiency_score': valid_efficiency,
388
- 'time': avg_time,
389
- 'tuple_order': avg_tuple_order
390
- })
391
-
392
- radar_df = pd.DataFrame(radar_data)
393
- categories = ['valid_efficiency_score', 'time', 'tuple_order']
394
-
395
- # Calcola il range dinamico per il grafico
396
- min_val = radar_df[categories].min().min()
397
- max_val = radar_df[categories].max().max()
398
- radar_df[categories] = (radar_df[categories] - min_val) / (max_val - min_val)
399
-
400
- fig = go.Figure()
401
- for _, row in radar_df.iterrows():
402
- fig.add_trace(go.Scatterpolar(
403
- r=[row[cat] for cat in categories],
404
- theta=categories,
405
- fill='toself',
406
- name=row['model']
407
- ))
408
-
409
- fig.update_layout(
410
- polar=dict(radialaxis=dict(visible=True, range=[min_val, max_val])),
411
- title='Radar Plot delle Metriche per Modello',
412
- template='plotly_dark',
413
- width=700, height=700
414
- )
415
-
416
- return fig
417
-
418
- def plot_query_rate(df, selected_models, show_labels):
419
- df = df[df['model'].isin(selected_models)]
420
-
421
- fig = go.Figure()
422
-
423
- for model in selected_models:
424
- model_df = df[df['model'] == model].copy()
425
-
426
- model_df['cumulative_time'] = model_df['time'].cumsum()
427
- model_df['query_rate'] = 1 / model_df['time']
428
-
429
- fig.add_trace(go.Scatter(
430
- x=model_df['cumulative_time'],
431
- y=model_df['query_rate'],
432
- mode='lines+markers',
433
- name=model,
434
- line=dict(width=2)
435
- ))
436
-
437
- if show_labels:
438
- prev_category = None
439
- prev_time = -float('inf')
440
- y_positions = [1.1, 1.3]
441
- y_idx = 0
442
-
443
- for i, row in model_df.iterrows():
444
- current_category = row['test_category']
445
- if current_category != prev_category and row['cumulative_time'] - prev_time > 5:
446
- fig.add_vline(x=row['cumulative_time'], line_width=1, line_dash="dash", line_color="gray")
447
- fig.add_annotation(
448
- x=row['cumulative_time'],
449
- y=max(model_df['query_rate']) * y_positions[y_idx % 2],
450
- text=current_category,
451
- showarrow=False,
452
- font=dict(size=10, color="white"),
453
- textangle=45,
454
- yshift=10,
455
- bgcolor="rgba(0,0,0,0.6)"
456
- )
457
- prev_category = current_category
458
- prev_time = row['cumulative_time']
459
- y_idx += 1
460
-
461
- fig.update_layout(
462
- title="Rate di Generazione delle Query per Modello",
463
- xaxis_title="Tempo Cumulativo (s)",
464
- yaxis_title="Query al Secondo",
465
- template='plotly_dark',
466
- legend_title="Modelli"
467
- )
468
-
469
- return fig
470
-
471
- def update_plot(selected_metrics, group_by, selected_models):
472
- df = load_data_csv_es()
473
- return plot_metric(df, selected_metrics, group_by, selected_models)
474
-
475
- def update_radar(selected_models):
476
- df = load_data_csv_es()
477
- return plot_radar(df, selected_models)
478
-
479
- def update_query_rate(selected_models, show_labels):
480
- df = load_data_csv_es()
481
- return plot_query_rate(df, selected_models, show_labels)
482
-
483
- def plot_query_time_evolution(df, selected_models):
484
- # Filtriamo i dati per i modelli selezionati
485
- df = df[df['model'].isin(selected_models)]
486
-
487
- # Ordinare per modello e tempo per tracciare l'evoluzione
488
- df_sorted = df.sort_values(by=['model', 'time'])
489
-
490
- fig = go.Figure()
491
-
492
- # Aggiungiamo una traccia per ogni modello
493
- for model in selected_models:
494
- model_df = df_sorted[df_sorted['model'] == model]
495
- fig.add_trace(go.Scatter(
496
- x=model_df.index, y=model_df['time'], mode='lines+markers', name=model,
497
- line=dict(shape='linear'),
498
- text=model_df['model']
499
- ))
500
-
501
- fig.update_layout(
502
- title="Evoluzione del Tempo di Generazione per Modello",
503
- xaxis_title="Indice della Query",
504
- yaxis_title="Tempo (s)",
505
- template='plotly_dark'
506
- )
507
-
508
- return fig
509
-
510
-
511
- metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
512
- group_options = {
513
- "SQL Category": ["test_category", "model"],
514
- "Tabella": ["tbl_name", "model"],
515
- "Modello": ["model"]
516
- }
517
-
518
- df_initial = load_data_csv_es()
519
- models = df_initial['model'].unique().tolist()
520
-
521
- #with gr.Blocks(theme=gr.themes.Default(primary_hue='blue')) as demo:
522
- gr.Markdown("""## Analisi delle prestazioni dei modelli
523
- Seleziona una o più metriche per calcolare la media e visualizzare gli istogrammi e radar plots.
524
- """)
525
-
526
- # Sezione di selezione delle opzioni
527
- with gr.Row():
528
- metric_multiselect = gr.CheckboxGroup(choices=metrics, label="Seleziona le metriche")
529
- model_multiselect = gr.CheckboxGroup(choices=models, label="Seleziona i modelli", value=models)
530
- group_radio = gr.Radio(choices=list(group_options.keys()), label="Seleziona il raggruppamento", value="SQL Category")
531
- #show_labels_checkbox = gr.Checkbox(label="Mostra etichette test category", value=True)
532
-
533
- with gr.Row():
534
- output_plot = gr.Plot()
535
- # Dividi la pagina in due colonne
536
- with gr.Row():
537
- with gr.Column(scale=1): # Imposta la colonna a occupare metà della larghezza
538
- radar_plot = gr.Plot(value=update_radar(models))
539
- with gr.Column(scale=2): # Imposta la seconda colonna a occupare l'altra metà
540
- show_labels_checkbox = gr.Checkbox(label="Mostra etichette test category", value=True)
541
- query_rate_plot = gr.Plot(value=update_query_rate(models, True))
542
-
543
- # Funzioni di callback per il cambiamento dei grafici
544
- def on_change(selected_metrics, selected_group, selected_models):
545
- return update_plot(selected_metrics, group_options[selected_group], selected_models)
546
-
547
- def on_radar_change(selected_models):
548
- return update_radar(selected_models)
549
-
550
- show_labels_checkbox.change(update_query_rate, inputs=[model_multiselect, show_labels_checkbox], outputs=query_rate_plot)
551
- metric_multiselect.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
552
- group_radio.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
553
- model_multiselect.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
554
- model_multiselect.change(on_radar_change, inputs=model_multiselect, outputs=radar_plot)
555
- model_multiselect.change(update_query_rate, inputs=[model_multiselect, show_labels_checkbox], outputs=query_rate_plot)
556
-
557
- reset_data = gr.Button("Open upload data section")
558
- reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
559
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  interface.launch()
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import os
4
+ import sys
5
+ from qatch.connectors.sqlite_connector import SqliteConnector
6
+ from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
7
+ from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
8
+ #from predictor.orchestrator_predictor import OrchestratorPredictor
9
+ import utils_get_db_tables_info
10
+ import utilities as us
11
+ import time
12
+ import plotly.express as px
13
+ import plotly.graph_objects as go
14
+ import plotly.colors as pc
15
+
16
+ with open('style.css', 'r') as file:
17
+ css = file.read()
18
+
19
+ # DataFrame di default
20
+ df_default = pd.DataFrame({
21
+ 'Name': ['Alice', 'Bob', 'Charlie'],
22
+ 'Age': [25, 30, 35],
23
+ 'City': ['New York', 'Los Angeles', 'Chicago']
24
+ })
25
+
26
+ models_path = "models.csv"
27
+
28
+ # Variabile globale per tenere traccia dei dati correnti
29
+ df_current = df_default.copy()
30
+
31
+ input_data = {
32
+ 'input_method': "",
33
+ 'data_path': "",
34
+ 'db_name': "",
35
+ 'data': {
36
+ 'data_frames': {}, # dictionary of dataframes
37
+ 'db': None # SQLITE3 database object
38
+ },
39
+ 'models': []
40
+ }
41
+
42
+ def load_data(file, path, use_default):
43
+ """Carica i dati da un file, un percorso o usa il DataFrame di default."""
44
+ global df_current
45
+ if use_default:
46
+ input_data["input_method"] = 'default'
47
+ input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable.sqlite")
48
+ input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0]
49
+ input_data["data"]['data_frames'] = {'MyTable': df_current}
50
+
51
+ if( input_data["data"]['data_frames']):
52
+ table2primary_key = {}
53
+ for table_name, df in input_data["data"]['data_frames'].items():
54
+ # Assign primary keys for each table
55
+ table2primary_key[table_name] = 'id'
56
+ input_data["data"]["db"] = SqliteConnector(
57
+ relative_db_path=input_data["data_path"],
58
+ db_name=input_data["db_name"],
59
+ tables= input_data["data"]['data_frames'],
60
+ table2primary_key=table2primary_key
61
+ )
62
+
63
+ df_current = df_default.copy() # Ripristina i dati di default
64
+ return input_data["data"]['data_frames']
65
+
66
+ selected_inputs = sum([file is not None, bool(path), use_default])
67
+ if selected_inputs > 1:
68
+ return 'Errore: Selezionare solo un metodo di input alla volta.'
69
+
70
+ if file is not None:
71
+ try:
72
+ input_data["input_method"] = 'uploaded_file'
73
+ input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
74
+ input_data["data_path"] = os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
75
+ input_data["data"] = us.load_data(file, input_data["db_name"])
76
+ df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
77
+ if( input_data["data"]['data_frames']):
78
+ table2primary_key = {}
79
+ for table_name, df in input_data["data"]['data_frames'].items():
80
+ # Assign primary keys for each table
81
+ table2primary_key[table_name] = 'id'
82
+ input_data["data"]["db"] = SqliteConnector(
83
+ relative_db_path=input_data["data_path"],
84
+ db_name=input_data["db_name"],
85
+ tables= input_data["data"]['data_frames'],
86
+ table2primary_key=table2primary_key
87
+ )
88
+ return input_data["data"]['data_frames']
89
+ except Exception as e:
90
+ return f'Errore nel caricamento del file: {e}'
91
+
92
+ """
93
+ if path:
94
+ if not os.path.exists(path):
95
+ return 'Errore: Il percorso specificato non esiste.'
96
+ try:
97
+ input_data["input_method"] = 'uploaded_file'
98
+ input_data["data_path"] = path
99
+ input_data["db_name"] = os.path.splitext(os.path.basename(path))[0]
100
+ input_data["data"] = us.load_data(input_data["data_path"], input_data["db_name"])
101
+ df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
102
+
103
+ return input_data["data"]['data_frames']
104
+ except Exception as e:
105
+ return f'Errore nel caricamento del file dal percorso: {e}'
106
+ """
107
+
108
+
109
+ return input_data["data"]['data_frames']
110
+
111
+ def preview_default(use_default):
112
+ """Mostra il DataFrame di default se il checkbox è selezionato."""
113
+ if use_default:
114
+ return df_default # Mostra il DataFrame di default
115
+ return df_current # Mostra il DataFrame corrente, che potrebbe essere stato modificato
116
+
117
+ def update_df(new_df):
118
+ """Aggiorna il DataFrame corrente."""
119
+ global df_current # Usa la variabile globale per aggiornarla
120
+ df_current = new_df
121
+ return df_current
122
+
123
+ def open_accordion(target):
124
+ # Apre uno e chiude l'altro
125
+ if target == "reset":
126
+ df_current = df_default.copy()
127
+ input_data['input_method'] = ""
128
+ input_data['data_path'] = ""
129
+ input_data['db_name'] = ""
130
+ input_data['data']['data_frames'] = {}
131
+ input_data['data']['db'] = None
132
+ input_data['models'] = []
133
+ return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value=False), gr.update(value=None)
134
+ elif target == "model_selection":
135
+ return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False)
136
+
137
+ # Interfaccia Gradio
138
+
139
+ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
140
+ gr.Markdown("# QATCH")
141
+ data_state = gr.State(None) # Memorizza i dati caricati
142
+ upload_acc = gr.Accordion("Upload your data section", open=True, visible=True)
143
+ select_table_acc = gr.Accordion("Select tables", open=False, visible=False)
144
+ select_model_acc = gr.Accordion("Select models", open=False, visible=False)
145
+ qatch_acc = gr.Accordion("QATCH execution", open=False, visible=False)
146
+ metrics_acc = gr.Accordion("Metrics", open=False, visible=False)
147
+ #metrics_acc = gr.Accordion("Metrics", open=False, visible=False, render=False)
148
+
149
+
150
+
151
+ #################################
152
+ # PARTE DI INSERIMENTO DEL DB #
153
+ #################################
154
+ with upload_acc:
155
+ gr.Markdown("## Caricamento dei Dati")
156
+
157
+ file_input = gr.File(label="Trascina e rilascia un file", file_types=[".csv", ".xlsx", ".sqlite"])
158
+ with gr.Row():
159
+ default_checkbox = gr.Checkbox(label="Usa DataFrame di default")
160
+ preview_output = gr.DataFrame(interactive=True, visible=True, value=df_default)
161
+ submit_button = gr.Button("Carica Dati", interactive=False) # Disabilitato di default
162
+ output = gr.JSON(visible=False) # Output dizionario
163
+
164
+ # Funzione per abilitare il bottone se sono presenti dati da caricare
165
+ def enable_submit(file, use_default):
166
+ return gr.update(interactive=bool(file or use_default))
167
+
168
+ # Funzione per deselezionare il checkbox se viene caricato un file
169
+ def deselect_default(file):
170
+ if file:
171
+ return gr.update(value=False)
172
+ return gr.update()
173
+
174
+ # Abilita il bottone quando i campi di input sono valorizzati
175
+ file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
176
+ default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
177
+
178
+ # Mostra l'anteprima del DataFrame di default quando il checkbox è selezionato
179
+ default_checkbox.change(fn=preview_default, inputs=[default_checkbox], outputs=[preview_output])
180
+ preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output])
181
+
182
+ # Deseleziona il checkbox quando viene caricato un file
183
+ file_input.change(fn=deselect_default, inputs=[file_input], outputs=[default_checkbox])
184
+
185
+ def handle_output(file, use_default):
186
+ """Gestisce l'output quando si preme il bottone 'Carica Dati'."""
187
+ result = load_data(file, None, use_default)
188
+
189
+ if isinstance(result, dict): # Se result è un dizionario di DataFrame
190
+ if len(result) == 1: # Se c'è solo una tabella
191
+ return (
192
+ gr.update(visible=False), # Nasconde l'output JSON
193
+ result, # Salva lo stato dei dati
194
+ gr.update(visible=False), # Nasconde la selezione tabella
195
+ result, # Mantiene lo stato dei dati
196
+ gr.update(interactive=False), # Disabilita il pulsante di submit
197
+ gr.update(visible=True, open=True), # Passa direttamente a select_model_acc
198
+ gr.update(visible=True, open=False)
199
+ )
200
+ else:
201
+ return (
202
+ gr.update(visible=False),
203
+ result,
204
+ gr.update(open=True, visible=True),
205
+ result,
206
+ gr.update(interactive=False),
207
+ gr.update(visible=False), # Mantiene il comportamento attuale
208
+ gr.update(visible=True, open=True)
209
+ )
210
+ else:
211
+ return (
212
+ gr.update(visible=False),
213
+ None,
214
+ gr.update(open=False, visible=True),
215
+ None,
216
+ gr.update(interactive=True),
217
+ gr.update(visible=False),
218
+ gr.update(visible=True, open=True)
219
+ )
220
+
221
+ submit_button.click(
222
+ fn=handle_output,
223
+ inputs=[file_input, default_checkbox],
224
+ outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc]
225
+ )
226
+
227
+
228
+
229
+ ######################################
230
+ # PARTE DI SELEZIONE DELLE TABELLE #
231
+ ######################################
232
+ with select_table_acc:
233
+
234
+ table_selector = gr.CheckboxGroup(choices=[], label="Seleziona le tabelle da visualizzare", value=[])
235
+ table_outputs = [gr.DataFrame(label=f"Tabella {i+1}", interactive=True, visible=False) for i in range(5)]
236
+ selected_table_names = gr.Textbox(label="Tabelle selezionate", visible=False, interactive=False)
237
+
238
+ # Bottone di selezione modelli (inizialmente disabilitato)
239
+ open_model_selection = gr.Button("Choose your models", interactive=False)
240
+
241
+ def update_table_list(data):
242
+ """Aggiorna dinamicamente la lista delle tabelle disponibili."""
243
+ if isinstance(data, dict) and data:
244
+ table_names = list(data.keys()) # Ritorna solo i nomi delle tabelle
245
+ return gr.update(choices=table_names, value=[]) # Reset delle selezioni
246
+ return gr.update(choices=[], value=[])
247
+
248
+ def show_selected_tables(data, selected_tables):
249
+ """Mostra solo le tabelle selezionate dall'utente e abilita il bottone."""
250
+ updates = []
251
+ if isinstance(data, dict) and data:
252
+ available_tables = list(data.keys()) # Nomi effettivamente disponibili
253
+ selected_tables = [t for t in selected_tables if t in available_tables] # Filtra selezioni valide
254
+
255
+ tables = {name: data[name] for name in selected_tables} # Filtra i DataFrame
256
+
257
+ for i, (name, df) in enumerate(tables.items()):
258
+ updates.append(gr.update(value=df, label=f"Tabella: {name}", visible=True))
259
+
260
+ # Se ci sono meno di 5 tabelle, nascondi gli altri DataFrame
261
+ for _ in range(len(tables), 5):
262
+ updates.append(gr.update(visible=False))
263
+ else:
264
+ updates = [gr.update(value=pd.DataFrame(), visible=False) for _ in range(5)]
265
+
266
+ # Abilitare/disabilitare il bottone in base alle selezioni
267
+ button_state = bool(selected_tables) # True se almeno una tabella è selezionata, False altrimenti
268
+ updates.append(gr.update(interactive=button_state)) # Aggiorna stato bottone
269
+
270
+ return updates
271
+
272
+ def show_selected_table_names(selected_tables):
273
+ """Mostra i nomi delle tabelle selezionate quando si preme il bottone."""
274
+ if selected_tables:
275
+ return gr.update(value=", ".join(selected_tables), visible=False)
276
+ return gr.update(value="", visible=False)
277
+
278
+ # Aggiorna automaticamente la lista delle checkbox quando `data_state` cambia
279
+ data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector])
280
+
281
+ # Aggiorna le tabelle visibili e lo stato del bottone in base alle selezioni dell'utente
282
+ table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
283
+
284
+ # Mostra la lista delle tabelle selezionate quando si preme "Choose your models"
285
+ open_model_selection.click(fn=show_selected_table_names, inputs=[table_selector], outputs=[selected_table_names])
286
+ open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
287
+
288
+
289
+
290
+ ####################################
291
+ # PARTE DI SELEZIONE DEL MODELLO #
292
+ ####################################
293
+ with select_model_acc:
294
+ gr.Markdown("**Model Selection**")
295
+
296
+ # Supponiamo che `us.read_models_csv` restituisca anche il percorso dell'immagine
297
+ model_list_dict = us.read_models_csv(models_path)
298
+ model_list = [model["code"] for model in model_list_dict]
299
+ model_images = [model["image_path"] for model in model_list_dict]
300
+
301
+ model_checkboxes = []
302
+ rows = []
303
+
304
+ # Creazione dinamica di checkbox con immagini (3 per riga)
305
+ for i in range(0, len(model_list), 3):
306
+ with gr.Row():
307
+ cols = []
308
+ for j in range(3):
309
+ if i + j < len(model_list):
310
+ model = model_list[i + j]
311
+ image_path = model_images[i + j]
312
+ with gr.Column():
313
+ gr.Image(image_path, show_label=False)
314
+ checkbox = gr.Checkbox(label=model, value=False)
315
+ model_checkboxes.append(checkbox)
316
+ cols.append(checkbox)
317
+ rows.append(cols)
318
+
319
+ selected_models_output = gr.JSON(visible=False)
320
+
321
+ # Funzione per ottenere i modelli selezionati
322
+ def get_selected_models(*model_selections):
323
+ selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
324
+ input_data['models'] = selected_models
325
+ button_state = bool(selected_models) # True se almeno un modello è selezionato, False altrimenti
326
+ return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
327
+
328
+ # Bottone di submit (inizialmente disabilitato)
329
+ submit_models_button = gr.Button("Submit Models", interactive=False)
330
+
331
+ # Collegamento dei checkbox agli eventi di selezione
332
+ for checkbox in model_checkboxes:
333
+ checkbox.change(
334
+ fn=get_selected_models,
335
+ inputs=model_checkboxes,
336
+ outputs=[selected_models_output, select_model_acc, submit_models_button]
337
+ )
338
+
339
+ submit_models_button.click(
340
+ fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
341
+ inputs=model_checkboxes,
342
+ outputs=[selected_models_output, select_model_acc, qatch_acc]
343
+ )
344
+
345
+ reset_data = gr.Button("Back to upload data section")
346
+ reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
347
+
348
+
349
+ ###############################
350
+ # PARTE DI ESECUZIONE QATCH #
351
+ ###############################
352
+ with qatch_acc:
353
+ def change_text(text):
354
+ return text
355
+ def qatch_flow():
356
+ orchestrator_generator = OrchestratorGenerator()
357
+ #TODO add to target_df column target_df["columns_used"], tables selection
358
+ #print(input_data['data']['db'])
359
+ target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'])
360
+
361
+ schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
362
+ db_id = input_data["db_name"],
363
+ base_path = input_data["data_path"],
364
+ normalize=False,
365
+ sql=None
366
+ )
367
+
368
+ # TODO QUERY PREDICTION
369
+ predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
370
+ metrics_conc = pd.DataFrame()
371
+ for model in input_data["models"]:
372
+ for index, row in target_df.iterrows():
373
+ if len(target_df) != 0: load_value = f"##Loading... {round((index + 1) / len(target_df) * 100, 2)}%"
374
+ else: load_value = "##Loading..."
375
+ question = row['query']
376
+ #yield gr.Textbox(question), gr.Textbox(), *[predictions_dict[model] for model in input_data["models"]], None
377
+ yield gr.Markdown(value=load_value), gr.Textbox(question), gr.Textbox(), metrics_conc, *[predictions_dict[model] for model in model_list]
378
+ start_time = time.time()
379
+
380
+ # Simulazione della predizione
381
+ time.sleep(0.03)
382
+ prediction = "Prediction_placeholder"
383
+
384
+ # Esegui la predizione reale qui
385
+ # prediction = predictor.run(model, schema_text, question)
386
+
387
+ end_time = time.time()
388
+ # Crea una nuova riga come dataframe
389
+ new_row = pd.DataFrame([{
390
+ 'id': index,
391
+ 'question': question,
392
+ 'predicted_sql': prediction,
393
+ 'time': end_time - start_time,
394
+ 'query': row["query"],
395
+ 'db_path': input_data["data_path"]
396
+ }]).dropna(how="all") # Rimuove solo righe completamente vuote
397
+ #TODO con un for
398
+ for col in target_df.columns:
399
+ if col not in new_row.columns:
400
+ new_row[col] = row[col]
401
+ # Aggiorna il dataframe corrispondente al modello man mano
402
+ if not new_row.empty:
403
+ predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
404
+ #yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
405
+ yield gr.Markdown(value=load_value), gr.Textbox(), gr.Textbox(prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
406
+
407
+ #END
408
+ evaluator = OrchestratorEvaluator()
409
+ for model in input_data["models"]:
410
+ metrics_df_model = evaluator.evaluate_df(
411
+ df=predictions_dict[model],
412
+ target_col_name="query", #'<target_column_name>',
413
+ prediction_col_name="predicted_sql", #'<prediction_column_name>',
414
+ db_path_name= "db_path", #'<db_path_column_name>'
415
+ )
416
+ metrics_df_model['model'] = model
417
+ metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
418
+
419
+ if 'valid_efficiency_score' not in metrics_conc.columns:
420
+ metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
421
+
422
+ yield gr.Markdown(), gr.Textbox(), gr.Textbox(), metrics_conc, *[predictions_dict[model] for model in model_list]
423
+
424
+ #Loading Bar
425
+ with gr.Row():
426
+ #progress = gr.Progress()
427
+ variable = gr.Markdown()
428
+
429
+ #NL -> MODEL -> Generated Quesy
430
+ with gr.Row():
431
+ with gr.Column():
432
+ question_display = gr.Textbox()
433
+ with gr.Column():
434
+ gr.Image()
435
+ with gr.Column():
436
+ prediction_display = gr.Textbox()
437
+
438
+ dataframe_per_model = {}
439
+
440
+ with gr.Tabs() as model_tabs:
441
+ #for model in input_data["models"]:
442
+ for model in model_list:
443
+ #TODO fix model tabs
444
+ with gr.TabItem(model):
445
+ gr.Markdown(f"**Results for {model}**")
446
+ dataframe_per_model[model] = gr.DataFrame()
447
+
448
+
449
+ #question_display.change(fn=change_text, inputs=[gr.State(question)], outputs=[question_display])
450
+ selected_models_display = gr.JSON(label="Modelli selezionati")
451
+ metrics_df = gr.DataFrame(visible=False)
452
+ metrics_df_out= gr.DataFrame(visible=False)
453
+
454
+ submit_models_button.click(
455
+ fn=qatch_flow,
456
+ inputs=[],
457
+ outputs=[variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
458
+ )
459
+
460
+ submit_models_button.click(
461
+ fn=lambda: gr.update(value=input_data),
462
+ outputs=[selected_models_display]
463
+ )
464
+ #Funziona per METRICS
465
+ metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
466
+
467
+ # def change_tab(selected_models_output, model_tabs):
468
+ # for model in model_list:
469
+ # if model in selected_models_output:
470
+ # pass#model_tabs[model].visible = True
471
+ # else:
472
+ # pass#model_tabs[model].visible = False
473
+ # return model_tabs
474
+
475
+ # selected_models_output.change(fn=change_tab, inputs=[selected_models_output, model_tabs], outputs=[])
476
+
477
+ proceed_to_metrics_button = gr.Button("Proceed to Metrics")
478
+ proceed_to_metrics_button.click(
479
+ fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
480
+ outputs=[qatch_acc, metrics_acc]
481
+ )
482
+
483
+ reset_data = gr.Button("Back to upload data section")
484
+ reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
485
+
486
+
487
+
488
+ #######################################
489
+ # METRICS VISUALIZATION SECTION #
490
+ #######################################
491
+ with metrics_acc:
492
+ #confirmation_text = gr.Markdown("## Metrics successfully loaded")
493
+
494
+ data_path = 'test_results.csv'
495
+
496
+ @gr.render(inputs=metrics_df_out)
497
+ def function_metrics(metrics_df_out):
498
+ def load_data_csv_es():
499
+ return pd.read_csv(data_path)
500
+ #return metrics_df_out
501
+
502
+ def calculate_average_metrics(df, selected_metrics):
503
+ df['avg_metric'] = df[selected_metrics].mean(axis=1)
504
+ return df
505
+
506
+ def generate_model_colors():
507
+ """Generates a unique color map for models in the dataset."""
508
+ df = load_data_csv_es()
509
+ unique_models = df['model'].unique() # Extract unique models
510
+ num_models = len(unique_models)
511
+
512
+ # Use the Plotly color scale (you can change it if needed)
513
+ color_palette = pc.qualitative.Plotly # ['#636EFA', '#EF553B', '#00CC96', ...]
514
+
515
+ # If there are more models than colors, cycle through them
516
+ colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)}
517
+
518
+ return colors
519
+
520
+ MODEL_COLORS = generate_model_colors()
521
+
522
+ # BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
523
+ def plot_metric(df, selected_metrics, group_by, selected_models):
524
+ df = df[df['model'].isin(selected_models)]
525
+ df = calculate_average_metrics(df, selected_metrics)
526
+
527
+ # Ensure the group_by value is always valid
528
+ if group_by not in [["tbl_name", "model"], ["model"]]:
529
+ group_by = ["tbl_name", "model"] # Default
530
+
531
+ avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index()
532
+
533
+ fig = px.bar(
534
+ avg_metrics,
535
+ x=group_by[0],
536
+ y='avg_metric',
537
+ color='model',
538
+ color_discrete_map=MODEL_COLORS,
539
+ barmode='group',
540
+ title=f'Average metric per {group_by[0]} 📊',
541
+ labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metric'},
542
+ template='plotly_dark'
543
+ )
544
+
545
+ return fig
546
+
547
+ def update_plot(selected_metrics, group_by, selected_models):
548
+ df = load_data_csv_es()
549
+ return plot_metric(df, selected_metrics, group_by, selected_models)
550
+
551
+
552
+ # RADAR CHART FOR AVERAGE METRICS PER MODEL WITH UPDATE FUNCTION
553
+ def plot_radar(df, selected_models):
554
+ # Filter only selected models
555
+ df = df[df['model'].isin(selected_models)]
556
+
557
+ # Select relevant metrics
558
+ selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
559
+
560
+ # Compute average metrics per test_category and model
561
+ df = calculate_average_metrics(df, selected_metrics)
562
+ avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
563
+
564
+ # Check if data is available
565
+ if avg_metrics.empty:
566
+ print("Error: No data available to compute averages.")
567
+ return go.Figure()
568
+
569
+ fig = go.Figure()
570
+ categories = avg_metrics['test_category'].unique()
571
+
572
+ for model in selected_models:
573
+ model_data = avg_metrics[avg_metrics['model'] == model]
574
+
575
+ # Build a list of values for each category (if a value is missing, set it to 0)
576
+ values = [
577
+ model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
578
+ if cat in model_data['test_category'].values else 0
579
+ for cat in categories
580
+ ]
581
+
582
+ fig.add_trace(go.Scatterpolar(
583
+ r=values,
584
+ theta=categories,
585
+ fill='toself',
586
+ name=model,
587
+ line=dict(color=MODEL_COLORS.get(model, "gray"))
588
+ ))
589
+
590
+ fig.update_layout(
591
+ polar=dict(radialaxis=dict(visible=True, range=[0, max(avg_metrics['avg_metric'].max(), 0.5)])), # Set the radar range
592
+ title='❇️ Radar Plot of Metrics per Model (Average per Category) ❇️ ',
593
+ template='plotly_dark',
594
+ width=700, height=700
595
+ )
596
+
597
+ return fig
598
+
599
+ def update_radar(selected_models):
600
+ df = load_data_csv_es()
601
+ return plot_radar(df, selected_models)
602
+
603
+
604
+ # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
605
+ def plot_cumulative_flow(df, selected_models):
606
+ df = df[df['model'].isin(selected_models)]
607
+
608
+ fig = go.Figure()
609
+
610
+ for model in selected_models:
611
+ model_df = df[df['model'] == model].copy()
612
+
613
+ # Calculate cumulative time
614
+ model_df['cumulative_time'] = model_df['time'].cumsum()
615
+
616
+ # Calculate cumulative number of queries over time
617
+ model_df['cumulative_queries'] = range(1, len(model_df) + 1)
618
+
619
+ # Select a color for the model
620
+ color = MODEL_COLORS.get(model, "gray") # Assigned model color
621
+ fillcolor = color.replace("rgb", "rgba").replace(")", ", 0.2)") # 🔹 Makes the area semi-transparent
622
+
623
+ #color = f"rgba({hash(model) % 256}, {hash(model * 2) % 256}, {hash(model * 3) % 256}, 1)"
624
+
625
+ fig.add_trace(go.Scatter(
626
+ x=model_df['cumulative_time'],
627
+ y=model_df['cumulative_queries'],
628
+ mode='lines+markers',
629
+ name=model,
630
+ line=dict(width=2, color=color)
631
+ ))
632
+
633
+ # Adds the underlying colored area (same color but transparent)
634
+ """
635
+ fig.add_trace(go.Scatter(
636
+ x=model_df['cumulative_time'],
637
+ y=model_df['cumulative_queries'],
638
+ fill='tozeroy',
639
+ mode='none',
640
+ showlegend=False, # Hides the area in the legend
641
+ fillcolor=fillcolor
642
+ ))
643
+ """
644
+
645
+ fig.update_layout(
646
+ title="Cumulative Query Flow Chart 📈",
647
+ xaxis_title="Cumulative Time (s)",
648
+ yaxis_title="Number of Queries Completed",
649
+ template='plotly_dark',
650
+ legend_title="Models"
651
+ )
652
+
653
+ return fig
654
+
655
+ def update_query_rate(selected_models):
656
+ df = load_data_csv_es()
657
+ return plot_cumulative_flow(df, selected_models)
658
+
659
+
660
+ # RANKING FOR THE TOP 3 MODELS WITH UPDATE FUNCTION
661
+ def ranking_text(df, selected_models, ranking_type):
662
+ #df = load_data_csv_es()
663
+ df = df[df['model'].isin(selected_models)]
664
+ df['valid_efficiency_score'] = pd.to_numeric(df['valid_efficiency_score'], errors='coerce')
665
+ if ranking_type == "valid_efficiency_score":
666
+ rank_df = df.groupby('model')['valid_efficiency_score'].mean().reset_index()
667
+ #rank_df = df.groupby('model')['valid_efficiency_score'].mean().reset_index()
668
+ ascending_order = False # Higher is better
669
+ elif ranking_type == "time":
670
+ rank_df = df.groupby('model')['time'].sum().reset_index()
671
+ rank_df["Ranking Value"] = rank_df["time"].round(2).astype(str) + " s" # Adds "s" for seconds
672
+ ascending_order = True # For time, lower is better
673
+ elif ranking_type == "metrics":
674
+ selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
675
+ df = calculate_average_metrics(df, selected_metrics)
676
+ rank_df = df.groupby('model')['avg_metric'].mean().reset_index()
677
+ ascending_order = False # Higher is better
678
+
679
+ if ranking_type != "time":
680
+ rank_df.rename(columns={rank_df.columns[1]: "Ranking Value"}, inplace=True)
681
+ rank_df["Ranking Value"] = rank_df["Ranking Value"].round(2) # Round values except for time
682
+
683
+ # Sort based on the selected criterion
684
+ rank_df = rank_df.sort_values(by="Ranking Value", ascending=ascending_order).reset_index(drop=True)
685
+
686
+ # Select only the top 3 models
687
+ rank_df = rank_df.head(3)
688
+
689
+ # Add medal icons for the top 3
690
+ medals = ["🥇", "🥈", "🥉"]
691
+ rank_df.insert(0, "Rank", medals[:len(rank_df)])
692
+
693
+ # Build the formatted ranking string
694
+ ranking_str = "## 🏆 Model Ranking\n"
695
+ for _, row in rank_df.iterrows():
696
+ ranking_str += f"<span style='font-size:18px;'>{row['Rank']} {row['model']} ({row['Ranking Value']})</span><br>\n"
697
+
698
+ return ranking_str
699
+
700
+ def update_ranking_text(selected_models, ranking_type):
701
+ df = load_data_csv_es()
702
+ return ranking_text(df, selected_models, ranking_type)
703
+
704
+
705
+ # RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
706
+ def worst_cases_text(df, selected_models):
707
+ df = df[df['model'].isin(selected_models)]
708
+
709
+ selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
710
+ df = calculate_average_metrics(df, selected_metrics)
711
+
712
+ worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql'])['avg_metric'].mean().reset_index()
713
+
714
+ worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
715
+
716
+ worst_cases_top_3 = worst_cases_df.head(3)
717
+
718
+ worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2)
719
+
720
+ worst_str = "## ❌ Top 3 Worst Cases\n"
721
+ medals = ["🥇", "🥈", "🥉"]
722
+
723
+ for i, row in worst_cases_top_3.iterrows():
724
+ worst_str += (
725
+ f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']}</b> ({row['avg_metric']})</span> \n"
726
+ f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n"
727
+ f"<span style='font-size:16px;'>- <b>Original Query:</b> `{row['query']}`</span> \n"
728
+ f"<span style='font-size:16px;'>- <b>Predicted SQL:</b> `{row['predicted_sql']}`</span> \n\n"
729
+ )
730
+
731
+ return worst_str
732
+
733
+ def update_worst_cases_text(selected_models):
734
+ df = load_data_csv_es()
735
+ return worst_cases_text(df, selected_models)
736
+
737
+
738
+ metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
739
+ group_options = {
740
+ "Table": ["tbl_name", "model"],
741
+ "Model": ["model"]
742
+ }
743
+
744
+ df_initial = load_data_csv_es()
745
+ models = df_initial['model'].unique().tolist()
746
+
747
+ #with gr.Blocks(theme=gr.themes.Default(primary_hue='blue')) as demo:
748
+ gr.Markdown("""## 📊 Model Performance Analysis 📊
749
+ Select one or more metrics to calculate the average and visualize histograms and radar plots.
750
+ """)
751
+
752
+ # Options selection section
753
+ with gr.Row():
754
+
755
+ metric_multiselect = gr.CheckboxGroup(choices=metrics, label="Select metrics", value=metrics)
756
+ model_multiselect = gr.CheckboxGroup(choices=models, label="Select models", value=models)
757
+ group_radio = gr.Radio(choices=list(group_options.keys()), label="Select grouping", value="Model")
758
+
759
+ output_plot = gr.Plot()
760
+
761
+ query_rate_plot = gr.Plot(value=update_query_rate(models))
762
+
763
+ with gr.Row():
764
+ with gr.Column(scale=1):
765
+ radar_plot = gr.Plot(value=update_radar(models))
766
+
767
+ with gr.Column(scale=1):
768
+ ranking_type_radio = gr.Radio(
769
+ ["valid_efficiency_score", "time", "metrics"],
770
+ label="Choose ranking criteria",
771
+ value="valid_efficiency_score"
772
+ )
773
+ ranking_text_display = gr.Markdown(value=update_ranking_text(models, "valid_efficiency_score"))
774
+ worst_cases_display = gr.Markdown(value=update_worst_cases_text(models))
775
+
776
+ # Callback functions for updating charts
777
+ def on_change(selected_metrics, selected_group, selected_models):
778
+ return update_plot(selected_metrics, group_options[selected_group], selected_models)
779
+
780
+ def on_radar_change(selected_models):
781
+ return update_radar(selected_models)
782
+
783
+ #metrics_df_out.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
784
+ metric_multiselect.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
785
+ group_radio.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
786
+ model_multiselect.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
787
+ model_multiselect.change(update_radar, inputs=model_multiselect, outputs=radar_plot)
788
+ model_multiselect.change(update_ranking_text, inputs=[model_multiselect, ranking_type_radio], outputs=ranking_text_display)
789
+ ranking_type_radio.change(update_ranking_text, inputs=[model_multiselect, ranking_type_radio], outputs=ranking_text_display)
790
+ model_multiselect.change(update_worst_cases_text, inputs=model_multiselect, outputs=worst_cases_display)
791
+ model_multiselect.change(update_query_rate, inputs=[model_multiselect], outputs=query_rate_plot)
792
+
793
+ reset_data = gr.Button("Back to upload data section")
794
+ reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
795
+
796
+ # Hidden button to force UI refresh on load
797
+ force_update_button = gr.Button("", visible=False)
798
+
799
+ # State variable to track first load
800
+ load_trigger = gr.State(value=True)
801
+
802
+ # Function to force initial load
803
+ def force_update(is_first_load):
804
+ if is_first_load:
805
+ return (
806
+ update_plot(metrics, group_options["Model"], models),
807
+ update_query_rate(models),
808
+ update_radar(models),
809
+ update_ranking_text(models, "valid_efficiency_score"),
810
+ update_worst_cases_text(models),
811
+ False # Change state to prevent continuous reloads
812
+ )
813
+ return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), False
814
+
815
+ # The invisible button forces chart loading only the first time
816
+ force_update_button.click(
817
+ fn=force_update,
818
+ inputs=[load_trigger],
819
+ outputs=[output_plot, query_rate_plot, radar_plot, ranking_text_display, worst_cases_display, load_trigger]
820
+ )
821
+
822
+ # Simulate button click when UI loads
823
+ with gr.Blocks() as demo:
824
+ demo.load(
825
+ lambda: force_update(True),
826
+ outputs=[output_plot, query_rate_plot, radar_plot, ranking_text_display, worst_cases_display, load_trigger]
827
+ )
828
+
829
  interface.launch()
style.css CHANGED
@@ -22,4 +22,9 @@
22
  display: block;
23
  margin: 10px auto 0;
24
  border-radius: 2px;
25
- }
 
 
 
 
 
 
22
  display: block;
23
  margin: 10px auto 0;
24
  border-radius: 2px;
25
+ }
26
+
27
+ #bar_plot, #line_plot {
28
+ width: 100% !important;
29
+ max-width: none !important;
30
+ }
test_results.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model,tbl_name,test_category,question,query,predicted_sql,cell_precision,cell_recall,execution_accuracy,tuple_cardinality,tuple_constraint,time,valid_efficiency_score,tuple_order
2
+ Model_B,Table_2,WHERE,Mostra il prodotto più venduto,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT SUM(sales) FROM orders;,0.15,0.89,0.84,0.92,0.97,13,10,
3
+ Model_C,Table_3,GROUPBY,Quali clienti hanno speso di più?,SELECT * FROM users;,SELECT * FROM customers ORDER BY total_spent DESC;,0.09,0.1,0.02,0.22,0.42,3,4,
4
+ Model_A,Table_3,ORDERBY,Mostra il prodotto più venduto,SELECT * FROM users;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.9,0.81,0.18,0.56,0.52,11,11,0.39
5
+ Model_C,Table_3,WHERE,Qual è la media dei prezzi?,SELECT SUM(sales) FROM orders;,SELECT * FROM users;,0.87,0.26,0.83,0.37,0.83,2,1,
6
+ Model_B,Table_3,ORDERBY,Quali clienti hanno speso di più?,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.62,0.49,0.66,0.95,0.68,13,2,0.38
7
+ Model_C,Table_3,JOIN,Qual è la media dei prezzi?,SELECT AVG(price) FROM products;,SELECT * FROM customers ORDER BY total_spent DESC;,0.54,0.2,0.83,0.43,0.64,11,14,
8
+ Model_B,Table_1,ORDERBY,Ordina i clienti per spesa,SELECT SUM(sales) FROM orders;,SELECT AVG(price) FROM products;,0.35,0.12,0.51,0.78,0.94,9,12,0.06
9
+ Model_C,Table_2,SELECT,Qual è la media dei prezzi?,SELECT AVG(price) FROM products;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.96,0.23,0.38,0.26,0.13,11,8,
10
+ Model_C,Table_2,JOIN,Elenca tutti gli utenti,SELECT * FROM customers ORDER BY total_spent DESC;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.5,0.07,0.02,0.82,0.75,2,5,
11
+ Model_A,Table_1,JOIN,Trova il totale delle vendite,SELECT * FROM users;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.03,0.81,0.15,0.64,0.98,12,8,
12
+ Model_B,Table_3,JOIN,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT AVG(price) FROM products;,0.95,0.63,0.94,0.31,0.94,12,15,
13
+ Model_A,Table_3,SELECT,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.26,0.81,0.45,0.7,0.77,9,0,
14
+ Model_C,Table_3,WHERE,Ordina i clienti per spesa,SELECT * FROM users;,SELECT * FROM customers ORDER BY total_spent DESC;,0.65,0.22,0.61,0.56,0.84,5,11,
15
+ Model_C,Table_1,WHERE,Trova il totale delle vendite,SELECT * FROM users;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.3,0.83,0.14,0.22,0.7,14,2,
16
+ Model_A,Table_3,JOIN,Elenca tutti gli utenti,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;","SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.66,0.57,0.07,0.93,0.16,15,11,
17
+ Model_C,Table_1,WHERE,Elenca tutti gli utenti,SELECT SUM(sales) FROM orders;,SELECT SUM(sales) FROM orders;,0.39,0.97,0.87,0.99,0.27,13,4,
18
+ Model_A,Table_2,SELECT,Elenca tutti gli utenti,SELECT SUM(sales) FROM orders;,SELECT * FROM users;,0.64,0.44,0.03,0.81,0.43,14,6,
19
+ Model_A,Table_3,GROUPBY,Quali clienti hanno speso di più?,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT * FROM customers ORDER BY total_spent DESC;,0.55,0.32,0.28,0.12,0.79,7,3,
20
+ Model_B,Table_3,WHERE,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT * FROM users;,0.56,0.27,0.34,0.59,0.59,10,15,
21
+ Model_A,Table_2,SELECT,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.6,0.83,0.41,0.28,0.02,10,8,
22
+ Model_C,Table_2,SELECT,Elenca tutti gli utenti,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT AVG(price) FROM products;,0.77,0.5,0.13,0.74,0.36,2,4,
23
+ Model_C,Table_1,ORDERBY,Elenca tutti gli utenti,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.5,0.33,0.91,0.05,0.71,3,11,0.83
24
+ Model_C,Table_2,WHERE,Mostra il prodotto più venduto,SELECT SUM(sales) FROM orders;,SELECT AVG(price) FROM products;,0.74,0.62,0.64,0.26,0.05,1,5,
25
+ Model_B,Table_3,JOIN,Qual è la media dei prezzi?,SELECT SUM(sales) FROM orders;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.98,0.73,0.52,0.12,0.72,1,6,
26
+ Model_A,Table_2,WHERE,Elenca tutti gli utenti,SELECT AVG(price) FROM products;,SELECT AVG(price) FROM products;,0.9,0.84,0.57,0.86,0.66,10,1,
27
+ Model_A,Table_1,WHERE,Elenca tutti gli utenti,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT SUM(sales) FROM orders;,0.21,0.73,0.97,0.58,0.04,5,4,
28
+ Model_C,Table_3,GROUPBY,Qual è la media dei prezzi?,SELECT * FROM users;,SELECT * FROM users;,0.63,0.51,0.14,0.24,0.62,3,6,
29
+ Model_A,Table_1,ORDERBY,Elenca tutti gli utenti,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.39,0.42,0.88,0.45,0.24,9,1,0.95
30
+ Model_B,Table_3,JOIN,Quali clienti hanno speso di più?,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT SUM(sales) FROM orders;,0.86,0.59,0.53,0.91,0.9,4,8,
31
+ Model_C,Table_3,JOIN,Ordina i clienti per spesa,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM users;,0.86,0.57,0.26,0.47,0.5,13,13,
32
+ Model_B,Table_2,GROUPBY,Mostra il prodotto più venduto,SELECT SUM(sales) FROM orders;,SELECT AVG(price) FROM products;,0.67,0.7,0.1,0.42,0.31,9,14,
33
+ Model_C,Table_3,GROUPBY,Trova il totale delle vendite,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT AVG(price) FROM products;,0.74,0.26,0.64,0.33,0.09,9,7,
34
+ Model_C,Table_2,GROUPBY,Mostra il prodotto più venduto,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.31,0.75,0.73,0.3,0.25,8,10,
35
+ Model_B,Table_2,JOIN,Elenca tutti gli utenti,SELECT AVG(price) FROM products;,SELECT * FROM users;,0.83,0.63,0.43,0.0,0.1,13,9,
36
+ Model_C,Table_1,ORDERBY,Qual è la media dei prezzi?,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM customers ORDER BY total_spent DESC;,1.0,0.48,0.96,0.45,0.66,4,5,0.8
37
+ Model_A,Table_3,JOIN,Trova il totale delle vendite,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM customers ORDER BY total_spent DESC;,0.97,0.46,0.34,0.57,0.21,15,4,
38
+ Model_C,Table_3,SELECT,Quali clienti hanno speso di più?,SELECT * FROM users;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.72,0.02,0.64,0.62,0.83,11,7,
39
+ Model_A,Table_2,GROUPBY,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.16,0.67,0.2,0.62,0.82,9,15,
40
+ Model_C,Table_2,ORDERBY,Quali clienti hanno speso di più?,SELECT SUM(sales) FROM orders;,SELECT AVG(price) FROM products;,0.38,0.13,0.96,0.36,0.9,14,5,0.01
41
+ Model_A,Table_2,GROUPBY,Elenca tutti gli utenti,SELECT AVG(price) FROM products;,SELECT * FROM customers ORDER BY total_spent DESC;,0.73,0.77,0.24,0.35,0.77,1,2,
42
+ Model_C,Table_1,JOIN,Mostra il prodotto più venduto,SELECT SUM(sales) FROM orders;,SELECT * FROM users;,0.46,0.51,0.79,0.1,0.87,7,5,
43
+ Model_C,Table_3,WHERE,Ordina i clienti per spesa,SELECT SUM(sales) FROM orders;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.03,0.53,0.5,0.69,0.45,7,3,
44
+ Model_C,Table_3,JOIN,Elenca tutti gli utenti,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT AVG(price) FROM products;,0.27,0.94,0.41,0.07,0.61,7,14,
45
+ Model_A,Table_2,WHERE,Qual è la media dei prezzi?,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT SUM(sales) FROM orders;,0.04,0.8,0.59,0.06,0.18,9,10,
46
+ Model_C,Table_3,WHERE,Trova il totale delle vendite,SELECT * FROM users;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.99,0.17,0.2,0.74,0.59,3,14,
47
+ Model_B,Table_1,WHERE,Trova il totale delle vendite,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT * FROM users;,0.0,0.39,0.77,0.04,0.5,1,2,
48
+ Model_C,Table_3,SELECT,Ordina i clienti per spesa,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT * FROM customers ORDER BY total_spent DESC;,0.69,0.77,0.07,0.97,0.21,7,4,
49
+ Model_C,Table_2,JOIN,Quali clienti hanno speso di più?,SELECT SUM(sales) FROM orders;,SELECT SUM(sales) FROM orders;,0.05,0.6,0.47,0.08,0.83,10,11,
50
+ Model_B,Table_2,WHERE,Qual è la media dei prezzi?,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT SUM(sales) FROM orders;,0.18,0.8,0.01,0.26,0.79,4,9,
51
+ Model_A,Table_2,WHERE,Quali clienti hanno speso di più?,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM users;,0.83,0.69,0.25,0.27,0.73,6,15,
52
+ Model_C,Table_2,GROUPBY,Qual è la media dei prezzi?,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM users;,0.8,0.0,0.2,0.11,0.09,6,10,
53
+ Model_A,Table_1,JOIN,Qual è la media dei prezzi?,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT * FROM customers ORDER BY total_spent DESC;,0.33,0.0,0.01,0.65,0.38,12,2,
54
+ Model_B,Table_2,GROUPBY,Trova il totale delle vendite,SELECT AVG(price) FROM products;,SELECT AVG(price) FROM products;,0.3,0.79,0.37,0.66,0.07,2,14,
55
+ Model_A,Table_1,GROUPBY,Mostra il prodotto più venduto,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT AVG(price) FROM products;,0.8,0.21,0.4,0.93,0.61,3,8,
56
+ Model_B,Table_2,GROUPBY,Trova il totale delle vendite,SELECT AVG(price) FROM products;,SELECT * FROM customers ORDER BY total_spent DESC;,0.48,0.8,0.62,0.72,0.64,8,8,
57
+ Model_B,Table_3,ORDERBY,Elenca tutti gli utenti,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT SUM(sales) FROM orders;,0.94,0.61,0.07,0.42,0.74,9,6,0.63
58
+ Model_B,Table_2,WHERE,Ordina i clienti per spesa,SELECT * FROM users;,SELECT AVG(price) FROM products;,0.02,0.63,0.97,0.62,0.34,2,6,
59
+ Model_C,Table_2,WHERE,Trova il totale delle vendite,SELECT AVG(price) FROM products;,SELECT SUM(sales) FROM orders;,0.61,0.87,0.56,0.55,0.11,6,4,
60
+ Model_B,Table_2,WHERE,Elenca tutti gli utenti,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT AVG(price) FROM products;,0.4,0.44,0.72,0.01,0.78,2,1,
61
+ Model_A,Table_3,ORDERBY,Qual è la media dei prezzi?,SELECT SUM(sales) FROM orders;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.89,0.48,0.37,0.8,0.57,2,11,0.83
62
+ Model_A,Table_2,GROUPBY,Trova il totale delle vendite,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;","SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.95,0.51,0.17,0.55,0.06,10,4,
63
+ Model_A,Table_3,JOIN,Qual è la media dei prezzi?,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT * FROM customers ORDER BY total_spent DESC;,0.47,0.56,0.89,0.86,0.06,1,2,
64
+ Model_A,Table_2,ORDERBY,Mostra il prodotto più venduto,SELECT * FROM users;,SELECT * FROM customers ORDER BY total_spent DESC;,0.3,0.5,0.69,0.51,0.07,13,1,0.59
65
+ Model_C,Table_2,GROUPBY,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.81,0.37,0.51,0.6,0.45,3,9,
66
+ Model_B,Table_3,WHERE,Trova il totale delle vendite,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT AVG(price) FROM products;,0.62,0.65,0.26,0.52,0.05,4,3,
67
+ Model_A,Table_1,GROUPBY,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT * FROM users;,0.17,0.28,0.87,0.95,0.81,10,7,
68
+ Model_B,Table_1,GROUPBY,Trova il totale delle vendite,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.2,0.16,0.28,0.95,0.64,9,2,
69
+ Model_C,Table_2,JOIN,Qual è la media dei prezzi?,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT SUM(sales) FROM orders;,0.89,0.78,0.56,0.84,0.13,3,11,
70
+ Model_B,Table_1,ORDERBY,Trova il totale delle vendite,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;","SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.91,0.35,0.97,0.99,0.97,12,5,0.9
71
+ Model_A,Table_1,WHERE,Elenca tutti gli utenti,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT SUM(sales) FROM orders;,0.59,0.72,0.77,0.64,0.75,14,2,
72
+ Model_B,Table_2,JOIN,Ordina i clienti per spesa,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT * FROM users;,0.05,0.3,0.37,0.22,0.31,6,6,
73
+ Model_C,Table_3,JOIN,Mostra il prodotto più venduto,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.69,0.34,0.94,0.22,0.94,11,7,
74
+ Model_B,Table_2,JOIN,Qual è la media dei prezzi?,SELECT SUM(sales) FROM orders;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.4,0.79,0.72,0.82,0.98,6,5,
75
+ Model_C,Table_1,ORDERBY,Elenca tutti gli utenti,SELECT * FROM users;,SELECT SUM(sales) FROM orders;,0.57,0.71,0.04,0.32,0.55,12,0,0.43
76
+ Model_A,Table_1,SELECT,Quali clienti hanno speso di più?,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM users;,0.84,0.62,0.32,0.28,0.16,2,5,
77
+ Model_A,Table_3,ORDERBY,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT * FROM customers ORDER BY total_spent DESC;,0.77,0.81,0.47,0.46,0.82,2,10,0.66
78
+ Model_C,Table_2,ORDERBY,Trova il totale delle vendite,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.86,0.28,0.14,0.75,0.37,13,14,0.25
79
+ Model_A,Table_1,ORDERBY,Qual è la media dei prezzi?,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.9,0.93,0.89,0.88,0.25,13,3,0.92
80
+ Model_C,Table_3,JOIN,Quali clienti hanno speso di più?,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.59,0.9,0.12,0.58,0.69,14,2,
81
+ Model_C,Table_1,JOIN,Quali clienti hanno speso di più?,SELECT SUM(sales) FROM orders;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.72,0.2,0.67,0.36,0.42,5,14,
82
+ Model_A,Table_2,JOIN,Trova il totale delle vendite,SELECT * FROM users;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.17,0.74,0.12,0.4,0.27,11,15,
83
+ Model_A,Table_1,SELECT,Elenca tutti gli utenti,SELECT AVG(price) FROM products;,SELECT * FROM customers ORDER BY total_spent DESC;,0.34,0.59,0.8,0.15,0.58,11,3,
84
+ Model_A,Table_1,ORDERBY,Ordina i clienti per spesa,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.4,0.64,0.18,0.22,0.03,15,8,0.33
85
+ Model_C,Table_2,ORDERBY,Qual è la media dei prezzi?,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT * FROM customers ORDER BY total_spent DESC;,0.32,0.31,0.95,0.97,0.21,7,6,0.02
86
+ Model_B,Table_2,JOIN,Elenca tutti gli utenti,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;","SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.37,0.41,0.36,0.3,0.6,11,10,
87
+ Model_A,Table_1,SELECT,Trova il totale delle vendite,SELECT * FROM users;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.5,0.88,0.64,0.54,0.63,8,11,
88
+ Model_C,Table_1,GROUPBY,Trova il totale delle vendite,SELECT * FROM users;,SELECT * FROM users;,0.19,0.6,0.4,0.49,0.21,13,3,
89
+ Model_C,Table_3,JOIN,Trova il totale delle vendite,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM users;,0.4,0.92,0.33,0.19,0.67,13,3,
90
+ Model_B,Table_3,JOIN,Elenca tutti gli utenti,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT * FROM users;,0.76,0.67,0.51,1.0,0.22,10,14,
91
+ Model_C,Table_1,ORDERBY,Quali clienti hanno speso di più?,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.07,0.72,0.69,0.21,0.15,15,3,0.63
92
+ Model_C,Table_1,GROUPBY,Ordina i clienti per spesa,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.71,0.54,0.24,0.35,0.23,7,2,
93
+ Model_C,Table_1,WHERE,Elenca tutti gli utenti,SELECT AVG(price) FROM products;,SELECT * FROM users;,0.83,0.36,0.52,0.18,0.06,10,7,
94
+ Model_C,Table_2,JOIN,Elenca tutti gli utenti,SELECT * FROM users;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.19,0.01,0.34,0.79,0.29,14,7,
95
+ Model_C,Table_1,GROUPBY,Trova il totale delle vendite,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.14,0.74,0.76,0.88,0.01,5,9,
96
+ Model_A,Table_2,JOIN,Ordina i clienti per spesa,SELECT SUM(sales) FROM orders;,SELECT * FROM users;,0.66,0.36,0.97,0.46,0.85,5,1,
97
+ Model_A,Table_2,ORDERBY,Mostra il prodotto più venduto,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.39,0.8,0.88,0.06,0.84,4,14,0.96
98
+ Model_C,Table_3,GROUPBY,Elenca tutti gli utenti,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.17,0.93,0.64,0.84,0.72,6,8,
99
+ Model_A,Table_3,ORDERBY,Elenca tutti gli utenti,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.55,0.16,0.54,0.87,0.55,8,1,0.94
100
+ Model_B,Table_1,ORDERBY,Ordina i clienti per spesa,SELECT SUM(sales) FROM orders;,SELECT * FROM customers ORDER BY total_spent DESC;,0.92,0.59,0.25,0.57,0.08,2,5,0.47
101
+ Model_C,Table_1,JOIN,Mostra il prodotto più venduto,SELECT * FROM users;,SELECT * FROM customers ORDER BY total_spent DESC;,0.7,0.88,0.65,0.64,0.92,13,13,
utilities.py CHANGED
@@ -3,8 +3,9 @@ import pandas as pd
3
  import sqlite3
4
  import gradio as gr
5
  import os
 
6
 
7
- def carica_sqlite(file_path):
8
  conn = sqlite3.connect(file_path)
9
  cursor = conn.cursor()
10
  cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
@@ -17,7 +18,11 @@ def carica_sqlite(file_path):
17
  df = pd.read_sql_query(f"SELECT * FROM {nome_tabella}", conn)
18
  dfs[nome_tabella] = df
19
  conn.close()
20
- data_output = {'data_frames': dfs,'db': conn}
 
 
 
 
21
  return data_output
22
 
23
  # Funzione per leggere un file CSV
@@ -37,7 +42,7 @@ def load_data(data_path : str, db_name : str):
37
  data_output = {'data_frames': {} ,'db': None}
38
  table_name = os.path.splitext(os.path.basename(data_path))[0]
39
  if data_path.endswith(".sqlite") :
40
- data_output = carica_sqlite(data_path)
41
  elif data_path.endswith(".csv"):
42
  data_output['data_frames'] = {f"{table_name}_table" : carica_csv(data_path)}
43
  elif data_path.endswith(".xlsx"):
 
3
  import sqlite3
4
  import gradio as gr
5
  import os
6
+ from qatch.connectors.sqlite_connector import SqliteConnector
7
 
8
+ def carica_sqlite(file_path, db_id):
9
  conn = sqlite3.connect(file_path)
10
  cursor = conn.cursor()
11
  cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
 
18
  df = pd.read_sql_query(f"SELECT * FROM {nome_tabella}", conn)
19
  dfs[nome_tabella] = df
20
  conn.close()
21
+ data_output = {'data_frames': dfs,'db': None}
22
+ # data_output['db'] = SqliteConnector(
23
+ # relative_db_path=file_path,
24
+ # db_name=db_id,
25
+ # )
26
  return data_output
27
 
28
  # Funzione per leggere un file CSV
 
42
  data_output = {'data_frames': {} ,'db': None}
43
  table_name = os.path.splitext(os.path.basename(data_path))[0]
44
  if data_path.endswith(".sqlite") :
45
+ data_output = carica_sqlite(data_path, db_name)
46
  elif data_path.endswith(".csv"):
47
  data_output['data_frames'] = {f"{table_name}_table" : carica_csv(data_path)}
48
  elif data_path.endswith(".xlsx"):
utils_get_db_tables_info.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sqlite3
3
+ import re
4
+
5
+
6
+ def utils_extract_db_schema_as_string(
7
+ db_id, base_path, normalize=False, sql: str | None = None
8
+ ):
9
+ """
10
+ Extracts the full schema of an SQLite database into a single string.
11
+
12
+ :param base_path: Base path where the database is located.
13
+ :param db_id: Path to the SQLite database file.
14
+ :param normalize: Whether to normalize the schema string.
15
+ :param sql: Optional SQL query to filter specific tables.
16
+ :return: Schema of the database as a single string.
17
+ """
18
+ #db_path = os.path.join(base_path, db_id, f"{db_id}.sqlite")
19
+
20
+ # Connect to the SQLite database
21
+
22
+ #if not os.path.exists(db_path):
23
+ # raise FileNotFoundError(f"Database file not found at: {db_path}")
24
+
25
+ connection = sqlite3.connect(base_path)
26
+ cursor = connection.cursor()
27
+
28
+ # Get the schema entries based on the provided SQL query
29
+ schema_entries = _get_schema_entries(cursor, sql)
30
+
31
+ # Combine all schema definitions into a single string
32
+ schema_string = _combine_schema_entries(schema_entries, normalize)
33
+
34
+ return schema_string
35
+
36
+
37
+ def _get_schema_entries(cursor, sql):
38
+ """
39
+ Retrieves schema entries from the SQLite database.
40
+
41
+ :param cursor: SQLite cursor object.
42
+ :param sql: Optional SQL query to filter specific tables.
43
+ :return: List of schema entries.
44
+ """
45
+ if sql:
46
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
47
+ tables = [tbl[0] for tbl in cursor.fetchall() if tbl[0].lower() in sql.lower()]
48
+ if tables:
49
+ tbl_names = ", ".join(f"'{tbl}'" for tbl in tables)
50
+ query = f"SELECT sql FROM sqlite_master WHERE type='table' AND name IN ({tbl_names}) AND sql IS NOT NULL;"
51
+ else:
52
+ query = "SELECT sql FROM sqlite_master WHERE sql IS NOT NULL;"
53
+ else:
54
+ query = "SELECT sql FROM sqlite_master WHERE sql IS NOT NULL;"
55
+
56
+ cursor.execute(query)
57
+ return cursor.fetchall()
58
+
59
+
60
+ def _combine_schema_entries(schema_entries, normalize):
61
+ """
62
+ Combines schema entries into a single string.
63
+
64
+ :param schema_entries: List of schema entries.
65
+ :param normalize: Whether to normalize the schema string.
66
+ :return: Combined schema string.
67
+ """
68
+ if not normalize:
69
+ return "\n".join(entry[0] for entry in schema_entries)
70
+
71
+ return "\n".join(
72
+ re.sub(
73
+ r"\s*\)",
74
+ ")",
75
+ re.sub(
76
+ r"\(\s*",
77
+ "(",
78
+ re.sub(
79
+ r"(`\w+`)\s+\(",
80
+ r"\1(",
81
+ re.sub(
82
+ r"^\s*([^\s(]+)",
83
+ r"`\1`",
84
+ re.sub(
85
+ r"\s+",
86
+ " ",
87
+ entry[0].replace("CREATE TABLE", "").replace("\t", " "),
88
+ ).strip(),
89
+ ),
90
+ ),
91
+ ),
92
+ )
93
+ for entry in schema_entries
94
+ )