Spaces:
Runtime error
Runtime error
Commit
·
18d08ed
1
Parent(s):
833d467
update
Browse files
app.py
CHANGED
@@ -126,19 +126,10 @@ def load_session_data(username):
|
|
126 |
|
127 |
def load_samples(methods):
|
128 |
logger.info(f"Loading samples for methods: {methods}")
|
129 |
-
samples =
|
130 |
categories = ["TP", "TN", "FP", "FN"]
|
131 |
|
132 |
-
method_dirs = []
|
133 |
-
for method in methods:
|
134 |
-
if method == 'No-XAI':
|
135 |
-
method_dirs.append('NO_XAI')
|
136 |
-
elif method == 'Dater':
|
137 |
-
method_dirs.append('DATER')
|
138 |
-
elif method == 'Chain-of-Table':
|
139 |
-
method_dirs.append('COT')
|
140 |
-
elif method == 'Plan-of-SQLs':
|
141 |
-
method_dirs.append('POS')
|
142 |
|
143 |
for category in categories:
|
144 |
dir_a = f'htmls_{method_dirs[0].upper()}/{category}'
|
@@ -150,18 +141,29 @@ def load_samples(methods):
|
|
150 |
matching_files = files_a & files_b
|
151 |
|
152 |
for file in matching_files:
|
153 |
-
samples.
|
154 |
-
'category': category,
|
155 |
-
'file': file
|
156 |
-
})
|
157 |
|
158 |
-
|
|
|
159 |
|
|
|
|
|
160 |
|
161 |
def select_balanced_samples(samples):
|
162 |
try:
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
return selected_samples
|
166 |
except Exception as e:
|
167 |
logger.exception("Error selecting balanced samples")
|
@@ -285,7 +287,6 @@ def get_method_dir(method):
|
|
285 |
elif method == 'Plan-of-SQLs':
|
286 |
return 'POS'
|
287 |
|
288 |
-
|
289 |
def get_visualization_dir(method):
|
290 |
if method == "No-XAI":
|
291 |
return 'htmls_NO_XAI'
|
|
|
126 |
|
127 |
def load_samples(methods):
|
128 |
logger.info(f"Loading samples for methods: {methods}")
|
129 |
+
samples = set() # Use a set to avoid duplicates
|
130 |
categories = ["TP", "TN", "FP", "FN"]
|
131 |
|
132 |
+
method_dirs = [get_method_dir(method) for method in methods]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
for category in categories:
|
135 |
dir_a = f'htmls_{method_dirs[0].upper()}/{category}'
|
|
|
141 |
matching_files = files_a & files_b
|
142 |
|
143 |
for file in matching_files:
|
144 |
+
samples.add((category, file))
|
|
|
|
|
|
|
145 |
|
146 |
+
# Convert set of tuples back to list of dictionaries
|
147 |
+
samples = [{'category': category, 'file': file} for category, file in samples]
|
148 |
|
149 |
+
logger.info(f"Loaded {len(samples)} unique samples across all categories")
|
150 |
+
return samples
|
151 |
|
152 |
def select_balanced_samples(samples):
|
153 |
try:
|
154 |
+
# Ensure we have at least 10 unique samples
|
155 |
+
unique_samples = list({(s['category'], s['file']) for s in samples})
|
156 |
+
|
157 |
+
if len(unique_samples) < 10:
|
158 |
+
logger.warning(f"Not enough unique samples. Only {len(unique_samples)} available.")
|
159 |
+
selected_samples = unique_samples
|
160 |
+
else:
|
161 |
+
selected_samples = random.sample(unique_samples, 10)
|
162 |
+
|
163 |
+
# Convert back to dictionary format
|
164 |
+
selected_samples = [{'category': category, 'file': file} for category, file in selected_samples]
|
165 |
+
|
166 |
+
logger.info(f"Selected {len(selected_samples)} unique samples")
|
167 |
return selected_samples
|
168 |
except Exception as e:
|
169 |
logger.exception("Error selecting balanced samples")
|
|
|
287 |
elif method == 'Plan-of-SQLs':
|
288 |
return 'POS'
|
289 |
|
|
|
290 |
def get_visualization_dir(method):
|
291 |
if method == "No-XAI":
|
292 |
return 'htmls_NO_XAI'
|