VOIDER commited on
Commit
fab9033
·
verified ·
1 Parent(s): d093305

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +344 -604
app.py CHANGED
@@ -1,667 +1,407 @@
1
  import os
2
- import asyncio
3
- from typing import List, Dict, Optional, Tuple, Any
4
- from dataclasses import dataclass, field
5
  from pathlib import Path
6
- import logging
7
 
8
  import cv2
 
9
  import numpy as np
 
10
  import torch
11
  import onnxruntime as rt
12
  from PIL import Image
13
- import gradio as gr
14
- from transformers import pipeline
15
  from huggingface_hub import hf_hub_download
16
- import pandas as pd
17
-
18
- # Configure logging
19
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
- logger = logging.getLogger(__name__)
21
-
22
- # Import aesthetic predictor function
23
- # Ensure 'aesthetic_predictor_v2_5.py' is in the same directory or accessible in PYTHONPATH
24
- # from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
25
- # Placeholder for the import if the file is missing, to allow syntax checking
26
- def convert_v2_5_from_siglip(low_cpu_mem_usage=True, trust_remote_code=True):
27
- # This is a placeholder. Replace with actual import and ensure the function exists.
28
- logger.warning("Using placeholder for convert_v2_5_from_siglip. Ensure the actual implementation is available.")
29
- # Mocking a model and preprocessor structure
30
- mock_model = torch.nn.Sequential(torch.nn.Linear(10,1)) # Dummy model
31
- mock_preprocessor = lambda images, return_tensors: {"pixel_values": torch.randn(len(images), 3, 224, 224)} # Dummy preprocessor
32
- return mock_model, mock_preprocessor
33
-
34
-
35
- @dataclass
36
- class EvaluationResult:
37
- """Data class for storing image evaluation results"""
38
- file_name: str
39
- image_path: str
40
- scores: Dict[str, Optional[float]] = field(default_factory=dict)
41
- final_score: Optional[float] = None
42
-
43
- def calculate_final_score(self, selected_models: List[str]) -> None:
44
- """Calculate the average score from selected models"""
45
- valid_scores = [
46
- score for model, score in self.scores.items()
47
- if model in selected_models and score is not None
48
- ]
49
- self.final_score = np.mean(valid_scores) if valid_scores else None
50
-
51
-
52
- class BaseModel:
53
- """Base class for all evaluation models"""
54
- def __init__(self, name: str):
55
- self.name = name
56
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
57
-
58
- async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
59
- """Evaluate a batch of images"""
60
- raise NotImplementedError
61
-
62
-
63
- class AestheticShadowModel(BaseModel):
64
- """Aesthetic Shadow V2 model implementation"""
65
- def __init__(self):
66
- super().__init__("Aesthetic Shadow")
67
- logger.info(f"Loading {self.name} model...")
68
- self.model = pipeline(
69
- "image-classification",
70
- model="NeoChen1024/aesthetic-shadow-v2-backup",
71
- device=0 if self.device == 'cuda' else -1
72
- )
73
-
74
- async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
75
- try:
76
- results = self.model(images)
77
- scores = []
78
- for result_set in results: # self.model(images) returns a list of lists of dicts if multiple images
79
- if not isinstance(result_set, list): # If single image, it returns a list of dicts
80
- result_set = [result_set]
81
-
82
- # Correctly handle varying structures from the pipeline
83
- hq_score = 0
84
- # The pipeline might return a list of dicts for each image, or just a list of dicts for a single image
85
- # For multiple images, results is List[List[Dict]]
86
- # For a single image, results is List[Dict] - pipeline might batch internally
87
- # The provided code expects `results` to be a list of predictions, where each prediction is a list of class scores.
88
- current_image_predictions = result_set
89
- if isinstance(result_set, list) and len(result_set) > 0 and isinstance(result_set[0], list) and len(images) == 1:
90
- # Handle cases where pipeline wraps single image result in an extra list
91
- current_image_predictions = result_set[0]
92
-
93
- hq_score_found = next((p['score'] for p in current_image_predictions if p['label'] == 'hq'), 0)
94
- scores.append(float(np.clip(hq_score_found * 10.0, 0.0, 10.0)))
95
- return scores
96
- except Exception as e:
97
- logger.error(f"Error in {self.name}: {e}")
98
- return [None] * len(images)
99
-
100
-
101
- class WaifuScorerModel(BaseModel):
102
- """Waifu Scorer V3 model implementation"""
103
- def __init__(self):
104
- super().__init__("Waifu Scorer")
105
- logger.info(f"Loading {self.name} model...")
106
- self._load_model()
107
-
108
- def _load_model(self):
109
- try:
110
- import clip
111
-
112
- self.mlp = self._create_mlp()
113
- model_path = hf_hub_download("Eugeoter/waifu-scorer-v3", "model.pth")
114
- state_dict = torch.load(model_path, map_location=self.device)
115
-
116
- # --- FIX for state_dict key mismatch ---
117
- # Check if keys are prefixed (e.g., "layers.0.weight") and adjust
118
- if any(key.startswith("layers.") for key in state_dict.keys()):
119
- new_state_dict = {}
120
- for k, v in state_dict.items():
121
- if k.startswith("layers."):
122
- new_state_dict[k[len("layers."):]] = v
123
- else:
124
- # Keep other keys if any (though error suggests all relevant keys were prefixed)
125
- new_state_dict[k] = v
126
- state_dict = new_state_dict
127
- # --- END FIX ---
128
-
129
- self.mlp.load_state_dict(state_dict)
130
- self.mlp.to(self.device).eval()
131
-
132
- self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device)
133
- self.available = True
134
- except ImportError:
135
- logger.error(f"Failed to load {self.name}: PyPI package 'clip' (openai-clip) not found. Please install it.")
136
- self.available = False
137
- except Exception as e:
138
- logger.error(f"Failed to load {self.name}: {e}")
139
- self.available = False
140
-
141
- def _create_mlp(self) -> torch.nn.Module:
142
- """Create the MLP architecture"""
143
- return torch.nn.Sequential(
144
- torch.nn.Linear(768, 2048),
145
- torch.nn.ReLU(),
146
- torch.nn.BatchNorm1d(2048),
147
- torch.nn.Dropout(0.3),
148
- torch.nn.Linear(2048, 512),
149
- torch.nn.ReLU(),
150
- torch.nn.BatchNorm1d(512),
151
- torch.nn.Dropout(0.3),
152
- torch.nn.Linear(512, 256),
153
- torch.nn.ReLU(),
154
- torch.nn.BatchNorm1d(256),
155
- torch.nn.Dropout(0.2),
156
- torch.nn.Linear(256, 128),
157
- torch.nn.ReLU(),
158
- torch.nn.BatchNorm1d(128),
159
- torch.nn.Dropout(0.1),
160
- torch.nn.Linear(128, 32),
161
- torch.nn.ReLU(),
162
- torch.nn.Linear(32, 1)
163
  )
164
-
165
- @torch.no_grad()
166
- async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
167
- if not self.available:
168
- return [None] * len(images)
169
-
170
- try:
171
- image_tensors = torch.cat([self.preprocess(img).unsqueeze(0) for img in images])
172
- image_tensors = image_tensors.to(self.device)
173
-
174
- features = self.clip_model.encode_image(image_tensors)
175
- features = features.float() # Ensure features are float32 for MLP
176
- features = features / features.norm(dim=-1, keepdim=True)
177
- predictions = self.mlp(features)
178
-
179
- scores = predictions.clamp(0, 10).cpu().numpy().flatten().tolist()
180
- return scores
181
- except Exception as e:
182
- logger.error(f"Error in {self.name}: {e}")
183
- return [None] * len(images)
184
-
185
-
186
- class AestheticPredictorV25Model(BaseModel):
187
- """Aesthetic Predictor V2.5 model implementation"""
188
- def __init__(self):
189
- super().__init__("Aesthetic V2.5")
190
- logger.info(f"Loading {self.name} model...")
191
- try:
192
- self.model, self.preprocessor = convert_v2_5_from_siglip(
193
- low_cpu_mem_usage=True,
194
- trust_remote_code=True, # Be cautious with trust_remote_code=True
195
- )
196
- if self.device == 'cuda':
197
- self.model = self.model.to(torch.bfloat16).cuda()
198
- self.available = True
199
- except Exception as e:
200
- logger.error(f"Failed to load {self.name}: {e}. Ensure 'aesthetic_predictor_v2_5.py' is correct and dependencies are installed.")
201
- self.available = False
202
- self.model, self.preprocessor = None, None
203
-
204
 
205
  @torch.no_grad()
206
- async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
207
- if not self.available:
208
- return [None] * len(images)
209
- try:
210
- images_rgb = [img.convert("RGB") for img in images]
211
- pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt")["pixel_values"] # Access pixel_values key
212
-
213
- if self.device == 'cuda':
214
- pixel_values = pixel_values.to(torch.bfloat16).cuda()
215
- else:
216
- pixel_values = pixel_values.float() # Ensure correct dtype for CPU
217
-
218
- logits = self.model(pixel_values).logits # Get logits if model output is a dataclass/dict
219
- # If model directly returns logits tensor:
220
- # logits = self.model(pixel_values)
221
-
222
- scores = logits.squeeze().float().cpu().numpy()
223
- if scores.ndim == 0: # Handle single image case
224
- scores = np.array([scores.item()]) # Use .item() for scalar tensor
225
-
226
- return [float(np.clip(s, 0.0, 10.0)) for s in scores]
227
- except Exception as e:
228
- logger.error(f"Error in {self.name}: {e}")
229
- return [None] * len(images)
230
-
231
-
232
- class AnimeAestheticModel(BaseModel):
233
- """Anime Aesthetic model implementation"""
234
- def __init__(self):
235
- super().__init__("Anime Score")
236
- logger.info(f"Loading {self.name} model...")
237
- try:
238
- model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
239
- self.session = rt.InferenceSession(model_path, providers=['CPUExecutionProvider'])
240
- self.available = True
241
- except Exception as e:
242
- logger.error(f"Failed to load {self.name}: {e}")
243
- self.available = False
244
- self.session = None
245
-
246
- async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
247
- if not self.available:
248
- return [None] * len(images)
249
  scores = []
250
- for img in images:
251
  try:
252
- score = self._process_single_image(img)
253
- scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
254
- except Exception as e:
255
- logger.error(f"Error in {self.name} for single image processing: {e}")
256
- scores.append(None)
257
  return scores
258
-
259
- def _process_single_image(self, img: Image.Image) -> float:
260
- """Process a single image through the model"""
261
- # Ensure image is RGB
262
- img_rgb = img.convert("RGB")
263
- img_np = np.array(img_rgb).astype(np.float32) / 255.0
264
-
265
- # Original model expects BGR, but most image ops are RGB.
266
- # If ONNX model was trained on BGR, conversion might be needed.
267
- # Assuming model takes RGB based on common practices unless specified.
268
- # If it expects BGR: img_np = cv2.cvtColor(np.array(img.convert("RGB")), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
269
 
 
 
270
 
271
- size = 224 # Typical size for many aesthetic models, 768 is very large for direct input.
272
- # The original notebook for skytnt/anime-aesthetic uses 224x224.
273
- # Let's assume 224 unless documentation says 768.
274
- # The error log doesn't specify input size issues, but 768 is unusually large for this type of ONNX model.
275
- # Sticking to original code's 768 for now, but this is a potential point of error if model expects 224.
276
 
 
 
 
 
277
  h, w = img_np.shape[:2]
278
-
279
  if h > w:
280
- new_h, new_w = size, int(size * w / h)
281
  else:
282
- new_h, new_w = int(size * h / w), size
283
 
284
- resized_img = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_AREA) # Use INTER_AREA for shrinking
 
 
 
285
 
286
- canvas = np.ones((size, size, 3), dtype=np.float32) * 0.5 # Pad with gray, or use black (0)
287
-
288
- pad_h = (size - new_h) // 2
289
- pad_w = (size - new_w) // 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
- canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w, :] = resized_img
292
 
293
- input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :].astype(np.float32)
294
- return self.session.run(None, {"img": input_tensor})[0].item()
 
 
 
295
 
 
 
 
296
 
297
- class ImageEvaluator:
298
- """Main class for managing image evaluation"""
299
- def __init__(self):
300
- self.models: Dict[str, BaseModel] = {}
301
- self._initialize_models()
302
- self.results: List[EvaluationResult] = []
303
 
304
- def _initialize_models(self):
305
- """Initialize all evaluation models"""
306
- model_classes = [
307
- ("aesthetic_shadow", AestheticShadowModel),
308
- ("waifu_scorer", WaifuScorerModel),
309
- ("aesthetic_predictor_v2_5", AestheticPredictorV25Model),
310
- ("anime_aesthetic", AnimeAestheticModel),
311
- ]
312
 
313
- for key, model_class in model_classes:
314
- try:
315
- model_instance = model_class()
316
- # Store only if model is available (loaded successfully)
317
- if hasattr(model_instance, 'available') and model_instance.available:
318
- self.models[key] = model_instance
319
- logger.info(f"Successfully loaded and initialized {model_instance.name} ({key})")
320
- elif not hasattr(model_instance, 'available'): # For models without explicit 'available' flag
321
- self.models[key] = model_instance
322
- logger.info(f"Successfully loaded and initialized {model_instance.name} ({key}) (availability not explicitly tracked)")
323
- else:
324
- logger.warning(f"{model_instance.name} ({key}) was not loaded successfully and will be skipped.")
325
- except Exception as e:
326
- logger.error(f"Failed to initialize {key}: {e}")
327
-
328
- async def evaluate_images(
329
- self,
330
- file_paths: List[str],
331
- selected_models: List[str],
332
- batch_size: int = 8,
333
- progress_callback = None
334
- ) -> Tuple[List[EvaluationResult], List[str]]:
335
- """Evaluate images with selected models"""
336
- logs = []
337
- current_results = [] # Use a local list for current evaluation
338
 
339
- images_data = [] # Store tuples of (image, original_path)
340
- for path_obj in file_paths: # file_paths are UploadFile objects from Gradio
341
- path = path_obj.name # .name gives the temporary file path
342
- try:
343
- img = Image.open(path).convert("RGB")
344
- images_data.append({"image": img, "path": path, "name": Path(path).name})
345
- except Exception as e:
346
- logs.append(f"Failed to load {Path(path).name}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
- if not images_data:
349
- logs.append("No valid images to process")
350
- return current_results, logs
 
 
 
 
 
351
 
352
- logs.append(f"Loaded {len(images_data)} images")
 
353
 
354
- # Filter selected_models to only include those that were successfully initialized
355
- active_selected_models = [m_key for m_key in selected_models if m_key in self.models]
356
- if len(active_selected_models) != len(selected_models):
357
- disabled_models = set(selected_models) - set(active_selected_models)
358
- logs.append(f"Warning: The following models were selected but are not available: {', '.join(disabled_models)}")
359
-
360
-
361
- # Initialize results for all images first
362
- for data in images_data:
363
- result = EvaluationResult(
364
- file_name=data["name"],
365
- image_path=data["path"] # Store original path for display if needed
366
- )
367
- current_results.append(result)
368
-
369
- total_images = len(images_data)
370
- processed_count = 0
371
-
372
- for model_key in active_selected_models:
373
- model_instance = self.models[model_key]
374
- logs.append(f"Processing with {model_instance.name}...")
375
 
376
- for i in range(0, total_images, batch_size):
377
- batch_data = images_data[i:i + batch_size]
378
- batch_images_pil = [d["image"] for d in batch_data]
379
 
380
- try:
381
- scores = await model_instance.evaluate_batch(batch_images_pil)
382
- for k, score in enumerate(scores):
383
- # Find the corresponding EvaluationResult object
384
- # This assumes current_results is ordered the same as images_data
385
- current_results[i+k].scores[model_key] = score
386
- except Exception as e:
387
- logger.error(f"Error evaluating batch with {model_instance.name}: {e}")
388
- for k in range(len(batch_images_pil)):
389
- current_results[i+k].scores[model_key] = None
390
 
391
- processed_count += len(batch_images_pil)
392
- if progress_callback:
393
- # Progress based on overall images processed per model, then average over models
394
- # This logic might need refinement for a smoother progress bar experience
395
- current_model_idx = active_selected_models.index(model_key)
396
- overall_progress = ((current_model_idx / len(active_selected_models)) + \
397
- ((i + len(batch_data)) / total_images) / len(active_selected_models)) * 100
398
- progress_callback(min(overall_progress, 100), f"Model: {model_instance.name}, Batch {i//batch_size + 1}")
399
-
400
- # Calculate final scores for all results
401
- for result in current_results:
402
- result.calculate_final_score(active_selected_models)
403
 
404
- logs.append("Evaluation complete.")
405
- self.results = current_results # Update the main results list
406
- return current_results, logs
407
-
408
- def get_results_dataframe(self, selected_models_keys: List[str]) -> pd.DataFrame:
409
- if not self.results:
410
- return pd.DataFrame()
411
 
412
- data = []
413
- # Ensure selected_models_keys only contains keys of successfully loaded models
414
- valid_selected_models_keys = [key for key in selected_models_keys if key in self.models]
415
-
416
- for result in self.results:
417
- row = {
418
- 'File Name': result.file_name,
419
- # For Gradio display, we might want to show the image itself
420
- # 'Image': result.image_path, # This will show the temp path
421
- 'Image': gr.Image(result.image_path, type="pil", height=100, width=100) # Display thumbnail
422
- }
423
-
424
- for model_key in valid_selected_models_keys:
425
- model_name = self.models[model_key].name
426
- score = result.scores.get(model_key)
427
- row[model_name] = f"{score:.4f}" if score is not None else "N/A"
428
 
429
- row['Final Score'] = f"{result.final_score:.4f}" if result.final_score is not None else "N/A"
430
- data.append(row)
431
 
432
- # Define column order
433
- column_order = ['File Name', 'Image'] + \
434
- [self.models[key].name for key in valid_selected_models_keys if key in self.models] + \
435
- ['Final Score']
436
-
437
- df = pd.DataFrame(data)
438
- if not df.empty:
439
- df = df[column_order] # Reorder columns
440
- return df
441
 
 
 
 
442
 
443
- def create_interface():
444
- """Create the Gradio interface"""
445
- evaluator = ImageEvaluator()
446
 
447
- model_options = [
448
- (model.name, key) for key, model in evaluator.models.items()
449
- ]
450
- # If some models failed to load, model_options will be shorter.
451
- # Provide default selected models based on successfully loaded ones.
452
- default_selected_model_labels = [name for name, key in model_options]
453
-
454
-
455
- with gr.Blocks(theme=gr.themes.Soft(), title="Image Evaluation Tool") as demo:
456
- # NOTE on Gradio TypeError:
457
- # The traceback "TypeError: argument of type 'bool' is not iterable" during Gradio startup
458
- # (specifically in `gradio_client/utils.py` while processing component schemas)
459
- # often indicates an incompatibility with the Gradio version being used or a bug
460
- # in how Gradio generates schemas for certain component configurations.
461
- # The most common recommendation is to:
462
- # 1. Ensure your Gradio library is up-to-date: `pip install --upgrade gradio`
463
- # 2. If the error persists, try simplifying complex component configurations or
464
- # testing with a known stable version of Gradio.
465
- # The code below follows standard Gradio practices, so the error is likely
466
- # environment-related if it persists after the WaifuScorer fix.
467
-
468
- gr.Markdown("""
469
- # 🎨 Advanced Image Evaluation Tool
470
-
471
- Evaluate images using state-of-the-art aesthetic and quality prediction models.
472
- Upload your images and select the models you want to use for evaluation.
473
- """)
474
-
475
  with gr.Row():
476
  with gr.Column(scale=1):
477
- input_files = gr.File(
 
478
  label="Upload Images",
479
  file_count="multiple",
480
- file_types=["image"]
481
- )
482
-
483
- model_checkboxes = gr.CheckboxGroup(
484
- choices=[label for label, _ in model_options], # Use labels for choices
485
- value=default_selected_model_labels, # Default to all loaded models
486
- label="Select Models",
487
- info="Choose which models to use for evaluation. Models that failed to load will not be available."
488
- )
489
-
490
- batch_size_slider = gr.Slider( # Renamed to avoid conflict with batch_size variable name
491
- minimum=1,
492
- maximum=32, # Max 64 might be too high for some GPUs
493
- value=8,
494
- step=1,
495
- label="Batch Size",
496
- info="Number of images to process at once per model."
497
  )
498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  with gr.Row():
500
- evaluate_btn = gr.Button("🚀 Evaluate Images", variant="primary", scale=2)
501
- clear_btn = gr.Button("🗑️ Clear", variant="secondary", scale=1)
502
 
503
- with gr.Column(scale=3): # Increased scale for results
504
- # Using gr.Textbox for logs, as gr.Progress is now a status tracker
505
- logs_display = gr.Textbox(
506
- label="Processing Logs",
507
- lines=10,
508
- max_lines=20, # Allow more lines
509
- autoscroll=True,
510
- interactive=False
511
  )
512
-
513
- # Using gr.Label for progress status updates
514
- progress_status = gr.Label(label="Progress")
515
-
516
- results_df_display = gr.Dataframe(
517
- label="Evaluation Results",
518
- interactive=False,
519
- wrap=True,
520
- # Define column types for better display, especially for images
521
- # headers=['File Name', 'Image'] + default_selected_model_labels + ['Final Score'],
522
- # col_count=(len(default_selected_model_labels) + 3, "fixed"),
523
- # datatype=['str', 'image'] + ['number'] * (len(default_selected_model_labels) + 1)
524
- )
525
-
526
- download_button = gr.Button("📥 Download Results (CSV)", variant="secondary") # Changed from gr.Button to potentially use gr.DownloadButton later
527
- # download_file_output = gr.File(label="Download CSV", visible=False, interactive=False)
528
- # Using gr.File for download output triggered by a regular button
529
- download_file_output_component = gr.File(label="Download", visible=False)
530
-
531
 
532
- # State for storing full EvaluationResult objects if needed for more complex interactions
533
- # For this setup, regenerating DataFrame from evaluator.results is generally sufficient
534
- # results_state = gr.State([]) # If storing raw results is complex, simplify or manage carefully
535
-
536
- async def run_evaluation(files, selected_model_labels, current_batch_size, progress=gr.Progress(track_tqdm=True)):
537
- if not files:
538
- return "Please upload images first.", pd.DataFrame(), [], "No files uploaded."
539
-
540
- # Convert display labels back to model keys
541
- selected_model_keys = [key for label, key in model_options if label in selected_model_labels]
542
-
543
- if not selected_model_keys:
544
- return "Please select at least one model.", pd.DataFrame(), [], "No models selected."
545
 
546
- # file_paths = [f.name for f in files] # .name gives temp path of UploadFile
547
-
548
- # Progress callback
549
- # def update_progress_display(value, desc="Processing..."):
550
- # progress(value / 100, desc=f"{desc} {value:.0f}%")
551
- # return f"{desc} {value:.0f}%" # For gr.Label
552
-
553
- # Use gr.Progress context for automatic updates with iterators
554
- # However, for manual updates with batching, direct calls are fine.
555
- # We'll update logs_display and progress_status manually.
556
-
557
- progress_updates = []
558
- def progress_callback_for_eval(p_value, p_desc):
559
- progress(p_value / 100, desc=p_desc) # Update the main progress component
560
- # logs_display.value += f"\n{p_desc} - {p_value:.0f}%" # This will make logs messy
561
- progress_updates.append(f"{p_desc} - {p_value:.0f}%")
562
-
563
-
564
- # Evaluate images
565
- processed_results, log_messages = await evaluator.evaluate_images(
566
- files, # Pass the list of UploadFile objects directly
567
- selected_model_keys,
568
- int(current_batch_size), # Ensure batch_size is int
569
- progress_callback_for_eval # Pass the callback
570
- )
571
-
572
- df = evaluator.get_results_dataframe(selected_model_keys)
573
- log_text = "\n".join(log_messages + progress_updates)
574
-
575
- final_status = "Evaluation complete." if processed_results else "Evaluation failed or no results."
576
- progress(1.0, desc=final_status) # Mark progress as complete
577
-
578
- return log_text, df, final_status # Removed results_state for simplicity
579
-
580
- def handle_model_selection_change(selected_model_labels_updated):
581
- # Called when checkbox group changes. evaluator.results should already be populated.
582
- if not evaluator.results:
583
- return pd.DataFrame() # No results to re-filter/re-calculate
584
-
585
- selected_model_keys_updated = [key for label, key in model_options if label in selected_model_labels_updated]
586
-
587
- # Recalculate final scores for all existing results based on new selection
588
- for res_obj in evaluator.results:
589
- res_obj.calculate_final_score(selected_model_keys_updated)
590
-
591
- return evaluator.get_results_dataframe(selected_model_keys_updated)
592
-
593
- def clear_all_outputs():
594
- evaluator.results = [] # Clear stored results in the evaluator
595
- return "", pd.DataFrame(), "Cleared.", None # Log, DataFrame, Progress Status, Download File
596
-
597
- def generate_csv_for_download(selected_model_labels_for_csv):
598
- if not evaluator.results:
599
- gr.Warning("No results to download.")
600
- return None
601
-
602
- selected_model_keys_for_csv = [key for label, key in model_options if label in selected_model_labels_for_csv]
603
-
604
- # Get DataFrame, but exclude the gr.Image column for CSV
605
- df_for_csv = evaluator.get_results_dataframe(selected_model_keys_for_csv).copy()
606
- if 'Image' in df_for_csv.columns:
607
- df_for_csv.drop(columns=['Image'], inplace=True)
608
-
609
- if df_for_csv.empty:
610
- gr.Warning("No data to download based on current selection.")
611
- return None
612
-
613
- import tempfile
614
- with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.csv', encoding='utf-8') as tmp_file:
615
- df_for_csv.to_csv(tmp_file.name, index=False)
616
- return tmp_file.name
617
-
618
- evaluate_btn.click(
619
- fn=run_evaluation,
620
  inputs=[input_files, model_checkboxes, batch_size_slider],
621
- outputs=[logs_display, results_df_display, progress_status] # Removed results_state
622
- )
623
-
624
- model_checkboxes.change(
625
- fn=handle_model_selection_change,
626
- inputs=[model_checkboxes],
627
- outputs=[results_df_display]
628
- )
629
-
630
- clear_btn.click(
631
- fn=clear_all_outputs,
632
- outputs=[logs_display, results_df_display, progress_status, download_file_output_component]
633
  )
634
-
635
- download_button.click(
636
- fn=generate_csv_for_download,
637
- inputs=[model_checkboxes],
638
- outputs=[download_file_output_component]
 
 
 
 
 
 
 
 
 
639
  )
640
 
641
- gr.Markdown("""
642
- ### 📝 Notes
643
- - **Model Selection**: Choose which models to use for evaluation. The final score is the average of the selected models. Models that failed to load during startup will not be listed or will be ignored.
644
- - **Batch Size**: Adjust based on your system's VRAM and RAM. Smaller batches use less memory but may be slower overall.
645
- - **Results Table**: Displays scores from selected models and the final average. Images are shown as thumbnails.
646
- - **Download**: Export results (excluding image thumbnails) as a CSV file for further analysis.
647
-
648
- ### 🎯 Score Interpretation (General Guide)
649
- - **7-10**: High quality/aesthetic appeal
650
- - **5-7**: Medium quality
651
- - **0-5**: Lower quality
652
- _(Note: Score ranges and interpretations can vary between models.)_
653
- """)
654
-
655
  return demo
656
 
 
 
 
657
 
658
  if __name__ == "__main__":
659
- # Ensure 'aesthetic_predictor_v2_5.py' exists and 'openai-clip' is installed for WaifuScorer
660
- # Example: pip install openai-clip transformers==4.30.2 onnxruntime gradio pandas Pillow opencv-python
661
- # Check specific model requirements.
662
 
663
- # Create and launch the interface
664
- app_interface = create_interface()
665
- # Adding .queue() is good for handling multiple users or long-running tasks.
666
- # Set debug=True for more detailed Gradio errors during development.
667
- app_interface.queue().launch(debug=True)
 
1
  import os
2
+ import gc
3
+ from abc import ABC, abstractmethod
 
4
  from pathlib import Path
5
+ from typing import List, Dict, Any, Type
6
 
7
  import cv2
8
+ import gradio as gr
9
  import numpy as np
10
+ import pandas as pd
11
  import torch
12
  import onnxruntime as rt
13
  from PIL import Image
 
 
14
  from huggingface_hub import hf_hub_download
15
+ from transformers import pipeline, Pipeline
16
+ from tqdm import tqdm
17
+
18
+ # Suppress a specific PIL warning about image size
19
+ Image.MAX_IMAGE_PIXELS = None
20
+
21
+ # --- Configuration ---
22
+ CACHE_DIR = "./hf_cache"
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float32
25
+
26
+ print(f"Using device: {DEVICE} with dtype: {DTYPE}")
27
+
28
+ # ==================================================================================
29
+ # 1. MODEL ABSTRACTION: A unified interface for all scorers.
30
+ # ==================================================================================
31
+
32
+ class AestheticScorer(ABC):
33
+ """Abstract base class for all aesthetic scoring models."""
34
+
35
+ def __init__(self, model_name: str, repo_id: str, filename: str = None):
36
+ self.model_name = model_name
37
+ self.repo_id = repo_id
38
+ self.filename = filename
39
+ self._model = None
40
+ print(f"Initializing scorer: {self.model_name}")
41
+
42
+ @property
43
+ def model(self):
44
+ """Lazy-loads the model on first access."""
45
+ if self._model is None:
46
+ print(f"Loading model for '{self.model_name}'...")
47
+ self._model = self.load_model()
48
+ print(f"'{self.model_name}' model loaded.")
49
+ return self._model
50
+
51
+ def _download_model(self) -> str:
52
+ """Downloads the model file from Hugging Face Hub."""
53
+ return hf_hub_download(repo_id=self.repo_id, filename=self.filename, cache_dir=CACHE_DIR)
54
+
55
+ @abstractmethod
56
+ def load_model(self) -> Any:
57
+ """Loads the model and any necessary preprocessors."""
58
+ pass
59
+
60
+ @abstractmethod
61
+ def score_batch(self, image_batch: List[Image.Image]) -> List[float]:
62
+ """Scores a batch of images and returns a list of floats."""
63
+ pass
64
+
65
+ def release_model(self):
66
+ """Releases model from memory."""
67
+ if self._model is not None:
68
+ print(f"Releasing model: {self.model_name}")
69
+ del self._model
70
+ self._model = None
71
+ gc.collect()
72
+ if torch.cuda.is_available():
73
+ torch.cuda.empty_cache()
74
+
75
+ class PipelineScorer(AestheticScorer):
76
+ """Scorer for models compatible with Hugging Face pipelines."""
77
+
78
+ def load_model(self) -> Pipeline:
79
+ """Loads a pipeline model."""
80
+ return pipeline(
81
+ "image-classification",
82
+ model=self.repo_id,
83
+ device=DEVICE,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  @torch.no_grad()
87
+ def score_batch(self, image_batch: List[Image.Image]) -> List[float]:
88
+ """Scores a batch using the pipeline and extracts the 'hq' score."""
89
+ results = self.model(image_batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  scores = []
91
+ for res in results:
92
  try:
93
+ # Find the score for the 'hq' (high quality) label
94
+ hq_score = next(item['score'] for item in res if item['label'] == 'hq')
95
+ scores.append(round(hq_score * 10.0, 4))
96
+ except (StopIteration, TypeError):
97
+ scores.append(0.0)
98
  return scores
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ class ONNXScorer(AestheticScorer):
101
+ """Scorer for ONNX-based models."""
102
 
103
+ def load_model(self) -> rt.InferenceSession:
104
+ """Loads an ONNX inference session."""
105
+ model_path = self._download_model()
106
+ return rt.InferenceSession(model_path, providers=['CUDAExecutionProvider' if DEVICE == 'cuda' else 'CPUExecutionProvider'])
 
107
 
108
+ def _preprocess(self, img: Image.Image) -> np.ndarray:
109
+ """Preprocesses a single image for the Anime Aesthetic model."""
110
+ img_np = np.array(img.convert("RGB")).astype(np.float32) / 255.0
111
+ s = 768
112
  h, w = img_np.shape[:2]
 
113
  if h > w:
114
+ new_h, new_w = s, int(s * w / h)
115
  else:
116
+ new_h, new_w = int(s * h / w), s
117
 
118
+ resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
119
+ canvas = np.zeros((s, s, 3), dtype=np.float32)
120
+ pad_h, pad_w = (s - new_h) // 2, (s - new_w) // 2
121
+ canvas[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized
122
 
123
+ return np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
124
+
125
+ def score_batch(self, image_batch: List[Image.Image]) -> List[float]:
126
+ """Scores images one by one as this model doesn't support batching."""
127
+ scores = []
128
+ for img in image_batch:
129
+ try:
130
+ input_tensor = self._preprocess(img)
131
+ pred = self.model.run(None, {"img": input_tensor})[0].item()
132
+ scores.append(round(pred * 10.0, 4))
133
+ except Exception:
134
+ scores.append(0.0)
135
+ return scores
136
+
137
+ class CLIPMLPScorer(AestheticScorer):
138
+ """Scorer for models using a CLIP backbone and an MLP head."""
139
+
140
+ class MLP(torch.nn.Module):
141
+ def __init__(self, input_size: int):
142
+ super().__init__()
143
+ self.layers = torch.nn.Sequential(
144
+ torch.nn.Linear(input_size, 1024),
145
+ torch.nn.ReLU(),
146
+ torch.nn.Dropout(0.2),
147
+ torch.nn.Linear(1024, 128),
148
+ torch.nn.ReLU(),
149
+ torch.nn.Dropout(0.2),
150
+ torch.nn.Linear(128, 64),
151
+ torch.nn.ReLU(),
152
+ torch.nn.Linear(64, 16),
153
+ torch.nn.ReLU(),
154
+ torch.nn.Linear(16, 1),
155
+ )
156
+ def forward(self, x):
157
+ return self.layers(x)
158
+
159
+ def load_model(self) -> Dict[str, Any]:
160
+ """Loads both the CLIP model and the custom MLP head."""
161
+ import clip # Lazy import
162
 
163
+ model_path = self._download_model()
164
 
165
+ mlp = self.MLP(input_size=768) # ViT-L/14 has 768 features
166
+ state_dict = torch.load(model_path, map_location=DEVICE)
167
+ mlp.load_state_dict(state_dict)
168
+ mlp.to(device=DEVICE, dtype=DTYPE)
169
+ mlp.eval()
170
 
171
+ clip_model, preprocess = clip.load("ViT-L/14", device=DEVICE)
172
+
173
+ return {"mlp": mlp, "clip": clip_model, "preprocess": preprocess}
174
 
175
+ @torch.no_grad()
176
+ def score_batch(self, image_batch: List[Image.Image]) -> List[float]:
177
+ """Scores a batch using CLIP features and the MLP head."""
178
+ preprocess = self.model['preprocess']
179
+ image_tensors = torch.cat([preprocess(img).unsqueeze(0) for img in image_batch]).to(DEVICE)
 
180
 
181
+ image_features = self.model['clip'].encode_image(image_tensors)
182
+ image_features /= image_features.norm(dim=-1, keepdim=True)
 
 
 
 
 
 
183
 
184
+ # Pass features through MLP
185
+ predictions = self.model['mlp'](image_features.to(DTYPE)).squeeze(-1)
186
+ scores = predictions.float().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
+ return [round(float(s), 4) for s in scores]
189
+
190
+ # --- Model Registry ---
191
+ MODEL_REGISTRY: Dict[str, Type[AestheticScorer]] = {
192
+ "Aesthetic Shadow V2": PipelineScorer(
193
+ "Aesthetic Shadow V2", "shadowlilac/aesthetic-shadow-v2"
194
+ ),
195
+ "Waifu Scorer V2": CLIPMLPScorer(
196
+ "Waifu Scorer V2", "skytnt/waifu-aesthetic-scorer", "model.pth"
197
+ ),
198
+ "Anime Scorer": ONNXScorer(
199
+ "Anime Scorer", "skytnt/anime-aesthetic", "model.onnx"
200
+ )
201
+ }
202
+
203
+ # In-memory cache for loaded model instances
204
+ _loaded_models_cache: Dict[str, AestheticScorer] = {}
205
+
206
+
207
+ # ==================================================================================
208
+ # 2. CORE PROCESSING LOGIC
209
+ # ==================================================================================
210
+
211
+ def get_scorers(model_names: List[str]) -> List[AestheticScorer]:
212
+ """Retrieves and caches scorer instances based on selected names."""
213
+ # Release models that are no longer selected
214
+ for name, scorer in list(_loaded_models_cache.items()):
215
+ if name not in model_names:
216
+ scorer.release_model()
217
+ del _loaded_models_cache[name]
218
+
219
+ # Load newly selected models
220
+ scorers = []
221
+ for name in model_names:
222
+ if name in _loaded_models_cache:
223
+ scorers.append(_loaded_models_cache[name])
224
+ elif name in MODEL_REGISTRY:
225
+ scorer = MODEL_REGISTRY[name]
226
+ _loaded_models_cache[name] = scorer
227
+ scorers.append(scorer)
228
+ return scorers
229
+
230
+ def evaluate_images(
231
+ files: List[gr.File],
232
+ selected_model_names: List[str],
233
+ batch_size: int,
234
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
235
+ ) -> pd.DataFrame:
236
+ """
237
+ Main function to process images, run them through selected models,
238
+ and return results as a Pandas DataFrame.
239
+ """
240
+ if not files:
241
+ gr.Warning("No images uploaded. Please upload files to evaluate.")
242
+ return pd.DataFrame()
243
 
244
+ if not selected_model_names:
245
+ gr.Warning("No models selected. Please select at least one model.")
246
+ return pd.DataFrame()
247
+
248
+ try:
249
+ image_paths = [Path(f.name) for f in files]
250
+ all_results = []
251
+ scorers = get_scorers(selected_model_names)
252
 
253
+ # Use a single tqdm instance for progress tracking
254
+ pbar = tqdm(total=len(image_paths), desc="Processing images")
255
 
256
+ for i in range(0, len(image_paths), batch_size):
257
+ batch_paths = image_paths[i : i + batch_size]
258
+
259
+ # Load images for the current batch
260
+ try:
261
+ batch_images = [Image.open(p).convert("RGB") for p in batch_paths]
262
+ except Exception as e:
263
+ gr.Warning(f"Skipping a batch due to an error loading an image: {e}")
264
+ pbar.update(len(batch_paths))
265
+ continue
266
+
267
+ # Get scores from all selected models for the batch
268
+ batch_scores: Dict[str, List[float]] = {}
269
+ for scorer in scorers:
270
+ batch_scores[scorer.model_name] = scorer.score_batch(batch_images)
 
 
 
 
 
 
271
 
272
+ # Collate results for the batch
273
+ for j, path in enumerate(batch_paths):
274
+ result_row = {"Image": Image.open(path), "Filename": path.name}
275
 
276
+ scores_for_avg = []
277
+ for scorer in scorers:
278
+ score = batch_scores[scorer.model_name][j]
279
+ result_row[scorer.model_name] = score
280
+ scores_for_avg.append(score)
 
 
 
 
 
281
 
282
+ # Calculate average score
283
+ if scores_for_avg:
284
+ result_row["Average Score"] = round(np.mean(scores_for_avg), 4)
285
+ else:
286
+ result_row["Average Score"] = 0.0
287
+
288
+ all_results.append(result_row)
 
 
 
 
 
289
 
290
+ pbar.update(len(batch_paths))
291
+
292
+ pbar.close()
 
 
 
 
293
 
294
+ if not all_results:
295
+ gr.Warning("Processing completed, but no results were generated.")
296
+ return pd.DataFrame()
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
+ return pd.DataFrame(all_results)
 
299
 
300
+ except Exception as e:
301
+ gr.Error(f"A critical error occurred: {e}")
302
+ # Clean up in case of failure
303
+ for scorer in _loaded_models_cache.values():
304
+ scorer.release_model()
305
+ _loaded_models_cache.clear()
306
+ return pd.DataFrame()
307
+
 
308
 
309
+ # ==================================================================================
310
+ # 3. GRADIO USER INTERFACE
311
+ # ==================================================================================
312
 
313
+ def create_ui() -> gr.Blocks:
314
+ """Creates and configures the Gradio web interface."""
 
315
 
316
+ all_model_names = list(MODEL_REGISTRY.keys())
317
+
318
+ # Define headers and datatypes for the results table
319
+ dataframe_headers = ["Image", "Filename"] + all_model_names + ["Average Score"]
320
+ dataframe_datatypes = ["image", "str"] + ["number"] * (len(all_model_names) + 1)
321
+
322
+ with gr.Blocks(theme=gr.themes.Soft(), title="Image Aesthetic Scorer") as demo:
323
+ gr.Markdown(
324
+ """
325
+ # 🖼️ Modern Image Aesthetic Scorer
326
+ Upload your images, select the scoring models, and click 'Evaluate'.
327
+ The results table supports **interactive sorting** (click on headers) and can be **downloaded as a CSV**.
328
+ """
329
+ )
330
+
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  with gr.Row():
332
  with gr.Column(scale=1):
333
+ gr.Markdown("### ⚙️ Settings")
334
+ input_files = gr.Files(
335
  label="Upload Images",
336
  file_count="multiple",
337
+ file_types=["image"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  )
339
 
340
+ with gr.Accordion("Advanced Configuration", open=False):
341
+ model_checkboxes = gr.CheckboxGroup(
342
+ choices=all_model_names,
343
+ value=all_model_names,
344
+ label="Scoring Models",
345
+ info="Choose which models to use for evaluation.",
346
+ )
347
+ batch_size_slider = gr.Slider(
348
+ minimum=1,
349
+ maximum=64,
350
+ value=8,
351
+ step=1,
352
+ label="Batch Size",
353
+ info="Adjust based on your VRAM. Higher is faster.",
354
+ )
355
+
356
  with gr.Row():
357
+ process_button = gr.Button("🚀 Evaluate Images", variant="primary")
358
+ clear_button = gr.Button("🧹 Clear All")
359
 
360
+ with gr.Column(scale=3):
361
+ gr.Markdown("### 📊 Results")
362
+ results_dataframe = gr.DataFrame(
363
+ headers=dataframe_headers,
364
+ datatype=dataframe_datatypes,
365
+ label="Evaluation Scores",
366
+ interactive=True,
367
+ # Enable the download button directly on the component
368
  )
369
+ # This is a cleaner way to show the download button
370
+ results_dataframe.style(height=800, show_download_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
+ # --- Event Handlers ---
374
+ process_button.click(
375
+ fn=evaluate_images,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  inputs=[input_files, model_checkboxes, batch_size_slider],
377
+ outputs=[results_dataframe],
378
+ concurrency_limit=1 # Only one evaluation at a time
 
 
 
 
 
 
 
 
 
 
379
  )
380
+
381
+ def clear_outputs():
382
+ # Release all models from memory when clearing
383
+ for scorer in _loaded_models_cache.values():
384
+ scorer.release_model()
385
+ _loaded_models_cache.clear()
386
+ gr.Info("Cleared results and released models from memory.")
387
+ # Return an empty DataFrame to clear the table
388
+ return pd.DataFrame()
389
+
390
+ clear_button.click(
391
+ fn=clear_outputs,
392
+ inputs=[],
393
+ outputs=[results_dataframe],
394
  )
395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  return demo
397
 
398
+ # ==================================================================================
399
+ # 4. APPLICATION ENTRY POINT
400
+ # ==================================================================================
401
 
402
  if __name__ == "__main__":
403
+ # Ensure cache directory exists
404
+ os.makedirs(CACHE_DIR, exist_ok=True)
 
405
 
406
+ app = create_ui()
407
+ app.queue().launch(share=False)