Spaces:
Runtime error
Runtime error
More stable version. Link all acc, but still miss prediction (#6)
Browse files- More stable version. Link all acc, but still miss prediction (696b8fe88b9890ac834f2892408dc569ef849c2f)
Co-authored-by: Francesco Giannuzzo <[email protected]>
- app.py +828 -559
- style.css +6 -1
- test_results.csv +101 -0
- utilities.py +8 -3
- utils_get_db_tables_info.py +94 -0
app.py
CHANGED
@@ -1,560 +1,829 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
import
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
'
|
34 |
-
'
|
35 |
-
'
|
36 |
-
}
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
'
|
50 |
-
|
51 |
-
|
52 |
-
}
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
selected_inputs
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
input_data["
|
74 |
-
input_data["
|
75 |
-
input_data["
|
76 |
-
input_data["data"]
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
input_data["
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
gr.update(
|
197 |
-
gr.update(visible=True, open=True)
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
#
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
reset_data =
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
)
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
)
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
560 |
interface.launch()
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from qatch.connectors.sqlite_connector import SqliteConnector
|
6 |
+
from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
|
7 |
+
from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
|
8 |
+
#from predictor.orchestrator_predictor import OrchestratorPredictor
|
9 |
+
import utils_get_db_tables_info
|
10 |
+
import utilities as us
|
11 |
+
import time
|
12 |
+
import plotly.express as px
|
13 |
+
import plotly.graph_objects as go
|
14 |
+
import plotly.colors as pc
|
15 |
+
|
16 |
+
with open('style.css', 'r') as file:
|
17 |
+
css = file.read()
|
18 |
+
|
19 |
+
# DataFrame di default
|
20 |
+
df_default = pd.DataFrame({
|
21 |
+
'Name': ['Alice', 'Bob', 'Charlie'],
|
22 |
+
'Age': [25, 30, 35],
|
23 |
+
'City': ['New York', 'Los Angeles', 'Chicago']
|
24 |
+
})
|
25 |
+
|
26 |
+
models_path = "models.csv"
|
27 |
+
|
28 |
+
# Variabile globale per tenere traccia dei dati correnti
|
29 |
+
df_current = df_default.copy()
|
30 |
+
|
31 |
+
input_data = {
|
32 |
+
'input_method': "",
|
33 |
+
'data_path': "",
|
34 |
+
'db_name': "",
|
35 |
+
'data': {
|
36 |
+
'data_frames': {}, # dictionary of dataframes
|
37 |
+
'db': None # SQLITE3 database object
|
38 |
+
},
|
39 |
+
'models': []
|
40 |
+
}
|
41 |
+
|
42 |
+
def load_data(file, path, use_default):
|
43 |
+
"""Carica i dati da un file, un percorso o usa il DataFrame di default."""
|
44 |
+
global df_current
|
45 |
+
if use_default:
|
46 |
+
input_data["input_method"] = 'default'
|
47 |
+
input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable.sqlite")
|
48 |
+
input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0]
|
49 |
+
input_data["data"]['data_frames'] = {'MyTable': df_current}
|
50 |
+
|
51 |
+
if( input_data["data"]['data_frames']):
|
52 |
+
table2primary_key = {}
|
53 |
+
for table_name, df in input_data["data"]['data_frames'].items():
|
54 |
+
# Assign primary keys for each table
|
55 |
+
table2primary_key[table_name] = 'id'
|
56 |
+
input_data["data"]["db"] = SqliteConnector(
|
57 |
+
relative_db_path=input_data["data_path"],
|
58 |
+
db_name=input_data["db_name"],
|
59 |
+
tables= input_data["data"]['data_frames'],
|
60 |
+
table2primary_key=table2primary_key
|
61 |
+
)
|
62 |
+
|
63 |
+
df_current = df_default.copy() # Ripristina i dati di default
|
64 |
+
return input_data["data"]['data_frames']
|
65 |
+
|
66 |
+
selected_inputs = sum([file is not None, bool(path), use_default])
|
67 |
+
if selected_inputs > 1:
|
68 |
+
return 'Errore: Selezionare solo un metodo di input alla volta.'
|
69 |
+
|
70 |
+
if file is not None:
|
71 |
+
try:
|
72 |
+
input_data["input_method"] = 'uploaded_file'
|
73 |
+
input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
|
74 |
+
input_data["data_path"] = os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
|
75 |
+
input_data["data"] = us.load_data(file, input_data["db_name"])
|
76 |
+
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
|
77 |
+
if( input_data["data"]['data_frames']):
|
78 |
+
table2primary_key = {}
|
79 |
+
for table_name, df in input_data["data"]['data_frames'].items():
|
80 |
+
# Assign primary keys for each table
|
81 |
+
table2primary_key[table_name] = 'id'
|
82 |
+
input_data["data"]["db"] = SqliteConnector(
|
83 |
+
relative_db_path=input_data["data_path"],
|
84 |
+
db_name=input_data["db_name"],
|
85 |
+
tables= input_data["data"]['data_frames'],
|
86 |
+
table2primary_key=table2primary_key
|
87 |
+
)
|
88 |
+
return input_data["data"]['data_frames']
|
89 |
+
except Exception as e:
|
90 |
+
return f'Errore nel caricamento del file: {e}'
|
91 |
+
|
92 |
+
"""
|
93 |
+
if path:
|
94 |
+
if not os.path.exists(path):
|
95 |
+
return 'Errore: Il percorso specificato non esiste.'
|
96 |
+
try:
|
97 |
+
input_data["input_method"] = 'uploaded_file'
|
98 |
+
input_data["data_path"] = path
|
99 |
+
input_data["db_name"] = os.path.splitext(os.path.basename(path))[0]
|
100 |
+
input_data["data"] = us.load_data(input_data["data_path"], input_data["db_name"])
|
101 |
+
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
|
102 |
+
|
103 |
+
return input_data["data"]['data_frames']
|
104 |
+
except Exception as e:
|
105 |
+
return f'Errore nel caricamento del file dal percorso: {e}'
|
106 |
+
"""
|
107 |
+
|
108 |
+
|
109 |
+
return input_data["data"]['data_frames']
|
110 |
+
|
111 |
+
def preview_default(use_default):
|
112 |
+
"""Mostra il DataFrame di default se il checkbox è selezionato."""
|
113 |
+
if use_default:
|
114 |
+
return df_default # Mostra il DataFrame di default
|
115 |
+
return df_current # Mostra il DataFrame corrente, che potrebbe essere stato modificato
|
116 |
+
|
117 |
+
def update_df(new_df):
|
118 |
+
"""Aggiorna il DataFrame corrente."""
|
119 |
+
global df_current # Usa la variabile globale per aggiornarla
|
120 |
+
df_current = new_df
|
121 |
+
return df_current
|
122 |
+
|
123 |
+
def open_accordion(target):
|
124 |
+
# Apre uno e chiude l'altro
|
125 |
+
if target == "reset":
|
126 |
+
df_current = df_default.copy()
|
127 |
+
input_data['input_method'] = ""
|
128 |
+
input_data['data_path'] = ""
|
129 |
+
input_data['db_name'] = ""
|
130 |
+
input_data['data']['data_frames'] = {}
|
131 |
+
input_data['data']['db'] = None
|
132 |
+
input_data['models'] = []
|
133 |
+
return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value=False), gr.update(value=None)
|
134 |
+
elif target == "model_selection":
|
135 |
+
return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False)
|
136 |
+
|
137 |
+
# Interfaccia Gradio
|
138 |
+
|
139 |
+
with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
|
140 |
+
gr.Markdown("# QATCH")
|
141 |
+
data_state = gr.State(None) # Memorizza i dati caricati
|
142 |
+
upload_acc = gr.Accordion("Upload your data section", open=True, visible=True)
|
143 |
+
select_table_acc = gr.Accordion("Select tables", open=False, visible=False)
|
144 |
+
select_model_acc = gr.Accordion("Select models", open=False, visible=False)
|
145 |
+
qatch_acc = gr.Accordion("QATCH execution", open=False, visible=False)
|
146 |
+
metrics_acc = gr.Accordion("Metrics", open=False, visible=False)
|
147 |
+
#metrics_acc = gr.Accordion("Metrics", open=False, visible=False, render=False)
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
#################################
|
152 |
+
# PARTE DI INSERIMENTO DEL DB #
|
153 |
+
#################################
|
154 |
+
with upload_acc:
|
155 |
+
gr.Markdown("## Caricamento dei Dati")
|
156 |
+
|
157 |
+
file_input = gr.File(label="Trascina e rilascia un file", file_types=[".csv", ".xlsx", ".sqlite"])
|
158 |
+
with gr.Row():
|
159 |
+
default_checkbox = gr.Checkbox(label="Usa DataFrame di default")
|
160 |
+
preview_output = gr.DataFrame(interactive=True, visible=True, value=df_default)
|
161 |
+
submit_button = gr.Button("Carica Dati", interactive=False) # Disabilitato di default
|
162 |
+
output = gr.JSON(visible=False) # Output dizionario
|
163 |
+
|
164 |
+
# Funzione per abilitare il bottone se sono presenti dati da caricare
|
165 |
+
def enable_submit(file, use_default):
|
166 |
+
return gr.update(interactive=bool(file or use_default))
|
167 |
+
|
168 |
+
# Funzione per deselezionare il checkbox se viene caricato un file
|
169 |
+
def deselect_default(file):
|
170 |
+
if file:
|
171 |
+
return gr.update(value=False)
|
172 |
+
return gr.update()
|
173 |
+
|
174 |
+
# Abilita il bottone quando i campi di input sono valorizzati
|
175 |
+
file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
|
176 |
+
default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
|
177 |
+
|
178 |
+
# Mostra l'anteprima del DataFrame di default quando il checkbox è selezionato
|
179 |
+
default_checkbox.change(fn=preview_default, inputs=[default_checkbox], outputs=[preview_output])
|
180 |
+
preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output])
|
181 |
+
|
182 |
+
# Deseleziona il checkbox quando viene caricato un file
|
183 |
+
file_input.change(fn=deselect_default, inputs=[file_input], outputs=[default_checkbox])
|
184 |
+
|
185 |
+
def handle_output(file, use_default):
|
186 |
+
"""Gestisce l'output quando si preme il bottone 'Carica Dati'."""
|
187 |
+
result = load_data(file, None, use_default)
|
188 |
+
|
189 |
+
if isinstance(result, dict): # Se result è un dizionario di DataFrame
|
190 |
+
if len(result) == 1: # Se c'è solo una tabella
|
191 |
+
return (
|
192 |
+
gr.update(visible=False), # Nasconde l'output JSON
|
193 |
+
result, # Salva lo stato dei dati
|
194 |
+
gr.update(visible=False), # Nasconde la selezione tabella
|
195 |
+
result, # Mantiene lo stato dei dati
|
196 |
+
gr.update(interactive=False), # Disabilita il pulsante di submit
|
197 |
+
gr.update(visible=True, open=True), # Passa direttamente a select_model_acc
|
198 |
+
gr.update(visible=True, open=False)
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
return (
|
202 |
+
gr.update(visible=False),
|
203 |
+
result,
|
204 |
+
gr.update(open=True, visible=True),
|
205 |
+
result,
|
206 |
+
gr.update(interactive=False),
|
207 |
+
gr.update(visible=False), # Mantiene il comportamento attuale
|
208 |
+
gr.update(visible=True, open=True)
|
209 |
+
)
|
210 |
+
else:
|
211 |
+
return (
|
212 |
+
gr.update(visible=False),
|
213 |
+
None,
|
214 |
+
gr.update(open=False, visible=True),
|
215 |
+
None,
|
216 |
+
gr.update(interactive=True),
|
217 |
+
gr.update(visible=False),
|
218 |
+
gr.update(visible=True, open=True)
|
219 |
+
)
|
220 |
+
|
221 |
+
submit_button.click(
|
222 |
+
fn=handle_output,
|
223 |
+
inputs=[file_input, default_checkbox],
|
224 |
+
outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc]
|
225 |
+
)
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
######################################
|
230 |
+
# PARTE DI SELEZIONE DELLE TABELLE #
|
231 |
+
######################################
|
232 |
+
with select_table_acc:
|
233 |
+
|
234 |
+
table_selector = gr.CheckboxGroup(choices=[], label="Seleziona le tabelle da visualizzare", value=[])
|
235 |
+
table_outputs = [gr.DataFrame(label=f"Tabella {i+1}", interactive=True, visible=False) for i in range(5)]
|
236 |
+
selected_table_names = gr.Textbox(label="Tabelle selezionate", visible=False, interactive=False)
|
237 |
+
|
238 |
+
# Bottone di selezione modelli (inizialmente disabilitato)
|
239 |
+
open_model_selection = gr.Button("Choose your models", interactive=False)
|
240 |
+
|
241 |
+
def update_table_list(data):
|
242 |
+
"""Aggiorna dinamicamente la lista delle tabelle disponibili."""
|
243 |
+
if isinstance(data, dict) and data:
|
244 |
+
table_names = list(data.keys()) # Ritorna solo i nomi delle tabelle
|
245 |
+
return gr.update(choices=table_names, value=[]) # Reset delle selezioni
|
246 |
+
return gr.update(choices=[], value=[])
|
247 |
+
|
248 |
+
def show_selected_tables(data, selected_tables):
|
249 |
+
"""Mostra solo le tabelle selezionate dall'utente e abilita il bottone."""
|
250 |
+
updates = []
|
251 |
+
if isinstance(data, dict) and data:
|
252 |
+
available_tables = list(data.keys()) # Nomi effettivamente disponibili
|
253 |
+
selected_tables = [t for t in selected_tables if t in available_tables] # Filtra selezioni valide
|
254 |
+
|
255 |
+
tables = {name: data[name] for name in selected_tables} # Filtra i DataFrame
|
256 |
+
|
257 |
+
for i, (name, df) in enumerate(tables.items()):
|
258 |
+
updates.append(gr.update(value=df, label=f"Tabella: {name}", visible=True))
|
259 |
+
|
260 |
+
# Se ci sono meno di 5 tabelle, nascondi gli altri DataFrame
|
261 |
+
for _ in range(len(tables), 5):
|
262 |
+
updates.append(gr.update(visible=False))
|
263 |
+
else:
|
264 |
+
updates = [gr.update(value=pd.DataFrame(), visible=False) for _ in range(5)]
|
265 |
+
|
266 |
+
# Abilitare/disabilitare il bottone in base alle selezioni
|
267 |
+
button_state = bool(selected_tables) # True se almeno una tabella è selezionata, False altrimenti
|
268 |
+
updates.append(gr.update(interactive=button_state)) # Aggiorna stato bottone
|
269 |
+
|
270 |
+
return updates
|
271 |
+
|
272 |
+
def show_selected_table_names(selected_tables):
|
273 |
+
"""Mostra i nomi delle tabelle selezionate quando si preme il bottone."""
|
274 |
+
if selected_tables:
|
275 |
+
return gr.update(value=", ".join(selected_tables), visible=False)
|
276 |
+
return gr.update(value="", visible=False)
|
277 |
+
|
278 |
+
# Aggiorna automaticamente la lista delle checkbox quando `data_state` cambia
|
279 |
+
data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector])
|
280 |
+
|
281 |
+
# Aggiorna le tabelle visibili e lo stato del bottone in base alle selezioni dell'utente
|
282 |
+
table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
|
283 |
+
|
284 |
+
# Mostra la lista delle tabelle selezionate quando si preme "Choose your models"
|
285 |
+
open_model_selection.click(fn=show_selected_table_names, inputs=[table_selector], outputs=[selected_table_names])
|
286 |
+
open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
|
287 |
+
|
288 |
+
|
289 |
+
|
290 |
+
####################################
|
291 |
+
# PARTE DI SELEZIONE DEL MODELLO #
|
292 |
+
####################################
|
293 |
+
with select_model_acc:
|
294 |
+
gr.Markdown("**Model Selection**")
|
295 |
+
|
296 |
+
# Supponiamo che `us.read_models_csv` restituisca anche il percorso dell'immagine
|
297 |
+
model_list_dict = us.read_models_csv(models_path)
|
298 |
+
model_list = [model["code"] for model in model_list_dict]
|
299 |
+
model_images = [model["image_path"] for model in model_list_dict]
|
300 |
+
|
301 |
+
model_checkboxes = []
|
302 |
+
rows = []
|
303 |
+
|
304 |
+
# Creazione dinamica di checkbox con immagini (3 per riga)
|
305 |
+
for i in range(0, len(model_list), 3):
|
306 |
+
with gr.Row():
|
307 |
+
cols = []
|
308 |
+
for j in range(3):
|
309 |
+
if i + j < len(model_list):
|
310 |
+
model = model_list[i + j]
|
311 |
+
image_path = model_images[i + j]
|
312 |
+
with gr.Column():
|
313 |
+
gr.Image(image_path, show_label=False)
|
314 |
+
checkbox = gr.Checkbox(label=model, value=False)
|
315 |
+
model_checkboxes.append(checkbox)
|
316 |
+
cols.append(checkbox)
|
317 |
+
rows.append(cols)
|
318 |
+
|
319 |
+
selected_models_output = gr.JSON(visible=False)
|
320 |
+
|
321 |
+
# Funzione per ottenere i modelli selezionati
|
322 |
+
def get_selected_models(*model_selections):
|
323 |
+
selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
|
324 |
+
input_data['models'] = selected_models
|
325 |
+
button_state = bool(selected_models) # True se almeno un modello è selezionato, False altrimenti
|
326 |
+
return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
|
327 |
+
|
328 |
+
# Bottone di submit (inizialmente disabilitato)
|
329 |
+
submit_models_button = gr.Button("Submit Models", interactive=False)
|
330 |
+
|
331 |
+
# Collegamento dei checkbox agli eventi di selezione
|
332 |
+
for checkbox in model_checkboxes:
|
333 |
+
checkbox.change(
|
334 |
+
fn=get_selected_models,
|
335 |
+
inputs=model_checkboxes,
|
336 |
+
outputs=[selected_models_output, select_model_acc, submit_models_button]
|
337 |
+
)
|
338 |
+
|
339 |
+
submit_models_button.click(
|
340 |
+
fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
|
341 |
+
inputs=model_checkboxes,
|
342 |
+
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
343 |
+
)
|
344 |
+
|
345 |
+
reset_data = gr.Button("Back to upload data section")
|
346 |
+
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
|
347 |
+
|
348 |
+
|
349 |
+
###############################
|
350 |
+
# PARTE DI ESECUZIONE QATCH #
|
351 |
+
###############################
|
352 |
+
with qatch_acc:
|
353 |
+
def change_text(text):
|
354 |
+
return text
|
355 |
+
def qatch_flow():
|
356 |
+
orchestrator_generator = OrchestratorGenerator()
|
357 |
+
#TODO add to target_df column target_df["columns_used"], tables selection
|
358 |
+
#print(input_data['data']['db'])
|
359 |
+
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'])
|
360 |
+
|
361 |
+
schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
362 |
+
db_id = input_data["db_name"],
|
363 |
+
base_path = input_data["data_path"],
|
364 |
+
normalize=False,
|
365 |
+
sql=None
|
366 |
+
)
|
367 |
+
|
368 |
+
# TODO QUERY PREDICTION
|
369 |
+
predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
|
370 |
+
metrics_conc = pd.DataFrame()
|
371 |
+
for model in input_data["models"]:
|
372 |
+
for index, row in target_df.iterrows():
|
373 |
+
if len(target_df) != 0: load_value = f"##Loading... {round((index + 1) / len(target_df) * 100, 2)}%"
|
374 |
+
else: load_value = "##Loading..."
|
375 |
+
question = row['query']
|
376 |
+
#yield gr.Textbox(question), gr.Textbox(), *[predictions_dict[model] for model in input_data["models"]], None
|
377 |
+
yield gr.Markdown(value=load_value), gr.Textbox(question), gr.Textbox(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
378 |
+
start_time = time.time()
|
379 |
+
|
380 |
+
# Simulazione della predizione
|
381 |
+
time.sleep(0.03)
|
382 |
+
prediction = "Prediction_placeholder"
|
383 |
+
|
384 |
+
# Esegui la predizione reale qui
|
385 |
+
# prediction = predictor.run(model, schema_text, question)
|
386 |
+
|
387 |
+
end_time = time.time()
|
388 |
+
# Crea una nuova riga come dataframe
|
389 |
+
new_row = pd.DataFrame([{
|
390 |
+
'id': index,
|
391 |
+
'question': question,
|
392 |
+
'predicted_sql': prediction,
|
393 |
+
'time': end_time - start_time,
|
394 |
+
'query': row["query"],
|
395 |
+
'db_path': input_data["data_path"]
|
396 |
+
}]).dropna(how="all") # Rimuove solo righe completamente vuote
|
397 |
+
#TODO con un for
|
398 |
+
for col in target_df.columns:
|
399 |
+
if col not in new_row.columns:
|
400 |
+
new_row[col] = row[col]
|
401 |
+
# Aggiorna il dataframe corrispondente al modello man mano
|
402 |
+
if not new_row.empty:
|
403 |
+
predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
|
404 |
+
#yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
|
405 |
+
yield gr.Markdown(value=load_value), gr.Textbox(), gr.Textbox(prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
|
406 |
+
|
407 |
+
#END
|
408 |
+
evaluator = OrchestratorEvaluator()
|
409 |
+
for model in input_data["models"]:
|
410 |
+
metrics_df_model = evaluator.evaluate_df(
|
411 |
+
df=predictions_dict[model],
|
412 |
+
target_col_name="query", #'<target_column_name>',
|
413 |
+
prediction_col_name="predicted_sql", #'<prediction_column_name>',
|
414 |
+
db_path_name= "db_path", #'<db_path_column_name>'
|
415 |
+
)
|
416 |
+
metrics_df_model['model'] = model
|
417 |
+
metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
|
418 |
+
|
419 |
+
if 'valid_efficiency_score' not in metrics_conc.columns:
|
420 |
+
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
|
421 |
+
|
422 |
+
yield gr.Markdown(), gr.Textbox(), gr.Textbox(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
423 |
+
|
424 |
+
#Loading Bar
|
425 |
+
with gr.Row():
|
426 |
+
#progress = gr.Progress()
|
427 |
+
variable = gr.Markdown()
|
428 |
+
|
429 |
+
#NL -> MODEL -> Generated Quesy
|
430 |
+
with gr.Row():
|
431 |
+
with gr.Column():
|
432 |
+
question_display = gr.Textbox()
|
433 |
+
with gr.Column():
|
434 |
+
gr.Image()
|
435 |
+
with gr.Column():
|
436 |
+
prediction_display = gr.Textbox()
|
437 |
+
|
438 |
+
dataframe_per_model = {}
|
439 |
+
|
440 |
+
with gr.Tabs() as model_tabs:
|
441 |
+
#for model in input_data["models"]:
|
442 |
+
for model in model_list:
|
443 |
+
#TODO fix model tabs
|
444 |
+
with gr.TabItem(model):
|
445 |
+
gr.Markdown(f"**Results for {model}**")
|
446 |
+
dataframe_per_model[model] = gr.DataFrame()
|
447 |
+
|
448 |
+
|
449 |
+
#question_display.change(fn=change_text, inputs=[gr.State(question)], outputs=[question_display])
|
450 |
+
selected_models_display = gr.JSON(label="Modelli selezionati")
|
451 |
+
metrics_df = gr.DataFrame(visible=False)
|
452 |
+
metrics_df_out= gr.DataFrame(visible=False)
|
453 |
+
|
454 |
+
submit_models_button.click(
|
455 |
+
fn=qatch_flow,
|
456 |
+
inputs=[],
|
457 |
+
outputs=[variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
458 |
+
)
|
459 |
+
|
460 |
+
submit_models_button.click(
|
461 |
+
fn=lambda: gr.update(value=input_data),
|
462 |
+
outputs=[selected_models_display]
|
463 |
+
)
|
464 |
+
#Funziona per METRICS
|
465 |
+
metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
|
466 |
+
|
467 |
+
# def change_tab(selected_models_output, model_tabs):
|
468 |
+
# for model in model_list:
|
469 |
+
# if model in selected_models_output:
|
470 |
+
# pass#model_tabs[model].visible = True
|
471 |
+
# else:
|
472 |
+
# pass#model_tabs[model].visible = False
|
473 |
+
# return model_tabs
|
474 |
+
|
475 |
+
# selected_models_output.change(fn=change_tab, inputs=[selected_models_output, model_tabs], outputs=[])
|
476 |
+
|
477 |
+
proceed_to_metrics_button = gr.Button("Proceed to Metrics")
|
478 |
+
proceed_to_metrics_button.click(
|
479 |
+
fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
|
480 |
+
outputs=[qatch_acc, metrics_acc]
|
481 |
+
)
|
482 |
+
|
483 |
+
reset_data = gr.Button("Back to upload data section")
|
484 |
+
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
|
485 |
+
|
486 |
+
|
487 |
+
|
488 |
+
#######################################
|
489 |
+
# METRICS VISUALIZATION SECTION #
|
490 |
+
#######################################
|
491 |
+
with metrics_acc:
|
492 |
+
#confirmation_text = gr.Markdown("## Metrics successfully loaded")
|
493 |
+
|
494 |
+
data_path = 'test_results.csv'
|
495 |
+
|
496 |
+
@gr.render(inputs=metrics_df_out)
|
497 |
+
def function_metrics(metrics_df_out):
|
498 |
+
def load_data_csv_es():
|
499 |
+
return pd.read_csv(data_path)
|
500 |
+
#return metrics_df_out
|
501 |
+
|
502 |
+
def calculate_average_metrics(df, selected_metrics):
|
503 |
+
df['avg_metric'] = df[selected_metrics].mean(axis=1)
|
504 |
+
return df
|
505 |
+
|
506 |
+
def generate_model_colors():
|
507 |
+
"""Generates a unique color map for models in the dataset."""
|
508 |
+
df = load_data_csv_es()
|
509 |
+
unique_models = df['model'].unique() # Extract unique models
|
510 |
+
num_models = len(unique_models)
|
511 |
+
|
512 |
+
# Use the Plotly color scale (you can change it if needed)
|
513 |
+
color_palette = pc.qualitative.Plotly # ['#636EFA', '#EF553B', '#00CC96', ...]
|
514 |
+
|
515 |
+
# If there are more models than colors, cycle through them
|
516 |
+
colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)}
|
517 |
+
|
518 |
+
return colors
|
519 |
+
|
520 |
+
MODEL_COLORS = generate_model_colors()
|
521 |
+
|
522 |
+
# BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
|
523 |
+
def plot_metric(df, selected_metrics, group_by, selected_models):
|
524 |
+
df = df[df['model'].isin(selected_models)]
|
525 |
+
df = calculate_average_metrics(df, selected_metrics)
|
526 |
+
|
527 |
+
# Ensure the group_by value is always valid
|
528 |
+
if group_by not in [["tbl_name", "model"], ["model"]]:
|
529 |
+
group_by = ["tbl_name", "model"] # Default
|
530 |
+
|
531 |
+
avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index()
|
532 |
+
|
533 |
+
fig = px.bar(
|
534 |
+
avg_metrics,
|
535 |
+
x=group_by[0],
|
536 |
+
y='avg_metric',
|
537 |
+
color='model',
|
538 |
+
color_discrete_map=MODEL_COLORS,
|
539 |
+
barmode='group',
|
540 |
+
title=f'Average metric per {group_by[0]} 📊',
|
541 |
+
labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metric'},
|
542 |
+
template='plotly_dark'
|
543 |
+
)
|
544 |
+
|
545 |
+
return fig
|
546 |
+
|
547 |
+
def update_plot(selected_metrics, group_by, selected_models):
|
548 |
+
df = load_data_csv_es()
|
549 |
+
return plot_metric(df, selected_metrics, group_by, selected_models)
|
550 |
+
|
551 |
+
|
552 |
+
# RADAR CHART FOR AVERAGE METRICS PER MODEL WITH UPDATE FUNCTION
|
553 |
+
def plot_radar(df, selected_models):
|
554 |
+
# Filter only selected models
|
555 |
+
df = df[df['model'].isin(selected_models)]
|
556 |
+
|
557 |
+
# Select relevant metrics
|
558 |
+
selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
|
559 |
+
|
560 |
+
# Compute average metrics per test_category and model
|
561 |
+
df = calculate_average_metrics(df, selected_metrics)
|
562 |
+
avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
|
563 |
+
|
564 |
+
# Check if data is available
|
565 |
+
if avg_metrics.empty:
|
566 |
+
print("Error: No data available to compute averages.")
|
567 |
+
return go.Figure()
|
568 |
+
|
569 |
+
fig = go.Figure()
|
570 |
+
categories = avg_metrics['test_category'].unique()
|
571 |
+
|
572 |
+
for model in selected_models:
|
573 |
+
model_data = avg_metrics[avg_metrics['model'] == model]
|
574 |
+
|
575 |
+
# Build a list of values for each category (if a value is missing, set it to 0)
|
576 |
+
values = [
|
577 |
+
model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
|
578 |
+
if cat in model_data['test_category'].values else 0
|
579 |
+
for cat in categories
|
580 |
+
]
|
581 |
+
|
582 |
+
fig.add_trace(go.Scatterpolar(
|
583 |
+
r=values,
|
584 |
+
theta=categories,
|
585 |
+
fill='toself',
|
586 |
+
name=model,
|
587 |
+
line=dict(color=MODEL_COLORS.get(model, "gray"))
|
588 |
+
))
|
589 |
+
|
590 |
+
fig.update_layout(
|
591 |
+
polar=dict(radialaxis=dict(visible=True, range=[0, max(avg_metrics['avg_metric'].max(), 0.5)])), # Set the radar range
|
592 |
+
title='❇️ Radar Plot of Metrics per Model (Average per Category) ❇️ ',
|
593 |
+
template='plotly_dark',
|
594 |
+
width=700, height=700
|
595 |
+
)
|
596 |
+
|
597 |
+
return fig
|
598 |
+
|
599 |
+
def update_radar(selected_models):
|
600 |
+
df = load_data_csv_es()
|
601 |
+
return plot_radar(df, selected_models)
|
602 |
+
|
603 |
+
|
604 |
+
# LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
|
605 |
+
def plot_cumulative_flow(df, selected_models):
|
606 |
+
df = df[df['model'].isin(selected_models)]
|
607 |
+
|
608 |
+
fig = go.Figure()
|
609 |
+
|
610 |
+
for model in selected_models:
|
611 |
+
model_df = df[df['model'] == model].copy()
|
612 |
+
|
613 |
+
# Calculate cumulative time
|
614 |
+
model_df['cumulative_time'] = model_df['time'].cumsum()
|
615 |
+
|
616 |
+
# Calculate cumulative number of queries over time
|
617 |
+
model_df['cumulative_queries'] = range(1, len(model_df) + 1)
|
618 |
+
|
619 |
+
# Select a color for the model
|
620 |
+
color = MODEL_COLORS.get(model, "gray") # Assigned model color
|
621 |
+
fillcolor = color.replace("rgb", "rgba").replace(")", ", 0.2)") # 🔹 Makes the area semi-transparent
|
622 |
+
|
623 |
+
#color = f"rgba({hash(model) % 256}, {hash(model * 2) % 256}, {hash(model * 3) % 256}, 1)"
|
624 |
+
|
625 |
+
fig.add_trace(go.Scatter(
|
626 |
+
x=model_df['cumulative_time'],
|
627 |
+
y=model_df['cumulative_queries'],
|
628 |
+
mode='lines+markers',
|
629 |
+
name=model,
|
630 |
+
line=dict(width=2, color=color)
|
631 |
+
))
|
632 |
+
|
633 |
+
# Adds the underlying colored area (same color but transparent)
|
634 |
+
"""
|
635 |
+
fig.add_trace(go.Scatter(
|
636 |
+
x=model_df['cumulative_time'],
|
637 |
+
y=model_df['cumulative_queries'],
|
638 |
+
fill='tozeroy',
|
639 |
+
mode='none',
|
640 |
+
showlegend=False, # Hides the area in the legend
|
641 |
+
fillcolor=fillcolor
|
642 |
+
))
|
643 |
+
"""
|
644 |
+
|
645 |
+
fig.update_layout(
|
646 |
+
title="Cumulative Query Flow Chart 📈",
|
647 |
+
xaxis_title="Cumulative Time (s)",
|
648 |
+
yaxis_title="Number of Queries Completed",
|
649 |
+
template='plotly_dark',
|
650 |
+
legend_title="Models"
|
651 |
+
)
|
652 |
+
|
653 |
+
return fig
|
654 |
+
|
655 |
+
def update_query_rate(selected_models):
|
656 |
+
df = load_data_csv_es()
|
657 |
+
return plot_cumulative_flow(df, selected_models)
|
658 |
+
|
659 |
+
|
660 |
+
# RANKING FOR THE TOP 3 MODELS WITH UPDATE FUNCTION
|
661 |
+
def ranking_text(df, selected_models, ranking_type):
|
662 |
+
#df = load_data_csv_es()
|
663 |
+
df = df[df['model'].isin(selected_models)]
|
664 |
+
df['valid_efficiency_score'] = pd.to_numeric(df['valid_efficiency_score'], errors='coerce')
|
665 |
+
if ranking_type == "valid_efficiency_score":
|
666 |
+
rank_df = df.groupby('model')['valid_efficiency_score'].mean().reset_index()
|
667 |
+
#rank_df = df.groupby('model')['valid_efficiency_score'].mean().reset_index()
|
668 |
+
ascending_order = False # Higher is better
|
669 |
+
elif ranking_type == "time":
|
670 |
+
rank_df = df.groupby('model')['time'].sum().reset_index()
|
671 |
+
rank_df["Ranking Value"] = rank_df["time"].round(2).astype(str) + " s" # Adds "s" for seconds
|
672 |
+
ascending_order = True # For time, lower is better
|
673 |
+
elif ranking_type == "metrics":
|
674 |
+
selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
|
675 |
+
df = calculate_average_metrics(df, selected_metrics)
|
676 |
+
rank_df = df.groupby('model')['avg_metric'].mean().reset_index()
|
677 |
+
ascending_order = False # Higher is better
|
678 |
+
|
679 |
+
if ranking_type != "time":
|
680 |
+
rank_df.rename(columns={rank_df.columns[1]: "Ranking Value"}, inplace=True)
|
681 |
+
rank_df["Ranking Value"] = rank_df["Ranking Value"].round(2) # Round values except for time
|
682 |
+
|
683 |
+
# Sort based on the selected criterion
|
684 |
+
rank_df = rank_df.sort_values(by="Ranking Value", ascending=ascending_order).reset_index(drop=True)
|
685 |
+
|
686 |
+
# Select only the top 3 models
|
687 |
+
rank_df = rank_df.head(3)
|
688 |
+
|
689 |
+
# Add medal icons for the top 3
|
690 |
+
medals = ["🥇", "🥈", "🥉"]
|
691 |
+
rank_df.insert(0, "Rank", medals[:len(rank_df)])
|
692 |
+
|
693 |
+
# Build the formatted ranking string
|
694 |
+
ranking_str = "## 🏆 Model Ranking\n"
|
695 |
+
for _, row in rank_df.iterrows():
|
696 |
+
ranking_str += f"<span style='font-size:18px;'>{row['Rank']} {row['model']} ({row['Ranking Value']})</span><br>\n"
|
697 |
+
|
698 |
+
return ranking_str
|
699 |
+
|
700 |
+
def update_ranking_text(selected_models, ranking_type):
|
701 |
+
df = load_data_csv_es()
|
702 |
+
return ranking_text(df, selected_models, ranking_type)
|
703 |
+
|
704 |
+
|
705 |
+
# RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
|
706 |
+
def worst_cases_text(df, selected_models):
|
707 |
+
df = df[df['model'].isin(selected_models)]
|
708 |
+
|
709 |
+
selected_metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
|
710 |
+
df = calculate_average_metrics(df, selected_metrics)
|
711 |
+
|
712 |
+
worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql'])['avg_metric'].mean().reset_index()
|
713 |
+
|
714 |
+
worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
|
715 |
+
|
716 |
+
worst_cases_top_3 = worst_cases_df.head(3)
|
717 |
+
|
718 |
+
worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2)
|
719 |
+
|
720 |
+
worst_str = "## ❌ Top 3 Worst Cases\n"
|
721 |
+
medals = ["🥇", "🥈", "🥉"]
|
722 |
+
|
723 |
+
for i, row in worst_cases_top_3.iterrows():
|
724 |
+
worst_str += (
|
725 |
+
f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']}</b> ({row['avg_metric']})</span> \n"
|
726 |
+
f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n"
|
727 |
+
f"<span style='font-size:16px;'>- <b>Original Query:</b> `{row['query']}`</span> \n"
|
728 |
+
f"<span style='font-size:16px;'>- <b>Predicted SQL:</b> `{row['predicted_sql']}`</span> \n\n"
|
729 |
+
)
|
730 |
+
|
731 |
+
return worst_str
|
732 |
+
|
733 |
+
def update_worst_cases_text(selected_models):
|
734 |
+
df = load_data_csv_es()
|
735 |
+
return worst_cases_text(df, selected_models)
|
736 |
+
|
737 |
+
|
738 |
+
metrics = ["cell_precision", "cell_recall", "execution_accuracy", "tuple_cardinality", "tuple_constraint"]
|
739 |
+
group_options = {
|
740 |
+
"Table": ["tbl_name", "model"],
|
741 |
+
"Model": ["model"]
|
742 |
+
}
|
743 |
+
|
744 |
+
df_initial = load_data_csv_es()
|
745 |
+
models = df_initial['model'].unique().tolist()
|
746 |
+
|
747 |
+
#with gr.Blocks(theme=gr.themes.Default(primary_hue='blue')) as demo:
|
748 |
+
gr.Markdown("""## 📊 Model Performance Analysis 📊
|
749 |
+
Select one or more metrics to calculate the average and visualize histograms and radar plots.
|
750 |
+
""")
|
751 |
+
|
752 |
+
# Options selection section
|
753 |
+
with gr.Row():
|
754 |
+
|
755 |
+
metric_multiselect = gr.CheckboxGroup(choices=metrics, label="Select metrics", value=metrics)
|
756 |
+
model_multiselect = gr.CheckboxGroup(choices=models, label="Select models", value=models)
|
757 |
+
group_radio = gr.Radio(choices=list(group_options.keys()), label="Select grouping", value="Model")
|
758 |
+
|
759 |
+
output_plot = gr.Plot()
|
760 |
+
|
761 |
+
query_rate_plot = gr.Plot(value=update_query_rate(models))
|
762 |
+
|
763 |
+
with gr.Row():
|
764 |
+
with gr.Column(scale=1):
|
765 |
+
radar_plot = gr.Plot(value=update_radar(models))
|
766 |
+
|
767 |
+
with gr.Column(scale=1):
|
768 |
+
ranking_type_radio = gr.Radio(
|
769 |
+
["valid_efficiency_score", "time", "metrics"],
|
770 |
+
label="Choose ranking criteria",
|
771 |
+
value="valid_efficiency_score"
|
772 |
+
)
|
773 |
+
ranking_text_display = gr.Markdown(value=update_ranking_text(models, "valid_efficiency_score"))
|
774 |
+
worst_cases_display = gr.Markdown(value=update_worst_cases_text(models))
|
775 |
+
|
776 |
+
# Callback functions for updating charts
|
777 |
+
def on_change(selected_metrics, selected_group, selected_models):
|
778 |
+
return update_plot(selected_metrics, group_options[selected_group], selected_models)
|
779 |
+
|
780 |
+
def on_radar_change(selected_models):
|
781 |
+
return update_radar(selected_models)
|
782 |
+
|
783 |
+
#metrics_df_out.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
|
784 |
+
metric_multiselect.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
|
785 |
+
group_radio.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
|
786 |
+
model_multiselect.change(on_change, inputs=[metric_multiselect, group_radio, model_multiselect], outputs=output_plot)
|
787 |
+
model_multiselect.change(update_radar, inputs=model_multiselect, outputs=radar_plot)
|
788 |
+
model_multiselect.change(update_ranking_text, inputs=[model_multiselect, ranking_type_radio], outputs=ranking_text_display)
|
789 |
+
ranking_type_radio.change(update_ranking_text, inputs=[model_multiselect, ranking_type_radio], outputs=ranking_text_display)
|
790 |
+
model_multiselect.change(update_worst_cases_text, inputs=model_multiselect, outputs=worst_cases_display)
|
791 |
+
model_multiselect.change(update_query_rate, inputs=[model_multiselect], outputs=query_rate_plot)
|
792 |
+
|
793 |
+
reset_data = gr.Button("Back to upload data section")
|
794 |
+
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
|
795 |
+
|
796 |
+
# Hidden button to force UI refresh on load
|
797 |
+
force_update_button = gr.Button("", visible=False)
|
798 |
+
|
799 |
+
# State variable to track first load
|
800 |
+
load_trigger = gr.State(value=True)
|
801 |
+
|
802 |
+
# Function to force initial load
|
803 |
+
def force_update(is_first_load):
|
804 |
+
if is_first_load:
|
805 |
+
return (
|
806 |
+
update_plot(metrics, group_options["Model"], models),
|
807 |
+
update_query_rate(models),
|
808 |
+
update_radar(models),
|
809 |
+
update_ranking_text(models, "valid_efficiency_score"),
|
810 |
+
update_worst_cases_text(models),
|
811 |
+
False # Change state to prevent continuous reloads
|
812 |
+
)
|
813 |
+
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), False
|
814 |
+
|
815 |
+
# The invisible button forces chart loading only the first time
|
816 |
+
force_update_button.click(
|
817 |
+
fn=force_update,
|
818 |
+
inputs=[load_trigger],
|
819 |
+
outputs=[output_plot, query_rate_plot, radar_plot, ranking_text_display, worst_cases_display, load_trigger]
|
820 |
+
)
|
821 |
+
|
822 |
+
# Simulate button click when UI loads
|
823 |
+
with gr.Blocks() as demo:
|
824 |
+
demo.load(
|
825 |
+
lambda: force_update(True),
|
826 |
+
outputs=[output_plot, query_rate_plot, radar_plot, ranking_text_display, worst_cases_display, load_trigger]
|
827 |
+
)
|
828 |
+
|
829 |
interface.launch()
|
style.css
CHANGED
@@ -22,4 +22,9 @@
|
|
22 |
display: block;
|
23 |
margin: 10px auto 0;
|
24 |
border-radius: 2px;
|
25 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
22 |
display: block;
|
23 |
margin: 10px auto 0;
|
24 |
border-radius: 2px;
|
25 |
+
}
|
26 |
+
|
27 |
+
#bar_plot, #line_plot {
|
28 |
+
width: 100% !important;
|
29 |
+
max-width: none !important;
|
30 |
+
}
|
test_results.csv
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model,tbl_name,test_category,question,query,predicted_sql,cell_precision,cell_recall,execution_accuracy,tuple_cardinality,tuple_constraint,time,valid_efficiency_score,tuple_order
|
2 |
+
Model_B,Table_2,WHERE,Mostra il prodotto più venduto,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT SUM(sales) FROM orders;,0.15,0.89,0.84,0.92,0.97,13,10,
|
3 |
+
Model_C,Table_3,GROUPBY,Quali clienti hanno speso di più?,SELECT * FROM users;,SELECT * FROM customers ORDER BY total_spent DESC;,0.09,0.1,0.02,0.22,0.42,3,4,
|
4 |
+
Model_A,Table_3,ORDERBY,Mostra il prodotto più venduto,SELECT * FROM users;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.9,0.81,0.18,0.56,0.52,11,11,0.39
|
5 |
+
Model_C,Table_3,WHERE,Qual è la media dei prezzi?,SELECT SUM(sales) FROM orders;,SELECT * FROM users;,0.87,0.26,0.83,0.37,0.83,2,1,
|
6 |
+
Model_B,Table_3,ORDERBY,Quali clienti hanno speso di più?,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.62,0.49,0.66,0.95,0.68,13,2,0.38
|
7 |
+
Model_C,Table_3,JOIN,Qual è la media dei prezzi?,SELECT AVG(price) FROM products;,SELECT * FROM customers ORDER BY total_spent DESC;,0.54,0.2,0.83,0.43,0.64,11,14,
|
8 |
+
Model_B,Table_1,ORDERBY,Ordina i clienti per spesa,SELECT SUM(sales) FROM orders;,SELECT AVG(price) FROM products;,0.35,0.12,0.51,0.78,0.94,9,12,0.06
|
9 |
+
Model_C,Table_2,SELECT,Qual è la media dei prezzi?,SELECT AVG(price) FROM products;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.96,0.23,0.38,0.26,0.13,11,8,
|
10 |
+
Model_C,Table_2,JOIN,Elenca tutti gli utenti,SELECT * FROM customers ORDER BY total_spent DESC;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.5,0.07,0.02,0.82,0.75,2,5,
|
11 |
+
Model_A,Table_1,JOIN,Trova il totale delle vendite,SELECT * FROM users;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.03,0.81,0.15,0.64,0.98,12,8,
|
12 |
+
Model_B,Table_3,JOIN,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT AVG(price) FROM products;,0.95,0.63,0.94,0.31,0.94,12,15,
|
13 |
+
Model_A,Table_3,SELECT,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.26,0.81,0.45,0.7,0.77,9,0,
|
14 |
+
Model_C,Table_3,WHERE,Ordina i clienti per spesa,SELECT * FROM users;,SELECT * FROM customers ORDER BY total_spent DESC;,0.65,0.22,0.61,0.56,0.84,5,11,
|
15 |
+
Model_C,Table_1,WHERE,Trova il totale delle vendite,SELECT * FROM users;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.3,0.83,0.14,0.22,0.7,14,2,
|
16 |
+
Model_A,Table_3,JOIN,Elenca tutti gli utenti,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;","SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.66,0.57,0.07,0.93,0.16,15,11,
|
17 |
+
Model_C,Table_1,WHERE,Elenca tutti gli utenti,SELECT SUM(sales) FROM orders;,SELECT SUM(sales) FROM orders;,0.39,0.97,0.87,0.99,0.27,13,4,
|
18 |
+
Model_A,Table_2,SELECT,Elenca tutti gli utenti,SELECT SUM(sales) FROM orders;,SELECT * FROM users;,0.64,0.44,0.03,0.81,0.43,14,6,
|
19 |
+
Model_A,Table_3,GROUPBY,Quali clienti hanno speso di più?,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT * FROM customers ORDER BY total_spent DESC;,0.55,0.32,0.28,0.12,0.79,7,3,
|
20 |
+
Model_B,Table_3,WHERE,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT * FROM users;,0.56,0.27,0.34,0.59,0.59,10,15,
|
21 |
+
Model_A,Table_2,SELECT,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.6,0.83,0.41,0.28,0.02,10,8,
|
22 |
+
Model_C,Table_2,SELECT,Elenca tutti gli utenti,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT AVG(price) FROM products;,0.77,0.5,0.13,0.74,0.36,2,4,
|
23 |
+
Model_C,Table_1,ORDERBY,Elenca tutti gli utenti,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.5,0.33,0.91,0.05,0.71,3,11,0.83
|
24 |
+
Model_C,Table_2,WHERE,Mostra il prodotto più venduto,SELECT SUM(sales) FROM orders;,SELECT AVG(price) FROM products;,0.74,0.62,0.64,0.26,0.05,1,5,
|
25 |
+
Model_B,Table_3,JOIN,Qual è la media dei prezzi?,SELECT SUM(sales) FROM orders;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.98,0.73,0.52,0.12,0.72,1,6,
|
26 |
+
Model_A,Table_2,WHERE,Elenca tutti gli utenti,SELECT AVG(price) FROM products;,SELECT AVG(price) FROM products;,0.9,0.84,0.57,0.86,0.66,10,1,
|
27 |
+
Model_A,Table_1,WHERE,Elenca tutti gli utenti,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT SUM(sales) FROM orders;,0.21,0.73,0.97,0.58,0.04,5,4,
|
28 |
+
Model_C,Table_3,GROUPBY,Qual è la media dei prezzi?,SELECT * FROM users;,SELECT * FROM users;,0.63,0.51,0.14,0.24,0.62,3,6,
|
29 |
+
Model_A,Table_1,ORDERBY,Elenca tutti gli utenti,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.39,0.42,0.88,0.45,0.24,9,1,0.95
|
30 |
+
Model_B,Table_3,JOIN,Quali clienti hanno speso di più?,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT SUM(sales) FROM orders;,0.86,0.59,0.53,0.91,0.9,4,8,
|
31 |
+
Model_C,Table_3,JOIN,Ordina i clienti per spesa,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM users;,0.86,0.57,0.26,0.47,0.5,13,13,
|
32 |
+
Model_B,Table_2,GROUPBY,Mostra il prodotto più venduto,SELECT SUM(sales) FROM orders;,SELECT AVG(price) FROM products;,0.67,0.7,0.1,0.42,0.31,9,14,
|
33 |
+
Model_C,Table_3,GROUPBY,Trova il totale delle vendite,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT AVG(price) FROM products;,0.74,0.26,0.64,0.33,0.09,9,7,
|
34 |
+
Model_C,Table_2,GROUPBY,Mostra il prodotto più venduto,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.31,0.75,0.73,0.3,0.25,8,10,
|
35 |
+
Model_B,Table_2,JOIN,Elenca tutti gli utenti,SELECT AVG(price) FROM products;,SELECT * FROM users;,0.83,0.63,0.43,0.0,0.1,13,9,
|
36 |
+
Model_C,Table_1,ORDERBY,Qual è la media dei prezzi?,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM customers ORDER BY total_spent DESC;,1.0,0.48,0.96,0.45,0.66,4,5,0.8
|
37 |
+
Model_A,Table_3,JOIN,Trova il totale delle vendite,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM customers ORDER BY total_spent DESC;,0.97,0.46,0.34,0.57,0.21,15,4,
|
38 |
+
Model_C,Table_3,SELECT,Quali clienti hanno speso di più?,SELECT * FROM users;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.72,0.02,0.64,0.62,0.83,11,7,
|
39 |
+
Model_A,Table_2,GROUPBY,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.16,0.67,0.2,0.62,0.82,9,15,
|
40 |
+
Model_C,Table_2,ORDERBY,Quali clienti hanno speso di più?,SELECT SUM(sales) FROM orders;,SELECT AVG(price) FROM products;,0.38,0.13,0.96,0.36,0.9,14,5,0.01
|
41 |
+
Model_A,Table_2,GROUPBY,Elenca tutti gli utenti,SELECT AVG(price) FROM products;,SELECT * FROM customers ORDER BY total_spent DESC;,0.73,0.77,0.24,0.35,0.77,1,2,
|
42 |
+
Model_C,Table_1,JOIN,Mostra il prodotto più venduto,SELECT SUM(sales) FROM orders;,SELECT * FROM users;,0.46,0.51,0.79,0.1,0.87,7,5,
|
43 |
+
Model_C,Table_3,WHERE,Ordina i clienti per spesa,SELECT SUM(sales) FROM orders;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.03,0.53,0.5,0.69,0.45,7,3,
|
44 |
+
Model_C,Table_3,JOIN,Elenca tutti gli utenti,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT AVG(price) FROM products;,0.27,0.94,0.41,0.07,0.61,7,14,
|
45 |
+
Model_A,Table_2,WHERE,Qual è la media dei prezzi?,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT SUM(sales) FROM orders;,0.04,0.8,0.59,0.06,0.18,9,10,
|
46 |
+
Model_C,Table_3,WHERE,Trova il totale delle vendite,SELECT * FROM users;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.99,0.17,0.2,0.74,0.59,3,14,
|
47 |
+
Model_B,Table_1,WHERE,Trova il totale delle vendite,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT * FROM users;,0.0,0.39,0.77,0.04,0.5,1,2,
|
48 |
+
Model_C,Table_3,SELECT,Ordina i clienti per spesa,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT * FROM customers ORDER BY total_spent DESC;,0.69,0.77,0.07,0.97,0.21,7,4,
|
49 |
+
Model_C,Table_2,JOIN,Quali clienti hanno speso di più?,SELECT SUM(sales) FROM orders;,SELECT SUM(sales) FROM orders;,0.05,0.6,0.47,0.08,0.83,10,11,
|
50 |
+
Model_B,Table_2,WHERE,Qual è la media dei prezzi?,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT SUM(sales) FROM orders;,0.18,0.8,0.01,0.26,0.79,4,9,
|
51 |
+
Model_A,Table_2,WHERE,Quali clienti hanno speso di più?,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM users;,0.83,0.69,0.25,0.27,0.73,6,15,
|
52 |
+
Model_C,Table_2,GROUPBY,Qual è la media dei prezzi?,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM users;,0.8,0.0,0.2,0.11,0.09,6,10,
|
53 |
+
Model_A,Table_1,JOIN,Qual è la media dei prezzi?,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT * FROM customers ORDER BY total_spent DESC;,0.33,0.0,0.01,0.65,0.38,12,2,
|
54 |
+
Model_B,Table_2,GROUPBY,Trova il totale delle vendite,SELECT AVG(price) FROM products;,SELECT AVG(price) FROM products;,0.3,0.79,0.37,0.66,0.07,2,14,
|
55 |
+
Model_A,Table_1,GROUPBY,Mostra il prodotto più venduto,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT AVG(price) FROM products;,0.8,0.21,0.4,0.93,0.61,3,8,
|
56 |
+
Model_B,Table_2,GROUPBY,Trova il totale delle vendite,SELECT AVG(price) FROM products;,SELECT * FROM customers ORDER BY total_spent DESC;,0.48,0.8,0.62,0.72,0.64,8,8,
|
57 |
+
Model_B,Table_3,ORDERBY,Elenca tutti gli utenti,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT SUM(sales) FROM orders;,0.94,0.61,0.07,0.42,0.74,9,6,0.63
|
58 |
+
Model_B,Table_2,WHERE,Ordina i clienti per spesa,SELECT * FROM users;,SELECT AVG(price) FROM products;,0.02,0.63,0.97,0.62,0.34,2,6,
|
59 |
+
Model_C,Table_2,WHERE,Trova il totale delle vendite,SELECT AVG(price) FROM products;,SELECT SUM(sales) FROM orders;,0.61,0.87,0.56,0.55,0.11,6,4,
|
60 |
+
Model_B,Table_2,WHERE,Elenca tutti gli utenti,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT AVG(price) FROM products;,0.4,0.44,0.72,0.01,0.78,2,1,
|
61 |
+
Model_A,Table_3,ORDERBY,Qual è la media dei prezzi?,SELECT SUM(sales) FROM orders;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.89,0.48,0.37,0.8,0.57,2,11,0.83
|
62 |
+
Model_A,Table_2,GROUPBY,Trova il totale delle vendite,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;","SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.95,0.51,0.17,0.55,0.06,10,4,
|
63 |
+
Model_A,Table_3,JOIN,Qual è la media dei prezzi?,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT * FROM customers ORDER BY total_spent DESC;,0.47,0.56,0.89,0.86,0.06,1,2,
|
64 |
+
Model_A,Table_2,ORDERBY,Mostra il prodotto più venduto,SELECT * FROM users;,SELECT * FROM customers ORDER BY total_spent DESC;,0.3,0.5,0.69,0.51,0.07,13,1,0.59
|
65 |
+
Model_C,Table_2,GROUPBY,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.81,0.37,0.51,0.6,0.45,3,9,
|
66 |
+
Model_B,Table_3,WHERE,Trova il totale delle vendite,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT AVG(price) FROM products;,0.62,0.65,0.26,0.52,0.05,4,3,
|
67 |
+
Model_A,Table_1,GROUPBY,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT * FROM users;,0.17,0.28,0.87,0.95,0.81,10,7,
|
68 |
+
Model_B,Table_1,GROUPBY,Trova il totale delle vendite,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.2,0.16,0.28,0.95,0.64,9,2,
|
69 |
+
Model_C,Table_2,JOIN,Qual è la media dei prezzi?,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT SUM(sales) FROM orders;,0.89,0.78,0.56,0.84,0.13,3,11,
|
70 |
+
Model_B,Table_1,ORDERBY,Trova il totale delle vendite,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;","SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.91,0.35,0.97,0.99,0.97,12,5,0.9
|
71 |
+
Model_A,Table_1,WHERE,Elenca tutti gli utenti,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT SUM(sales) FROM orders;,0.59,0.72,0.77,0.64,0.75,14,2,
|
72 |
+
Model_B,Table_2,JOIN,Ordina i clienti per spesa,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT * FROM users;,0.05,0.3,0.37,0.22,0.31,6,6,
|
73 |
+
Model_C,Table_3,JOIN,Mostra il prodotto più venduto,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.69,0.34,0.94,0.22,0.94,11,7,
|
74 |
+
Model_B,Table_2,JOIN,Qual è la media dei prezzi?,SELECT SUM(sales) FROM orders;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.4,0.79,0.72,0.82,0.98,6,5,
|
75 |
+
Model_C,Table_1,ORDERBY,Elenca tutti gli utenti,SELECT * FROM users;,SELECT SUM(sales) FROM orders;,0.57,0.71,0.04,0.32,0.55,12,0,0.43
|
76 |
+
Model_A,Table_1,SELECT,Quali clienti hanno speso di più?,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM users;,0.84,0.62,0.32,0.28,0.16,2,5,
|
77 |
+
Model_A,Table_3,ORDERBY,Ordina i clienti per spesa,SELECT AVG(price) FROM products;,SELECT * FROM customers ORDER BY total_spent DESC;,0.77,0.81,0.47,0.46,0.82,2,10,0.66
|
78 |
+
Model_C,Table_2,ORDERBY,Trova il totale delle vendite,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.86,0.28,0.14,0.75,0.37,13,14,0.25
|
79 |
+
Model_A,Table_1,ORDERBY,Qual è la media dei prezzi?,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.9,0.93,0.89,0.88,0.25,13,3,0.92
|
80 |
+
Model_C,Table_3,JOIN,Quali clienti hanno speso di più?,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.59,0.9,0.12,0.58,0.69,14,2,
|
81 |
+
Model_C,Table_1,JOIN,Quali clienti hanno speso di più?,SELECT SUM(sales) FROM orders;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.72,0.2,0.67,0.36,0.42,5,14,
|
82 |
+
Model_A,Table_2,JOIN,Trova il totale delle vendite,SELECT * FROM users;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.17,0.74,0.12,0.4,0.27,11,15,
|
83 |
+
Model_A,Table_1,SELECT,Elenca tutti gli utenti,SELECT AVG(price) FROM products;,SELECT * FROM customers ORDER BY total_spent DESC;,0.34,0.59,0.8,0.15,0.58,11,3,
|
84 |
+
Model_A,Table_1,ORDERBY,Ordina i clienti per spesa,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.4,0.64,0.18,0.22,0.03,15,8,0.33
|
85 |
+
Model_C,Table_2,ORDERBY,Qual è la media dei prezzi?,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT * FROM customers ORDER BY total_spent DESC;,0.32,0.31,0.95,0.97,0.21,7,6,0.02
|
86 |
+
Model_B,Table_2,JOIN,Elenca tutti gli utenti,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;","SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.37,0.41,0.36,0.3,0.6,11,10,
|
87 |
+
Model_A,Table_1,SELECT,Trova il totale delle vendite,SELECT * FROM users;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.5,0.88,0.64,0.54,0.63,8,11,
|
88 |
+
Model_C,Table_1,GROUPBY,Trova il totale delle vendite,SELECT * FROM users;,SELECT * FROM users;,0.19,0.6,0.4,0.49,0.21,13,3,
|
89 |
+
Model_C,Table_3,JOIN,Trova il totale delle vendite,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT * FROM users;,0.4,0.92,0.33,0.19,0.67,13,3,
|
90 |
+
Model_B,Table_3,JOIN,Elenca tutti gli utenti,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,SELECT * FROM users;,0.76,0.67,0.51,1.0,0.22,10,14,
|
91 |
+
Model_C,Table_1,ORDERBY,Quali clienti hanno speso di più?,SELECT AVG(price) FROM products;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.07,0.72,0.69,0.21,0.15,15,3,0.63
|
92 |
+
Model_C,Table_1,GROUPBY,Ordina i clienti per spesa,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.71,0.54,0.24,0.35,0.23,7,2,
|
93 |
+
Model_C,Table_1,WHERE,Elenca tutti gli utenti,SELECT AVG(price) FROM products;,SELECT * FROM users;,0.83,0.36,0.52,0.18,0.06,10,7,
|
94 |
+
Model_C,Table_2,JOIN,Elenca tutti gli utenti,SELECT * FROM users;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.19,0.01,0.34,0.79,0.29,14,7,
|
95 |
+
Model_C,Table_1,GROUPBY,Trova il totale delle vendite,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.14,0.74,0.76,0.88,0.01,5,9,
|
96 |
+
Model_A,Table_2,JOIN,Ordina i clienti per spesa,SELECT SUM(sales) FROM orders;,SELECT * FROM users;,0.66,0.36,0.97,0.46,0.85,5,1,
|
97 |
+
Model_A,Table_2,ORDERBY,Mostra il prodotto più venduto,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.39,0.8,0.88,0.06,0.84,4,14,0.96
|
98 |
+
Model_C,Table_3,GROUPBY,Elenca tutti gli utenti,SELECT * FROM customers ORDER BY total_spent DESC;,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,0.17,0.93,0.64,0.84,0.72,6,8,
|
99 |
+
Model_A,Table_3,ORDERBY,Elenca tutti gli utenti,SELECT product FROM sales ORDER BY count DESC LIMIT 1;,"SELECT customer, SUM(amount) FROM payments GROUP BY customer ORDER BY SUM(amount) DESC;",0.55,0.16,0.54,0.87,0.55,8,1,0.94
|
100 |
+
Model_B,Table_1,ORDERBY,Ordina i clienti per spesa,SELECT SUM(sales) FROM orders;,SELECT * FROM customers ORDER BY total_spent DESC;,0.92,0.59,0.25,0.57,0.08,2,5,0.47
|
101 |
+
Model_C,Table_1,JOIN,Mostra il prodotto più venduto,SELECT * FROM users;,SELECT * FROM customers ORDER BY total_spent DESC;,0.7,0.88,0.65,0.64,0.92,13,13,
|
utilities.py
CHANGED
@@ -3,8 +3,9 @@ import pandas as pd
|
|
3 |
import sqlite3
|
4 |
import gradio as gr
|
5 |
import os
|
|
|
6 |
|
7 |
-
def carica_sqlite(file_path):
|
8 |
conn = sqlite3.connect(file_path)
|
9 |
cursor = conn.cursor()
|
10 |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
@@ -17,7 +18,11 @@ def carica_sqlite(file_path):
|
|
17 |
df = pd.read_sql_query(f"SELECT * FROM {nome_tabella}", conn)
|
18 |
dfs[nome_tabella] = df
|
19 |
conn.close()
|
20 |
-
data_output = {'data_frames': dfs,'db':
|
|
|
|
|
|
|
|
|
21 |
return data_output
|
22 |
|
23 |
# Funzione per leggere un file CSV
|
@@ -37,7 +42,7 @@ def load_data(data_path : str, db_name : str):
|
|
37 |
data_output = {'data_frames': {} ,'db': None}
|
38 |
table_name = os.path.splitext(os.path.basename(data_path))[0]
|
39 |
if data_path.endswith(".sqlite") :
|
40 |
-
data_output = carica_sqlite(data_path)
|
41 |
elif data_path.endswith(".csv"):
|
42 |
data_output['data_frames'] = {f"{table_name}_table" : carica_csv(data_path)}
|
43 |
elif data_path.endswith(".xlsx"):
|
|
|
3 |
import sqlite3
|
4 |
import gradio as gr
|
5 |
import os
|
6 |
+
from qatch.connectors.sqlite_connector import SqliteConnector
|
7 |
|
8 |
+
def carica_sqlite(file_path, db_id):
|
9 |
conn = sqlite3.connect(file_path)
|
10 |
cursor = conn.cursor()
|
11 |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
|
|
18 |
df = pd.read_sql_query(f"SELECT * FROM {nome_tabella}", conn)
|
19 |
dfs[nome_tabella] = df
|
20 |
conn.close()
|
21 |
+
data_output = {'data_frames': dfs,'db': None}
|
22 |
+
# data_output['db'] = SqliteConnector(
|
23 |
+
# relative_db_path=file_path,
|
24 |
+
# db_name=db_id,
|
25 |
+
# )
|
26 |
return data_output
|
27 |
|
28 |
# Funzione per leggere un file CSV
|
|
|
42 |
data_output = {'data_frames': {} ,'db': None}
|
43 |
table_name = os.path.splitext(os.path.basename(data_path))[0]
|
44 |
if data_path.endswith(".sqlite") :
|
45 |
+
data_output = carica_sqlite(data_path, db_name)
|
46 |
elif data_path.endswith(".csv"):
|
47 |
data_output['data_frames'] = {f"{table_name}_table" : carica_csv(data_path)}
|
48 |
elif data_path.endswith(".xlsx"):
|
utils_get_db_tables_info.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sqlite3
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
def utils_extract_db_schema_as_string(
|
7 |
+
db_id, base_path, normalize=False, sql: str | None = None
|
8 |
+
):
|
9 |
+
"""
|
10 |
+
Extracts the full schema of an SQLite database into a single string.
|
11 |
+
|
12 |
+
:param base_path: Base path where the database is located.
|
13 |
+
:param db_id: Path to the SQLite database file.
|
14 |
+
:param normalize: Whether to normalize the schema string.
|
15 |
+
:param sql: Optional SQL query to filter specific tables.
|
16 |
+
:return: Schema of the database as a single string.
|
17 |
+
"""
|
18 |
+
#db_path = os.path.join(base_path, db_id, f"{db_id}.sqlite")
|
19 |
+
|
20 |
+
# Connect to the SQLite database
|
21 |
+
|
22 |
+
#if not os.path.exists(db_path):
|
23 |
+
# raise FileNotFoundError(f"Database file not found at: {db_path}")
|
24 |
+
|
25 |
+
connection = sqlite3.connect(base_path)
|
26 |
+
cursor = connection.cursor()
|
27 |
+
|
28 |
+
# Get the schema entries based on the provided SQL query
|
29 |
+
schema_entries = _get_schema_entries(cursor, sql)
|
30 |
+
|
31 |
+
# Combine all schema definitions into a single string
|
32 |
+
schema_string = _combine_schema_entries(schema_entries, normalize)
|
33 |
+
|
34 |
+
return schema_string
|
35 |
+
|
36 |
+
|
37 |
+
def _get_schema_entries(cursor, sql):
|
38 |
+
"""
|
39 |
+
Retrieves schema entries from the SQLite database.
|
40 |
+
|
41 |
+
:param cursor: SQLite cursor object.
|
42 |
+
:param sql: Optional SQL query to filter specific tables.
|
43 |
+
:return: List of schema entries.
|
44 |
+
"""
|
45 |
+
if sql:
|
46 |
+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
47 |
+
tables = [tbl[0] for tbl in cursor.fetchall() if tbl[0].lower() in sql.lower()]
|
48 |
+
if tables:
|
49 |
+
tbl_names = ", ".join(f"'{tbl}'" for tbl in tables)
|
50 |
+
query = f"SELECT sql FROM sqlite_master WHERE type='table' AND name IN ({tbl_names}) AND sql IS NOT NULL;"
|
51 |
+
else:
|
52 |
+
query = "SELECT sql FROM sqlite_master WHERE sql IS NOT NULL;"
|
53 |
+
else:
|
54 |
+
query = "SELECT sql FROM sqlite_master WHERE sql IS NOT NULL;"
|
55 |
+
|
56 |
+
cursor.execute(query)
|
57 |
+
return cursor.fetchall()
|
58 |
+
|
59 |
+
|
60 |
+
def _combine_schema_entries(schema_entries, normalize):
|
61 |
+
"""
|
62 |
+
Combines schema entries into a single string.
|
63 |
+
|
64 |
+
:param schema_entries: List of schema entries.
|
65 |
+
:param normalize: Whether to normalize the schema string.
|
66 |
+
:return: Combined schema string.
|
67 |
+
"""
|
68 |
+
if not normalize:
|
69 |
+
return "\n".join(entry[0] for entry in schema_entries)
|
70 |
+
|
71 |
+
return "\n".join(
|
72 |
+
re.sub(
|
73 |
+
r"\s*\)",
|
74 |
+
")",
|
75 |
+
re.sub(
|
76 |
+
r"\(\s*",
|
77 |
+
"(",
|
78 |
+
re.sub(
|
79 |
+
r"(`\w+`)\s+\(",
|
80 |
+
r"\1(",
|
81 |
+
re.sub(
|
82 |
+
r"^\s*([^\s(]+)",
|
83 |
+
r"`\1`",
|
84 |
+
re.sub(
|
85 |
+
r"\s+",
|
86 |
+
" ",
|
87 |
+
entry[0].replace("CREATE TABLE", "").replace("\t", " "),
|
88 |
+
).strip(),
|
89 |
+
),
|
90 |
+
),
|
91 |
+
),
|
92 |
+
)
|
93 |
+
for entry in schema_entries
|
94 |
+
)
|