VOIDER commited on
Commit
d093305
Β·
verified Β·
1 Parent(s): 024c6f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +339 -190
app.py CHANGED
@@ -20,7 +20,16 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
20
  logger = logging.getLogger(__name__)
21
 
22
  # Import aesthetic predictor function
23
- from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  @dataclass
@@ -66,9 +75,23 @@ class AestheticShadowModel(BaseModel):
66
  try:
67
  results = self.model(images)
68
  scores = []
69
- for result in results:
70
- hq_score = next((p['score'] for p in result if p['label'] == 'hq'), 0)
71
- scores.append(float(np.clip(hq_score * 10.0, 0.0, 10.0)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  return scores
73
  except Exception as e:
74
  logger.error(f"Error in {self.name}: {e}")
@@ -86,16 +109,31 @@ class WaifuScorerModel(BaseModel):
86
  try:
87
  import clip
88
 
89
- # Load MLP model
90
  self.mlp = self._create_mlp()
91
  model_path = hf_hub_download("Eugeoter/waifu-scorer-v3", "model.pth")
92
  state_dict = torch.load(model_path, map_location=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  self.mlp.load_state_dict(state_dict)
94
  self.mlp.to(self.device).eval()
95
 
96
- # Load CLIP model
97
  self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device)
98
  self.available = True
 
 
 
99
  except Exception as e:
100
  logger.error(f"Failed to load {self.name}: {e}")
101
  self.available = False
@@ -130,12 +168,11 @@ class WaifuScorerModel(BaseModel):
130
  return [None] * len(images)
131
 
132
  try:
133
- # Process images
134
  image_tensors = torch.cat([self.preprocess(img).unsqueeze(0) for img in images])
135
  image_tensors = image_tensors.to(self.device)
136
 
137
- # Extract features and predict
138
  features = self.clip_model.encode_image(image_tensors)
 
139
  features = features / features.norm(dim=-1, keepdim=True)
140
  predictions = self.mlp(features)
141
 
@@ -151,25 +188,40 @@ class AestheticPredictorV25Model(BaseModel):
151
  def __init__(self):
152
  super().__init__("Aesthetic V2.5")
153
  logger.info(f"Loading {self.name} model...")
154
- self.model, self.preprocessor = convert_v2_5_from_siglip(
155
- low_cpu_mem_usage=True,
156
- trust_remote_code=True,
157
- )
158
- if self.device == 'cuda':
159
- self.model = self.model.to(torch.bfloat16).cuda()
160
-
 
 
 
 
 
 
 
161
  @torch.no_grad()
162
  async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
 
 
163
  try:
164
  images_rgb = [img.convert("RGB") for img in images]
165
- pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt").pixel_values
166
 
167
  if self.device == 'cuda':
168
  pixel_values = pixel_values.to(torch.bfloat16).cuda()
 
 
169
 
170
- scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
171
- if scores.ndim == 0:
172
- scores = np.array([scores])
 
 
 
 
173
 
174
  return [float(np.clip(s, 0.0, 10.0)) for s in scores]
175
  except Exception as e:
@@ -182,41 +234,63 @@ class AnimeAestheticModel(BaseModel):
182
  def __init__(self):
183
  super().__init__("Anime Score")
184
  logger.info(f"Loading {self.name} model...")
185
- model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
186
- self.session = rt.InferenceSession(model_path, providers=['CPUExecutionProvider'])
187
-
 
 
 
 
 
 
188
  async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
 
 
189
  scores = []
190
  for img in images:
191
  try:
192
  score = self._process_single_image(img)
193
  scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
194
  except Exception as e:
195
- logger.error(f"Error in {self.name} for single image: {e}")
196
  scores.append(None)
197
  return scores
198
 
199
  def _process_single_image(self, img: Image.Image) -> float:
200
  """Process a single image through the model"""
201
- img_np = np.array(img).astype(np.float32) / 255.0
202
- size = 768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  h, w = img_np.shape[:2]
204
 
205
- # Calculate new dimensions
206
  if h > w:
207
  new_h, new_w = size, int(size * w / h)
208
  else:
209
  new_h, new_w = int(size * h / w), size
210
 
211
- # Resize and center
212
- resized = cv2.resize(img_np, (new_w, new_h))
213
- canvas = np.zeros((size, size, 3), dtype=np.float32)
 
214
  pad_h = (size - new_h) // 2
215
  pad_w = (size - new_w) // 2
216
- canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
217
 
218
- # Prepare input
219
- input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
 
220
  return self.session.run(None, {"img": input_tensor})[0].item()
221
 
222
 
@@ -238,10 +312,18 @@ class ImageEvaluator:
238
 
239
  for key, model_class in model_classes:
240
  try:
241
- self.models[key] = model_class()
242
- logger.info(f"Successfully loaded {key}")
 
 
 
 
 
 
 
 
243
  except Exception as e:
244
- logger.error(f"Failed to load {key}: {e}")
245
 
246
  async def evaluate_images(
247
  self,
@@ -252,99 +334,137 @@ class ImageEvaluator:
252
  ) -> Tuple[List[EvaluationResult], List[str]]:
253
  """Evaluate images with selected models"""
254
  logs = []
255
- results = []
256
 
257
- # Load images
258
- images = []
259
- valid_paths = []
260
- for path in file_paths:
261
  try:
262
  img = Image.open(path).convert("RGB")
263
- images.append(img)
264
- valid_paths.append(path)
265
  except Exception as e:
266
  logs.append(f"Failed to load {Path(path).name}: {e}")
267
 
268
- if not images:
269
  logs.append("No valid images to process")
270
- return results, logs
271
-
272
- logs.append(f"Loaded {len(images)} images")
273
 
274
- # Process in batches
275
- total_batches = (len(images) + batch_size - 1) // batch_size
276
 
277
- for batch_idx in range(0, len(images), batch_size):
278
- batch_images = images[batch_idx:batch_idx + batch_size]
279
- batch_paths = valid_paths[batch_idx:batch_idx + batch_size]
280
-
281
- # Evaluate with each selected model
282
- batch_results = {}
283
- for model_key in selected_models:
284
- if model_key in self.models:
285
- scores = await self.models[model_key].evaluate_batch(batch_images)
286
- batch_results[model_key] = scores
287
- logs.append(f"Processed batch {batch_idx//batch_size + 1}/{total_batches} with {self.models[model_key].name}")
 
 
 
 
 
 
 
 
 
 
288
 
289
- # Create results
290
- for i, (path, img) in enumerate(zip(batch_paths, batch_images)):
291
- result = EvaluationResult(
292
- file_name=Path(path).name,
293
- image_path=path
294
- )
295
 
296
- for model_key in selected_models:
297
- if model_key in batch_results:
298
- result.scores[model_key] = batch_results[model_key][i]
 
 
 
 
 
 
 
299
 
300
- result.calculate_final_score(selected_models)
301
- results.append(result)
302
-
303
- # Update progress
304
- if progress_callback:
305
- progress = (batch_idx + batch_size) / len(images) * 100
306
- progress_callback(min(progress, 100))
 
307
 
308
- self.results = results
309
- return results, logs
 
 
 
 
 
310
 
311
- def get_results_dataframe(self, selected_models: List[str]) -> pd.DataFrame:
312
- """Convert results to pandas DataFrame"""
313
  if not self.results:
314
  return pd.DataFrame()
315
 
316
  data = []
 
 
 
317
  for result in self.results:
318
  row = {
319
  'File Name': result.file_name,
320
- 'Image': result.image_path,
 
 
321
  }
322
 
323
- # Add model scores
324
- for model_key in selected_models:
325
- if model_key in self.models:
326
- score = result.scores.get(model_key)
327
- row[self.models[model_key].name] = f"{score:.4f}" if score is not None else "N/A"
328
 
329
  row['Final Score'] = f"{result.final_score:.4f}" if result.final_score is not None else "N/A"
330
  data.append(row)
331
 
332
- return pd.DataFrame(data)
 
 
 
 
 
 
 
 
333
 
334
 
335
  def create_interface():
336
  """Create the Gradio interface"""
337
  evaluator = ImageEvaluator()
338
 
339
- # Model options for checkbox
340
  model_options = [
341
- ("Aesthetic Shadow", "aesthetic_shadow"),
342
- ("Waifu Scorer", "waifu_scorer"),
343
- ("Aesthetic V2.5", "aesthetic_predictor_v2_5"),
344
- ("Anime Score", "anime_aesthetic")
345
  ]
346
-
 
 
 
 
347
  with gr.Blocks(theme=gr.themes.Soft(), title="Image Evaluation Tool") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
348
  gr.Markdown("""
349
  # 🎨 Advanced Image Evaluation Tool
350
 
@@ -361,158 +481,187 @@ def create_interface():
361
  )
362
 
363
  model_checkboxes = gr.CheckboxGroup(
364
- choices=[label for label, _ in model_options],
365
- value=[label for label, _ in model_options],
366
  label="Select Models",
367
- info="Choose which models to use for evaluation"
368
  )
369
 
370
- with gr.Row():
371
- batch_size = gr.Slider(
372
- minimum=1,
373
- maximum=64,
374
- value=8,
375
- step=1,
376
- label="Batch Size",
377
- info="Number of images to process at once"
378
- )
379
 
380
  with gr.Row():
381
  evaluate_btn = gr.Button("πŸš€ Evaluate Images", variant="primary", scale=2)
382
  clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)
383
 
384
- with gr.Column(scale=2):
385
- progress = gr.Progress()
386
- logs = gr.Textbox(
387
  label="Processing Logs",
388
  lines=10,
389
- max_lines=10,
390
- autoscroll=True
 
391
  )
392
 
393
- results_df = gr.Dataframe(
 
 
 
394
  label="Evaluation Results",
395
  interactive=False,
396
- wrap=True
 
 
 
 
397
  )
398
 
399
- download_btn = gr.Button("πŸ“₯ Download Results (CSV)", variant="secondary")
400
- download_file = gr.File(visible=False)
401
-
402
- # State for storing results
403
- results_state = gr.State([])
404
-
405
- async def process_images(files, selected_model_labels, batch_size, progress=gr.Progress()):
406
- """Process uploaded images"""
 
 
 
407
  if not files:
408
- return "Please upload images first", pd.DataFrame(), []
409
 
410
- # Convert labels to keys
411
- selected_models = [key for label, key in model_options if label in selected_model_labels]
412
 
413
- # Get file paths
414
- file_paths = [f.name for f in files]
 
 
415
 
416
  # Progress callback
417
- def update_progress(value):
418
- progress(value / 100, desc=f"Processing images... {value:.0f}%")
 
 
 
 
 
419
 
 
 
 
 
 
 
 
420
  # Evaluate images
421
- results, logs = await evaluator.evaluate_images(
422
- file_paths,
423
- selected_models,
424
- batch_size,
425
- update_progress
426
  )
427
 
428
- # Create DataFrame
429
- df = evaluator.get_results_dataframe(selected_models)
430
-
431
- # Format logs
432
- log_text = "\n".join(logs[-10:]) # Show last 10 logs
433
-
434
- return log_text, df, results
435
-
436
- def update_results_on_model_change(selected_model_labels, results):
437
- """Update results when model selection changes"""
438
- if not results:
439
- return pd.DataFrame()
440
-
441
- # Convert labels to keys
442
- selected_models = [key for label, key in model_options if label in selected_model_labels]
443
 
444
- # Recalculate final scores
445
- for result in results:
446
- result.calculate_final_score(selected_models)
 
 
 
 
 
 
 
 
447
 
448
- # Update evaluator results
449
- evaluator.results = results
 
450
 
451
- # Create updated DataFrame
452
- return evaluator.get_results_dataframe(selected_models)
453
-
454
- def clear_interface():
455
- """Clear all results"""
456
- return "", pd.DataFrame(), [], None
457
-
458
- def prepare_download(selected_model_labels, results):
459
- """Prepare CSV file for download"""
460
- if not results:
461
  return None
 
 
462
 
463
- # Convert labels to keys
464
- selected_models = [key for label, key in model_options if label in selected_model_labels]
465
-
466
- # Get DataFrame
467
- df = evaluator.get_results_dataframe(selected_models)
468
 
469
- # Save to temporary file
 
 
 
470
  import tempfile
471
- with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
472
- df.to_csv(f, index=False)
473
- return f.name
474
 
475
- # Event handlers
476
  evaluate_btn.click(
477
- fn=process_images,
478
- inputs=[input_files, model_checkboxes, batch_size],
479
- outputs=[logs, results_df, results_state]
480
  )
481
 
482
  model_checkboxes.change(
483
- fn=update_results_on_model_change,
484
- inputs=[model_checkboxes, results_state],
485
- outputs=[results_df]
486
  )
487
 
488
  clear_btn.click(
489
- fn=clear_interface,
490
- outputs=[logs, results_df, results_state, download_file]
491
  )
492
 
493
- download_btn.click(
494
- fn=prepare_download,
495
- inputs=[model_checkboxes, results_state],
496
- outputs=[download_file]
497
  )
498
 
499
  gr.Markdown("""
500
  ### πŸ“ Notes
501
- - **Model Selection**: Choose which models to use for evaluation. Final score is the average of selected models.
502
- - **Batch Size**: Adjust based on your GPU memory. Larger batches process faster.
503
- - **Results Table**: Click column headers to sort. The table updates automatically when models are selected/deselected.
504
- - **Download**: Export results as CSV for further analysis.
505
 
506
- ### 🎯 Score Interpretation
507
  - **7-10**: High quality/aesthetic appeal
508
  - **5-7**: Medium quality
509
  - **0-5**: Lower quality
 
510
  """)
511
 
512
  return demo
513
 
514
 
515
  if __name__ == "__main__":
 
 
 
 
516
  # Create and launch the interface
517
- demo = create_interface()
518
- demo.queue().launch()
 
 
 
20
  logger = logging.getLogger(__name__)
21
 
22
  # Import aesthetic predictor function
23
+ # Ensure 'aesthetic_predictor_v2_5.py' is in the same directory or accessible in PYTHONPATH
24
+ # from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
25
+ # Placeholder for the import if the file is missing, to allow syntax checking
26
+ def convert_v2_5_from_siglip(low_cpu_mem_usage=True, trust_remote_code=True):
27
+ # This is a placeholder. Replace with actual import and ensure the function exists.
28
+ logger.warning("Using placeholder for convert_v2_5_from_siglip. Ensure the actual implementation is available.")
29
+ # Mocking a model and preprocessor structure
30
+ mock_model = torch.nn.Sequential(torch.nn.Linear(10,1)) # Dummy model
31
+ mock_preprocessor = lambda images, return_tensors: {"pixel_values": torch.randn(len(images), 3, 224, 224)} # Dummy preprocessor
32
+ return mock_model, mock_preprocessor
33
 
34
 
35
  @dataclass
 
75
  try:
76
  results = self.model(images)
77
  scores = []
78
+ for result_set in results: # self.model(images) returns a list of lists of dicts if multiple images
79
+ if not isinstance(result_set, list): # If single image, it returns a list of dicts
80
+ result_set = [result_set]
81
+
82
+ # Correctly handle varying structures from the pipeline
83
+ hq_score = 0
84
+ # The pipeline might return a list of dicts for each image, or just a list of dicts for a single image
85
+ # For multiple images, results is List[List[Dict]]
86
+ # For a single image, results is List[Dict] - pipeline might batch internally
87
+ # The provided code expects `results` to be a list of predictions, where each prediction is a list of class scores.
88
+ current_image_predictions = result_set
89
+ if isinstance(result_set, list) and len(result_set) > 0 and isinstance(result_set[0], list) and len(images) == 1:
90
+ # Handle cases where pipeline wraps single image result in an extra list
91
+ current_image_predictions = result_set[0]
92
+
93
+ hq_score_found = next((p['score'] for p in current_image_predictions if p['label'] == 'hq'), 0)
94
+ scores.append(float(np.clip(hq_score_found * 10.0, 0.0, 10.0)))
95
  return scores
96
  except Exception as e:
97
  logger.error(f"Error in {self.name}: {e}")
 
109
  try:
110
  import clip
111
 
 
112
  self.mlp = self._create_mlp()
113
  model_path = hf_hub_download("Eugeoter/waifu-scorer-v3", "model.pth")
114
  state_dict = torch.load(model_path, map_location=self.device)
115
+
116
+ # --- FIX for state_dict key mismatch ---
117
+ # Check if keys are prefixed (e.g., "layers.0.weight") and adjust
118
+ if any(key.startswith("layers.") for key in state_dict.keys()):
119
+ new_state_dict = {}
120
+ for k, v in state_dict.items():
121
+ if k.startswith("layers."):
122
+ new_state_dict[k[len("layers."):]] = v
123
+ else:
124
+ # Keep other keys if any (though error suggests all relevant keys were prefixed)
125
+ new_state_dict[k] = v
126
+ state_dict = new_state_dict
127
+ # --- END FIX ---
128
+
129
  self.mlp.load_state_dict(state_dict)
130
  self.mlp.to(self.device).eval()
131
 
 
132
  self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device)
133
  self.available = True
134
+ except ImportError:
135
+ logger.error(f"Failed to load {self.name}: PyPI package 'clip' (openai-clip) not found. Please install it.")
136
+ self.available = False
137
  except Exception as e:
138
  logger.error(f"Failed to load {self.name}: {e}")
139
  self.available = False
 
168
  return [None] * len(images)
169
 
170
  try:
 
171
  image_tensors = torch.cat([self.preprocess(img).unsqueeze(0) for img in images])
172
  image_tensors = image_tensors.to(self.device)
173
 
 
174
  features = self.clip_model.encode_image(image_tensors)
175
+ features = features.float() # Ensure features are float32 for MLP
176
  features = features / features.norm(dim=-1, keepdim=True)
177
  predictions = self.mlp(features)
178
 
 
188
  def __init__(self):
189
  super().__init__("Aesthetic V2.5")
190
  logger.info(f"Loading {self.name} model...")
191
+ try:
192
+ self.model, self.preprocessor = convert_v2_5_from_siglip(
193
+ low_cpu_mem_usage=True,
194
+ trust_remote_code=True, # Be cautious with trust_remote_code=True
195
+ )
196
+ if self.device == 'cuda':
197
+ self.model = self.model.to(torch.bfloat16).cuda()
198
+ self.available = True
199
+ except Exception as e:
200
+ logger.error(f"Failed to load {self.name}: {e}. Ensure 'aesthetic_predictor_v2_5.py' is correct and dependencies are installed.")
201
+ self.available = False
202
+ self.model, self.preprocessor = None, None
203
+
204
+
205
  @torch.no_grad()
206
  async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
207
+ if not self.available:
208
+ return [None] * len(images)
209
  try:
210
  images_rgb = [img.convert("RGB") for img in images]
211
+ pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt")["pixel_values"] # Access pixel_values key
212
 
213
  if self.device == 'cuda':
214
  pixel_values = pixel_values.to(torch.bfloat16).cuda()
215
+ else:
216
+ pixel_values = pixel_values.float() # Ensure correct dtype for CPU
217
 
218
+ logits = self.model(pixel_values).logits # Get logits if model output is a dataclass/dict
219
+ # If model directly returns logits tensor:
220
+ # logits = self.model(pixel_values)
221
+
222
+ scores = logits.squeeze().float().cpu().numpy()
223
+ if scores.ndim == 0: # Handle single image case
224
+ scores = np.array([scores.item()]) # Use .item() for scalar tensor
225
 
226
  return [float(np.clip(s, 0.0, 10.0)) for s in scores]
227
  except Exception as e:
 
234
  def __init__(self):
235
  super().__init__("Anime Score")
236
  logger.info(f"Loading {self.name} model...")
237
+ try:
238
+ model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
239
+ self.session = rt.InferenceSession(model_path, providers=['CPUExecutionProvider'])
240
+ self.available = True
241
+ except Exception as e:
242
+ logger.error(f"Failed to load {self.name}: {e}")
243
+ self.available = False
244
+ self.session = None
245
+
246
  async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
247
+ if not self.available:
248
+ return [None] * len(images)
249
  scores = []
250
  for img in images:
251
  try:
252
  score = self._process_single_image(img)
253
  scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
254
  except Exception as e:
255
+ logger.error(f"Error in {self.name} for single image processing: {e}")
256
  scores.append(None)
257
  return scores
258
 
259
  def _process_single_image(self, img: Image.Image) -> float:
260
  """Process a single image through the model"""
261
+ # Ensure image is RGB
262
+ img_rgb = img.convert("RGB")
263
+ img_np = np.array(img_rgb).astype(np.float32) / 255.0
264
+
265
+ # Original model expects BGR, but most image ops are RGB.
266
+ # If ONNX model was trained on BGR, conversion might be needed.
267
+ # Assuming model takes RGB based on common practices unless specified.
268
+ # If it expects BGR: img_np = cv2.cvtColor(np.array(img.convert("RGB")), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
269
+
270
+
271
+ size = 224 # Typical size for many aesthetic models, 768 is very large for direct input.
272
+ # The original notebook for skytnt/anime-aesthetic uses 224x224.
273
+ # Let's assume 224 unless documentation says 768.
274
+ # The error log doesn't specify input size issues, but 768 is unusually large for this type of ONNX model.
275
+ # Sticking to original code's 768 for now, but this is a potential point of error if model expects 224.
276
+
277
  h, w = img_np.shape[:2]
278
 
 
279
  if h > w:
280
  new_h, new_w = size, int(size * w / h)
281
  else:
282
  new_h, new_w = int(size * h / w), size
283
 
284
+ resized_img = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_AREA) # Use INTER_AREA for shrinking
285
+
286
+ canvas = np.ones((size, size, 3), dtype=np.float32) * 0.5 # Pad with gray, or use black (0)
287
+
288
  pad_h = (size - new_h) // 2
289
  pad_w = (size - new_w) // 2
 
290
 
291
+ canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w, :] = resized_img
292
+
293
+ input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :].astype(np.float32)
294
  return self.session.run(None, {"img": input_tensor})[0].item()
295
 
296
 
 
312
 
313
  for key, model_class in model_classes:
314
  try:
315
+ model_instance = model_class()
316
+ # Store only if model is available (loaded successfully)
317
+ if hasattr(model_instance, 'available') and model_instance.available:
318
+ self.models[key] = model_instance
319
+ logger.info(f"Successfully loaded and initialized {model_instance.name} ({key})")
320
+ elif not hasattr(model_instance, 'available'): # For models without explicit 'available' flag
321
+ self.models[key] = model_instance
322
+ logger.info(f"Successfully loaded and initialized {model_instance.name} ({key}) (availability not explicitly tracked)")
323
+ else:
324
+ logger.warning(f"{model_instance.name} ({key}) was not loaded successfully and will be skipped.")
325
  except Exception as e:
326
+ logger.error(f"Failed to initialize {key}: {e}")
327
 
328
  async def evaluate_images(
329
  self,
 
334
  ) -> Tuple[List[EvaluationResult], List[str]]:
335
  """Evaluate images with selected models"""
336
  logs = []
337
+ current_results = [] # Use a local list for current evaluation
338
 
339
+ images_data = [] # Store tuples of (image, original_path)
340
+ for path_obj in file_paths: # file_paths are UploadFile objects from Gradio
341
+ path = path_obj.name # .name gives the temporary file path
 
342
  try:
343
  img = Image.open(path).convert("RGB")
344
+ images_data.append({"image": img, "path": path, "name": Path(path).name})
 
345
  except Exception as e:
346
  logs.append(f"Failed to load {Path(path).name}: {e}")
347
 
348
+ if not images_data:
349
  logs.append("No valid images to process")
350
+ return current_results, logs
 
 
351
 
352
+ logs.append(f"Loaded {len(images_data)} images")
 
353
 
354
+ # Filter selected_models to only include those that were successfully initialized
355
+ active_selected_models = [m_key for m_key in selected_models if m_key in self.models]
356
+ if len(active_selected_models) != len(selected_models):
357
+ disabled_models = set(selected_models) - set(active_selected_models)
358
+ logs.append(f"Warning: The following models were selected but are not available: {', '.join(disabled_models)}")
359
+
360
+
361
+ # Initialize results for all images first
362
+ for data in images_data:
363
+ result = EvaluationResult(
364
+ file_name=data["name"],
365
+ image_path=data["path"] # Store original path for display if needed
366
+ )
367
+ current_results.append(result)
368
+
369
+ total_images = len(images_data)
370
+ processed_count = 0
371
+
372
+ for model_key in active_selected_models:
373
+ model_instance = self.models[model_key]
374
+ logs.append(f"Processing with {model_instance.name}...")
375
 
376
+ for i in range(0, total_images, batch_size):
377
+ batch_data = images_data[i:i + batch_size]
378
+ batch_images_pil = [d["image"] for d in batch_data]
 
 
 
379
 
380
+ try:
381
+ scores = await model_instance.evaluate_batch(batch_images_pil)
382
+ for k, score in enumerate(scores):
383
+ # Find the corresponding EvaluationResult object
384
+ # This assumes current_results is ordered the same as images_data
385
+ current_results[i+k].scores[model_key] = score
386
+ except Exception as e:
387
+ logger.error(f"Error evaluating batch with {model_instance.name}: {e}")
388
+ for k in range(len(batch_images_pil)):
389
+ current_results[i+k].scores[model_key] = None
390
 
391
+ processed_count += len(batch_images_pil)
392
+ if progress_callback:
393
+ # Progress based on overall images processed per model, then average over models
394
+ # This logic might need refinement for a smoother progress bar experience
395
+ current_model_idx = active_selected_models.index(model_key)
396
+ overall_progress = ((current_model_idx / len(active_selected_models)) + \
397
+ ((i + len(batch_data)) / total_images) / len(active_selected_models)) * 100
398
+ progress_callback(min(overall_progress, 100), f"Model: {model_instance.name}, Batch {i//batch_size + 1}")
399
 
400
+ # Calculate final scores for all results
401
+ for result in current_results:
402
+ result.calculate_final_score(active_selected_models)
403
+
404
+ logs.append("Evaluation complete.")
405
+ self.results = current_results # Update the main results list
406
+ return current_results, logs
407
 
408
+ def get_results_dataframe(self, selected_models_keys: List[str]) -> pd.DataFrame:
 
409
  if not self.results:
410
  return pd.DataFrame()
411
 
412
  data = []
413
+ # Ensure selected_models_keys only contains keys of successfully loaded models
414
+ valid_selected_models_keys = [key for key in selected_models_keys if key in self.models]
415
+
416
  for result in self.results:
417
  row = {
418
  'File Name': result.file_name,
419
+ # For Gradio display, we might want to show the image itself
420
+ # 'Image': result.image_path, # This will show the temp path
421
+ 'Image': gr.Image(result.image_path, type="pil", height=100, width=100) # Display thumbnail
422
  }
423
 
424
+ for model_key in valid_selected_models_keys:
425
+ model_name = self.models[model_key].name
426
+ score = result.scores.get(model_key)
427
+ row[model_name] = f"{score:.4f}" if score is not None else "N/A"
 
428
 
429
  row['Final Score'] = f"{result.final_score:.4f}" if result.final_score is not None else "N/A"
430
  data.append(row)
431
 
432
+ # Define column order
433
+ column_order = ['File Name', 'Image'] + \
434
+ [self.models[key].name for key in valid_selected_models_keys if key in self.models] + \
435
+ ['Final Score']
436
+
437
+ df = pd.DataFrame(data)
438
+ if not df.empty:
439
+ df = df[column_order] # Reorder columns
440
+ return df
441
 
442
 
443
  def create_interface():
444
  """Create the Gradio interface"""
445
  evaluator = ImageEvaluator()
446
 
 
447
  model_options = [
448
+ (model.name, key) for key, model in evaluator.models.items()
 
 
 
449
  ]
450
+ # If some models failed to load, model_options will be shorter.
451
+ # Provide default selected models based on successfully loaded ones.
452
+ default_selected_model_labels = [name for name, key in model_options]
453
+
454
+
455
  with gr.Blocks(theme=gr.themes.Soft(), title="Image Evaluation Tool") as demo:
456
+ # NOTE on Gradio TypeError:
457
+ # The traceback "TypeError: argument of type 'bool' is not iterable" during Gradio startup
458
+ # (specifically in `gradio_client/utils.py` while processing component schemas)
459
+ # often indicates an incompatibility with the Gradio version being used or a bug
460
+ # in how Gradio generates schemas for certain component configurations.
461
+ # The most common recommendation is to:
462
+ # 1. Ensure your Gradio library is up-to-date: `pip install --upgrade gradio`
463
+ # 2. If the error persists, try simplifying complex component configurations or
464
+ # testing with a known stable version of Gradio.
465
+ # The code below follows standard Gradio practices, so the error is likely
466
+ # environment-related if it persists after the WaifuScorer fix.
467
+
468
  gr.Markdown("""
469
  # 🎨 Advanced Image Evaluation Tool
470
 
 
481
  )
482
 
483
  model_checkboxes = gr.CheckboxGroup(
484
+ choices=[label for label, _ in model_options], # Use labels for choices
485
+ value=default_selected_model_labels, # Default to all loaded models
486
  label="Select Models",
487
+ info="Choose which models to use for evaluation. Models that failed to load will not be available."
488
  )
489
 
490
+ batch_size_slider = gr.Slider( # Renamed to avoid conflict with batch_size variable name
491
+ minimum=1,
492
+ maximum=32, # Max 64 might be too high for some GPUs
493
+ value=8,
494
+ step=1,
495
+ label="Batch Size",
496
+ info="Number of images to process at once per model."
497
+ )
 
498
 
499
  with gr.Row():
500
  evaluate_btn = gr.Button("πŸš€ Evaluate Images", variant="primary", scale=2)
501
  clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)
502
 
503
+ with gr.Column(scale=3): # Increased scale for results
504
+ # Using gr.Textbox for logs, as gr.Progress is now a status tracker
505
+ logs_display = gr.Textbox(
506
  label="Processing Logs",
507
  lines=10,
508
+ max_lines=20, # Allow more lines
509
+ autoscroll=True,
510
+ interactive=False
511
  )
512
 
513
+ # Using gr.Label for progress status updates
514
+ progress_status = gr.Label(label="Progress")
515
+
516
+ results_df_display = gr.Dataframe(
517
  label="Evaluation Results",
518
  interactive=False,
519
+ wrap=True,
520
+ # Define column types for better display, especially for images
521
+ # headers=['File Name', 'Image'] + default_selected_model_labels + ['Final Score'],
522
+ # col_count=(len(default_selected_model_labels) + 3, "fixed"),
523
+ # datatype=['str', 'image'] + ['number'] * (len(default_selected_model_labels) + 1)
524
  )
525
 
526
+ download_button = gr.Button("πŸ“₯ Download Results (CSV)", variant="secondary") # Changed from gr.Button to potentially use gr.DownloadButton later
527
+ # download_file_output = gr.File(label="Download CSV", visible=False, interactive=False)
528
+ # Using gr.File for download output triggered by a regular button
529
+ download_file_output_component = gr.File(label="Download", visible=False)
530
+
531
+
532
+ # State for storing full EvaluationResult objects if needed for more complex interactions
533
+ # For this setup, regenerating DataFrame from evaluator.results is generally sufficient
534
+ # results_state = gr.State([]) # If storing raw results is complex, simplify or manage carefully
535
+
536
+ async def run_evaluation(files, selected_model_labels, current_batch_size, progress=gr.Progress(track_tqdm=True)):
537
  if not files:
538
+ return "Please upload images first.", pd.DataFrame(), [], "No files uploaded."
539
 
540
+ # Convert display labels back to model keys
541
+ selected_model_keys = [key for label, key in model_options if label in selected_model_labels]
542
 
543
+ if not selected_model_keys:
544
+ return "Please select at least one model.", pd.DataFrame(), [], "No models selected."
545
+
546
+ # file_paths = [f.name for f in files] # .name gives temp path of UploadFile
547
 
548
  # Progress callback
549
+ # def update_progress_display(value, desc="Processing..."):
550
+ # progress(value / 100, desc=f"{desc} {value:.0f}%")
551
+ # return f"{desc} {value:.0f}%" # For gr.Label
552
+
553
+ # Use gr.Progress context for automatic updates with iterators
554
+ # However, for manual updates with batching, direct calls are fine.
555
+ # We'll update logs_display and progress_status manually.
556
 
557
+ progress_updates = []
558
+ def progress_callback_for_eval(p_value, p_desc):
559
+ progress(p_value / 100, desc=p_desc) # Update the main progress component
560
+ # logs_display.value += f"\n{p_desc} - {p_value:.0f}%" # This will make logs messy
561
+ progress_updates.append(f"{p_desc} - {p_value:.0f}%")
562
+
563
+
564
  # Evaluate images
565
+ processed_results, log_messages = await evaluator.evaluate_images(
566
+ files, # Pass the list of UploadFile objects directly
567
+ selected_model_keys,
568
+ int(current_batch_size), # Ensure batch_size is int
569
+ progress_callback_for_eval # Pass the callback
570
  )
571
 
572
+ df = evaluator.get_results_dataframe(selected_model_keys)
573
+ log_text = "\n".join(log_messages + progress_updates)
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
+ final_status = "Evaluation complete." if processed_results else "Evaluation failed or no results."
576
+ progress(1.0, desc=final_status) # Mark progress as complete
577
+
578
+ return log_text, df, final_status # Removed results_state for simplicity
579
+
580
+ def handle_model_selection_change(selected_model_labels_updated):
581
+ # Called when checkbox group changes. evaluator.results should already be populated.
582
+ if not evaluator.results:
583
+ return pd.DataFrame() # No results to re-filter/re-calculate
584
+
585
+ selected_model_keys_updated = [key for label, key in model_options if label in selected_model_labels_updated]
586
 
587
+ # Recalculate final scores for all existing results based on new selection
588
+ for res_obj in evaluator.results:
589
+ res_obj.calculate_final_score(selected_model_keys_updated)
590
 
591
+ return evaluator.get_results_dataframe(selected_model_keys_updated)
592
+
593
+ def clear_all_outputs():
594
+ evaluator.results = [] # Clear stored results in the evaluator
595
+ return "", pd.DataFrame(), "Cleared.", None # Log, DataFrame, Progress Status, Download File
596
+
597
+ def generate_csv_for_download(selected_model_labels_for_csv):
598
+ if not evaluator.results:
599
+ gr.Warning("No results to download.")
 
600
  return None
601
+
602
+ selected_model_keys_for_csv = [key for label, key in model_options if label in selected_model_labels_for_csv]
603
 
604
+ # Get DataFrame, but exclude the gr.Image column for CSV
605
+ df_for_csv = evaluator.get_results_dataframe(selected_model_keys_for_csv).copy()
606
+ if 'Image' in df_for_csv.columns:
607
+ df_for_csv.drop(columns=['Image'], inplace=True)
 
608
 
609
+ if df_for_csv.empty:
610
+ gr.Warning("No data to download based on current selection.")
611
+ return None
612
+
613
  import tempfile
614
+ with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.csv', encoding='utf-8') as tmp_file:
615
+ df_for_csv.to_csv(tmp_file.name, index=False)
616
+ return tmp_file.name
617
 
 
618
  evaluate_btn.click(
619
+ fn=run_evaluation,
620
+ inputs=[input_files, model_checkboxes, batch_size_slider],
621
+ outputs=[logs_display, results_df_display, progress_status] # Removed results_state
622
  )
623
 
624
  model_checkboxes.change(
625
+ fn=handle_model_selection_change,
626
+ inputs=[model_checkboxes],
627
+ outputs=[results_df_display]
628
  )
629
 
630
  clear_btn.click(
631
+ fn=clear_all_outputs,
632
+ outputs=[logs_display, results_df_display, progress_status, download_file_output_component]
633
  )
634
 
635
+ download_button.click(
636
+ fn=generate_csv_for_download,
637
+ inputs=[model_checkboxes],
638
+ outputs=[download_file_output_component]
639
  )
640
 
641
  gr.Markdown("""
642
  ### πŸ“ Notes
643
+ - **Model Selection**: Choose which models to use for evaluation. The final score is the average of the selected models. Models that failed to load during startup will not be listed or will be ignored.
644
+ - **Batch Size**: Adjust based on your system's VRAM and RAM. Smaller batches use less memory but may be slower overall.
645
+ - **Results Table**: Displays scores from selected models and the final average. Images are shown as thumbnails.
646
+ - **Download**: Export results (excluding image thumbnails) as a CSV file for further analysis.
647
 
648
+ ### 🎯 Score Interpretation (General Guide)
649
  - **7-10**: High quality/aesthetic appeal
650
  - **5-7**: Medium quality
651
  - **0-5**: Lower quality
652
+ _(Note: Score ranges and interpretations can vary between models.)_
653
  """)
654
 
655
  return demo
656
 
657
 
658
  if __name__ == "__main__":
659
+ # Ensure 'aesthetic_predictor_v2_5.py' exists and 'openai-clip' is installed for WaifuScorer
660
+ # Example: pip install openai-clip transformers==4.30.2 onnxruntime gradio pandas Pillow opencv-python
661
+ # Check specific model requirements.
662
+
663
  # Create and launch the interface
664
+ app_interface = create_interface()
665
+ # Adding .queue() is good for handling multiple users or long-running tasks.
666
+ # Set debug=True for more detailed Gradio errors during development.
667
+ app_interface.queue().launch(debug=True)