luulinh90s commited on
Commit
18d08ed
·
1 Parent(s): 833d467
Files changed (1) hide show
  1. app.py +20 -19
app.py CHANGED
@@ -126,19 +126,10 @@ def load_session_data(username):
126
 
127
  def load_samples(methods):
128
  logger.info(f"Loading samples for methods: {methods}")
129
- samples = []
130
  categories = ["TP", "TN", "FP", "FN"]
131
 
132
- method_dirs = []
133
- for method in methods:
134
- if method == 'No-XAI':
135
- method_dirs.append('NO_XAI')
136
- elif method == 'Dater':
137
- method_dirs.append('DATER')
138
- elif method == 'Chain-of-Table':
139
- method_dirs.append('COT')
140
- elif method == 'Plan-of-SQLs':
141
- method_dirs.append('POS')
142
 
143
  for category in categories:
144
  dir_a = f'htmls_{method_dirs[0].upper()}/{category}'
@@ -150,18 +141,29 @@ def load_samples(methods):
150
  matching_files = files_a & files_b
151
 
152
  for file in matching_files:
153
- samples.append({
154
- 'category': category,
155
- 'file': file
156
- })
157
 
158
- return samples
 
159
 
 
 
160
 
161
  def select_balanced_samples(samples):
162
  try:
163
- selected_samples = random.sample(samples, min(10, len(samples)))
164
- logger.info(f"Selected balanced samples: {len(selected_samples)}")
 
 
 
 
 
 
 
 
 
 
 
165
  return selected_samples
166
  except Exception as e:
167
  logger.exception("Error selecting balanced samples")
@@ -285,7 +287,6 @@ def get_method_dir(method):
285
  elif method == 'Plan-of-SQLs':
286
  return 'POS'
287
 
288
-
289
  def get_visualization_dir(method):
290
  if method == "No-XAI":
291
  return 'htmls_NO_XAI'
 
126
 
127
  def load_samples(methods):
128
  logger.info(f"Loading samples for methods: {methods}")
129
+ samples = set() # Use a set to avoid duplicates
130
  categories = ["TP", "TN", "FP", "FN"]
131
 
132
+ method_dirs = [get_method_dir(method) for method in methods]
 
 
 
 
 
 
 
 
 
133
 
134
  for category in categories:
135
  dir_a = f'htmls_{method_dirs[0].upper()}/{category}'
 
141
  matching_files = files_a & files_b
142
 
143
  for file in matching_files:
144
+ samples.add((category, file))
 
 
 
145
 
146
+ # Convert set of tuples back to list of dictionaries
147
+ samples = [{'category': category, 'file': file} for category, file in samples]
148
 
149
+ logger.info(f"Loaded {len(samples)} unique samples across all categories")
150
+ return samples
151
 
152
  def select_balanced_samples(samples):
153
  try:
154
+ # Ensure we have at least 10 unique samples
155
+ unique_samples = list({(s['category'], s['file']) for s in samples})
156
+
157
+ if len(unique_samples) < 10:
158
+ logger.warning(f"Not enough unique samples. Only {len(unique_samples)} available.")
159
+ selected_samples = unique_samples
160
+ else:
161
+ selected_samples = random.sample(unique_samples, 10)
162
+
163
+ # Convert back to dictionary format
164
+ selected_samples = [{'category': category, 'file': file} for category, file in selected_samples]
165
+
166
+ logger.info(f"Selected {len(selected_samples)} unique samples")
167
  return selected_samples
168
  except Exception as e:
169
  logger.exception("Error selecting balanced samples")
 
287
  elif method == 'Plan-of-SQLs':
288
  return 'POS'
289
 
 
290
  def get_visualization_dir(method):
291
  if method == "No-XAI":
292
  return 'htmls_NO_XAI'