VOIDER commited on
Commit
1bc1e75
Β·
verified Β·
1 Parent(s): 8b461d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +390 -472
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
- import tempfile
3
- import base64
4
- from io import BytesIO
5
- from typing import List, Dict, Any, Optional, Tuple
6
- from dataclasses import dataclass
7
  from pathlib import Path
 
8
 
9
  import cv2
10
  import numpy as np
@@ -12,589 +11,508 @@ import torch
12
  import onnxruntime as rt
13
  from PIL import Image
14
  import gradio as gr
15
- import pandas as pd
16
  from transformers import pipeline
17
  from huggingface_hub import hf_hub_download
 
 
 
 
 
18
 
19
- # Import necessary function from aesthetic_predictor_v2_5
20
  from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
21
 
22
 
23
  @dataclass
24
  class EvaluationResult:
25
- """Data class for storing image evaluation results."""
26
  file_name: str
27
- image: Image.Image
28
- aesthetic_shadow: Optional[float] = None
29
- waifu_scorer: Optional[float] = None
30
- aesthetic_v2_5: Optional[float] = None
31
- anime_aesthetic: Optional[float] = None
32
  final_score: Optional[float] = None
 
 
 
 
 
 
 
 
33
 
34
 
35
- class MLP(torch.nn.Module):
36
- """Optimized MLP for image feature regression."""
37
- def __init__(self, input_size: int = 768):
38
- super().__init__()
39
- self.network = torch.nn.Sequential(
40
- torch.nn.Linear(input_size, 1024),
41
- torch.nn.ReLU(),
42
- torch.nn.BatchNorm1d(1024),
43
- torch.nn.Dropout(0.2),
44
- torch.nn.Linear(1024, 256),
45
- torch.nn.ReLU(),
46
- torch.nn.BatchNorm1d(256),
47
- torch.nn.Dropout(0.1),
48
- torch.nn.Linear(256, 64),
49
- torch.nn.ReLU(),
50
- torch.nn.Linear(64, 1)
51
- )
52
-
53
- def forward(self, x: torch.Tensor) -> torch.Tensor:
54
- return self.network(x)
55
 
56
 
57
- class ModelLoader:
58
- """Centralized model loading and management."""
59
-
60
- def __init__(self, device: str = None):
61
- self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
62
- self.models = {}
63
- self._load_all_models()
64
-
65
- def _load_all_models(self):
66
- """Load all models during initialization."""
67
- try:
68
- self._load_aesthetic_shadow()
69
- self._load_waifu_scorer()
70
- self._load_aesthetic_v2_5()
71
- self._load_anime_aesthetic()
72
- print("βœ… All models loaded successfully!")
73
- except Exception as e:
74
- print(f"❌ Error loading models: {e}")
75
-
76
- def _load_aesthetic_shadow(self):
77
- """Load Aesthetic Shadow model."""
78
- print("πŸ”„ Loading Aesthetic Shadow...")
79
- self.models['aesthetic_shadow'] = pipeline(
80
- "image-classification",
81
  model="NeoChen1024/aesthetic-shadow-v2-backup",
82
- device=self.device
83
  )
84
 
85
- def _load_waifu_scorer(self):
86
- """Load Waifu Scorer model."""
87
- print("πŸ”„ Loading Waifu Scorer...")
88
- try:
89
- import clip
90
-
91
- # Load MLP
92
- model_path = hf_hub_download("Eugeoter/waifu-scorer-v3", "model.pth")
93
- mlp = MLP()
94
- state_dict = torch.load(model_path, map_location=self.device)
95
- mlp.load_state_dict(state_dict)
96
- mlp.to(self.device).eval()
97
-
98
- # Load CLIP
99
- clip_model, preprocess = clip.load("ViT-L/14", device=self.device)
100
-
101
- self.models['waifu_scorer'] = {
102
- 'mlp': mlp,
103
- 'clip_model': clip_model,
104
- 'preprocess': preprocess
105
- }
106
- except Exception as e:
107
- print(f"⚠️ Waifu Scorer not available: {e}")
108
- self.models['waifu_scorer'] = None
109
-
110
- def _load_aesthetic_v2_5(self):
111
- """Load Aesthetic Predictor V2.5."""
112
- print("πŸ”„ Loading Aesthetic V2.5...")
113
- try:
114
- model, preprocessor = convert_v2_5_from_siglip(
115
- low_cpu_mem_usage=True,
116
- trust_remote_code=True,
117
- )
118
- if torch.cuda.is_available():
119
- model = model.to(torch.bfloat16).cuda()
120
-
121
- self.models['aesthetic_v2_5'] = {
122
- 'model': model,
123
- 'preprocessor': preprocessor
124
- }
125
- except Exception as e:
126
- print(f"⚠️ Aesthetic V2.5 not available: {e}")
127
- self.models['aesthetic_v2_5'] = None
128
-
129
- def _load_anime_aesthetic(self):
130
- """Load Anime Aesthetic model."""
131
- print("πŸ”„ Loading Anime Aesthetic...")
132
  try:
133
- model_path = hf_hub_download("skytnt/anime-aesthetic", "model.onnx")
134
- self.models['anime_aesthetic'] = rt.InferenceSession(
135
- model_path,
136
- providers=['CPUExecutionProvider']
137
- )
 
138
  except Exception as e:
139
- print(f"⚠️ Anime Aesthetic not available: {e}")
140
- self.models['anime_aesthetic'] = None
141
 
142
 
143
- class ImageEvaluator:
144
- """Main image evaluation class with batch processing."""
145
-
146
  def __init__(self):
147
- self.loader = ModelLoader()
148
- self.temp_dir = Path(tempfile.mkdtemp())
 
149
 
150
- def evaluate_images(
151
- self,
152
- images: List[Image.Image],
153
- file_names: List[str],
154
- selected_models: List[str],
155
- batch_size: int = 4,
156
- progress_callback=None
157
- ) -> List[EvaluationResult]:
158
- """Evaluate images using selected models."""
159
- results = []
160
- total_batches = (len(images) + batch_size - 1) // batch_size
161
-
162
- for batch_idx in range(0, len(images), batch_size):
163
- batch_images = images[batch_idx:batch_idx + batch_size]
164
- batch_names = file_names[batch_idx:batch_idx + batch_size]
165
 
166
- # Update progress
167
- if progress_callback:
168
- progress = (batch_idx // batch_size + 1) / total_batches
169
- progress_callback(progress, f"Processing batch {batch_idx//batch_size + 1}/{total_batches}")
 
 
170
 
171
- # Process batch
172
- batch_results = self._process_batch(batch_images, batch_names, selected_models)
173
- results.extend(batch_results)
174
-
175
- return results
 
176
 
177
- def _process_batch(
178
- self,
179
- images: List[Image.Image],
180
- file_names: List[str],
181
- selected_models: List[str]
182
- ) -> List[EvaluationResult]:
183
- """Process a single batch of images."""
184
- batch_results = []
185
-
186
- # Initialize results
187
- for i, (img, name) in enumerate(zip(images, file_names)):
188
- result = EvaluationResult(file_name=name, image=img)
189
- batch_results.append(result)
190
-
191
- # Process each selected model
192
- if 'aesthetic_shadow' in selected_models:
193
- scores = self._eval_aesthetic_shadow(images)
194
- for result, score in zip(batch_results, scores):
195
- result.aesthetic_shadow = score
196
-
197
- if 'waifu_scorer' in selected_models:
198
- scores = self._eval_waifu_scorer(images)
199
- for result, score in zip(batch_results, scores):
200
- result.waifu_scorer = score
201
-
202
- if 'aesthetic_v2_5' in selected_models:
203
- scores = self._eval_aesthetic_v2_5(images)
204
- for result, score in zip(batch_results, scores):
205
- result.aesthetic_v2_5 = score
206
-
207
- if 'anime_aesthetic' in selected_models:
208
- scores = self._eval_anime_aesthetic(images)
209
- for result, score in zip(batch_results, scores):
210
- result.anime_aesthetic = score
211
-
212
- # Calculate final scores
213
- for result in batch_results:
214
- result.final_score = self._calculate_final_score(result, selected_models)
215
-
216
- return batch_results
217
 
218
- def _eval_aesthetic_shadow(self, images: List[Image.Image]) -> List[Optional[float]]:
219
- """Evaluate using Aesthetic Shadow model."""
220
- if not self.loader.models.get('aesthetic_shadow'):
221
  return [None] * len(images)
222
 
223
  try:
224
- results = self.loader.models['aesthetic_shadow'](images)
225
- scores = []
226
- for result in results:
227
- try:
228
- hq_score = next(p for p in result if p['label'] == 'hq')['score']
229
- scores.append(float(np.clip(hq_score * 10.0, 0.0, 10.0)))
230
- except:
231
- scores.append(None)
 
 
232
  return scores
233
  except Exception as e:
234
- print(f"Error in Aesthetic Shadow: {e}")
235
- return [None] * len(images)
236
-
237
- def _eval_waifu_scorer(self, images: List[Image.Image]) -> List[Optional[float]]:
238
- """Evaluate using Waifu Scorer model."""
239
- model_dict = self.loader.models.get('waifu_scorer')
240
- if not model_dict:
241
- return [None] * len(images)
242
-
243
- try:
244
- with torch.no_grad():
245
- # Preprocess images
246
- image_tensors = [model_dict['preprocess'](img).unsqueeze(0) for img in images]
247
- if len(image_tensors) == 1:
248
- image_tensors = image_tensors * 2 # CLIP requirement
249
-
250
- image_batch = torch.cat(image_tensors).to(self.loader.device)
251
- image_features = model_dict['clip_model'].encode_image(image_batch)
252
-
253
- # Normalize features
254
- norm = image_features.norm(2, dim=-1, keepdim=True)
255
- norm[norm == 0] = 1
256
- im_emb = (image_features / norm).to(self.loader.device)
257
-
258
- predictions = model_dict['mlp'](im_emb)
259
- scores = predictions.clamp(0, 10).cpu().numpy().flatten().tolist()
260
-
261
- return scores[:len(images)]
262
- except Exception as e:
263
- print(f"Error in Waifu Scorer: {e}")
264
  return [None] * len(images)
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
- def _eval_aesthetic_v2_5(self, images: List[Image.Image]) -> List[Optional[float]]:
267
- """Evaluate using Aesthetic Predictor V2.5."""
268
- model_dict = self.loader.models.get('aesthetic_v2_5')
269
- if not model_dict:
270
- return [None] * len(images)
271
-
272
  try:
273
- rgb_images = [img.convert("RGB") for img in images]
274
- pixel_values = model_dict['preprocessor'](images=rgb_images, return_tensors="pt").pixel_values
275
 
276
- if torch.cuda.is_available():
277
  pixel_values = pixel_values.to(torch.bfloat16).cuda()
278
 
279
- with torch.inference_mode():
280
- scores = model_dict['model'](pixel_values).logits.squeeze().float().cpu().numpy()
281
- if scores.ndim == 0:
282
- scores = np.array([scores])
283
-
284
- return [float(np.clip(s, 0.0, 10.0)) for s in scores.tolist()]
285
  except Exception as e:
286
- print(f"Error in Aesthetic V2.5: {e}")
287
  return [None] * len(images)
 
 
 
 
 
 
 
 
 
288
 
289
- def _eval_anime_aesthetic(self, images: List[Image.Image]) -> List[Optional[float]]:
290
- """Evaluate using Anime Aesthetic model."""
291
- model = self.loader.models.get('anime_aesthetic')
292
- if not model:
293
- return [None] * len(images)
294
-
295
  scores = []
296
  for img in images:
297
  try:
298
- # Preprocess image
299
- img_np = np.array(img).astype(np.float32) / 255.0
300
- h, w = img_np.shape[:2]
301
- s = 768
302
-
303
- if h > w:
304
- new_h, new_w = s, int(s * w / h)
305
- else:
306
- new_h, new_w = int(s * h / w), s
307
-
308
- resized = cv2.resize(img_np, (new_w, new_h))
309
- canvas = np.zeros((s, s, 3), dtype=np.float32)
310
-
311
- pad_h = (s - new_h) // 2
312
- pad_w = (s - new_w) // 2
313
- canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
314
-
315
- input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
316
- pred = model.run(None, {"img": input_tensor})[0].item()
317
- scores.append(float(np.clip(pred * 10.0, 0.0, 10.0)))
318
  except Exception as e:
319
- print(f"Error processing image: {e}")
320
  scores.append(None)
321
-
322
  return scores
323
 
324
- def _calculate_final_score(self, result: EvaluationResult, selected_models: List[str]) -> Optional[float]:
325
- """Calculate final score from selected model results."""
326
- scores = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
- for model in selected_models:
329
- score = getattr(result, model, None)
330
- if score is not None:
331
- scores.append(score)
332
 
333
- return float(np.mean(scores)) if scores else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
- def results_to_dataframe(self, results: List[EvaluationResult]) -> pd.DataFrame:
336
- """Convert results to pandas DataFrame."""
 
 
 
337
  data = []
338
- for result in results:
339
  row = {
340
  'File Name': result.file_name,
341
- 'Final Score': result.final_score,
342
  }
343
- if result.aesthetic_shadow is not None:
344
- row['Aesthetic Shadow'] = result.aesthetic_shadow
345
- if result.waifu_scorer is not None:
346
- row['Waifu Scorer'] = result.waifu_scorer
347
- if result.aesthetic_v2_5 is not None:
348
- row['Aesthetic V2.5'] = result.aesthetic_v2_5
349
- if result.anime_aesthetic is not None:
350
- row['Anime Aesthetic'] = result.anime_aesthetic
351
  data.append(row)
352
 
353
  return pd.DataFrame(data)
354
-
355
- def optimize_batch_size(self, sample_images: List[Image.Image]) -> int:
356
- """Automatically determine optimal batch size."""
357
- if not sample_images:
358
- return 1
359
-
360
- test_image = sample_images[0]
361
- batch_size = 1
362
- max_test = min(16, len(sample_images))
363
-
364
- while batch_size <= max_test:
365
- try:
366
- test_batch = [test_image] * batch_size
367
- # Test with a lightweight model
368
- if self.loader.models.get('aesthetic_shadow'):
369
- _ = self.loader.models['aesthetic_shadow'](test_batch)
370
- batch_size *= 2
371
- except Exception:
372
- break
373
-
374
- optimal = max(1, batch_size // 2)
375
- return min(optimal, 8) # Cap at reasonable size
376
 
377
 
378
  def create_interface():
379
- """Create the Gradio interface."""
380
  evaluator = ImageEvaluator()
381
 
382
- # Available models
383
- model_choices = [
384
  ("Aesthetic Shadow", "aesthetic_shadow"),
385
  ("Waifu Scorer", "waifu_scorer"),
386
- ("Aesthetic V2.5", "aesthetic_v2_5"),
387
- ("Anime Aesthetic", "anime_aesthetic")
388
  ]
389
- available_models = [choice[1] for choice in model_choices]
390
 
391
- with gr.Blocks(title="Image Evaluation Tool", theme=gr.themes.Soft()) as app:
392
  gr.Markdown("""
393
- # 🎨 Modern Image Evaluation Tool
394
 
395
- Upload images to evaluate them using state-of-the-art aesthetic and quality prediction models.
396
-
397
- **Features:**
398
- - Multiple AI models for comprehensive evaluation
399
- - Batch processing with automatic optimization
400
- - Interactive results table with sorting and filtering
401
- - CSV export functionality
402
- - Real-time progress tracking
403
  """)
404
 
405
  with gr.Row():
406
  with gr.Column(scale=1):
407
- # Input components
408
  input_files = gr.File(
409
- label="πŸ“ Upload Images",
410
  file_count="multiple",
411
  file_types=["image"]
412
  )
413
 
414
- model_selection = gr.CheckboxGroup(
415
- choices=model_choices,
416
- value=available_models,
417
- label="πŸ€– Select Models",
418
  info="Choose which models to use for evaluation"
419
  )
420
 
421
  with gr.Row():
422
- auto_batch = gr.Checkbox(
423
- label="πŸ”„ Auto Batch Size",
424
- value=True,
425
- info="Automatically optimize batch size"
426
- )
427
-
428
- manual_batch = gr.Slider(
429
  minimum=1,
430
- maximum=16,
431
- value=4,
432
  step=1,
433
- label="πŸ“Š Batch Size",
434
- interactive=False,
435
- info="Manual batch size (when auto is disabled)"
436
  )
437
 
438
- evaluate_btn = gr.Button(
439
- "πŸš€ Evaluate Images",
440
- variant="primary",
441
- size="lg"
442
- )
443
-
444
- clear_btn = gr.Button("πŸ—‘οΈ Clear Results", variant="secondary")
445
 
446
  with gr.Column(scale=2):
447
- # Progress and status
448
- progress_bar = gr.Progress()
449
- status_text = gr.Textbox(
450
- label="πŸ“Š Status",
451
- interactive=False,
452
- max_lines=2
453
  )
454
 
455
- # Results display
456
- results_table = gr.DataFrame(
457
- label="πŸ“‹ Evaluation Results",
458
  interactive=False,
459
- wrap=True,
460
- max_height=400
461
  )
462
 
463
- # Export functionality
464
- with gr.Row():
465
- export_csv = gr.Button("πŸ“₯ Export CSV", variant="secondary")
466
- download_file = gr.File(
467
- label="πŸ’Ύ Download",
468
- visible=False
469
- )
470
 
471
- # State management
472
  results_state = gr.State([])
473
 
474
- # Event handlers
475
- def toggle_batch_slider(auto_enabled):
476
- return gr.update(interactive=not auto_enabled)
477
-
478
- def process_images(files, models, auto_batch_enabled, manual_batch_size, progress=gr.Progress()):
479
- if not files or not models:
480
- return "❌ Please upload images and select at least one model", pd.DataFrame(), []
481
 
482
- try:
483
- # Load images
484
- images = []
485
- file_names = []
486
-
487
- progress(0.1, "πŸ“‚ Loading images...")
488
-
489
- for file in files:
490
- try:
491
- img = Image.open(file.name).convert("RGB")
492
- images.append(img)
493
- file_names.append(os.path.basename(file.name))
494
- except Exception as e:
495
- print(f"Error loading {file.name}: {e}")
496
-
497
- if not images:
498
- return "❌ No valid images loaded", pd.DataFrame(), []
499
-
500
- # Determine batch size
501
- if auto_batch_enabled:
502
- batch_size = evaluator.optimize_batch_size(images[:2])
503
- progress(0.2, f"πŸ”§ Optimized batch size: {batch_size}")
504
- else:
505
- batch_size = int(manual_batch_size)
506
-
507
- # Process images
508
- def progress_callback(prog, msg):
509
- progress(0.2 + prog * 0.7, msg)
510
-
511
- results = evaluator.evaluate_images(
512
- images, file_names, models, batch_size, progress_callback
513
- )
514
-
515
- progress(0.95, "πŸ“Š Generating results table...")
516
-
517
- # Convert to DataFrame
518
- df = evaluator.results_to_dataframe(results)
519
- df = df.sort_values('Final Score', ascending=False, na_position='last')
520
-
521
- progress(1.0, f"βœ… Processed {len(results)} images successfully!")
522
-
523
- return f"βœ… Evaluated {len(results)} images using {len(models)} models", df, results
524
-
525
- except Exception as e:
526
- return f"❌ Error during processing: {str(e)}", pd.DataFrame(), []
527
 
528
- def update_results_table(models, current_results):
529
- if not current_results:
 
530
  return pd.DataFrame()
531
 
532
- # Recalculate final scores based on selected models
533
- for result in current_results:
534
- result.final_score = evaluator._calculate_final_score(result, models)
535
 
536
- df = evaluator.results_to_dataframe(current_results)
537
- return df.sort_values('Final Score', ascending=False, na_position='last')
538
-
539
- def export_results(current_results):
540
- if not current_results:
541
- return gr.update(visible=False)
542
 
543
- df = evaluator.results_to_dataframe(current_results)
544
- csv_path = evaluator.temp_dir / "evaluation_results.csv"
545
- df.to_csv(csv_path, index=False)
546
 
547
- return gr.update(value=str(csv_path), visible=True)
 
548
 
549
- def clear_all():
550
- return (
551
- "πŸ”„ Ready for new evaluation",
552
- pd.DataFrame(),
553
- [],
554
- gr.update(visible=False)
555
- )
556
 
557
- # Wire up events
558
- auto_batch.change(
559
- toggle_batch_slider,
560
- inputs=[auto_batch],
561
- outputs=[manual_batch]
562
- )
 
 
 
 
 
 
 
 
 
 
563
 
 
564
  evaluate_btn.click(
565
- process_images,
566
- inputs=[input_files, model_selection, auto_batch, manual_batch],
567
- outputs=[status_text, results_table, results_state]
568
  )
569
 
570
- model_selection.change(
571
- update_results_table,
572
- inputs=[model_selection, results_state],
573
- outputs=[results_table]
574
  )
575
 
576
- export_csv.click(
577
- export_results,
578
- inputs=[results_state],
579
- outputs=[download_file]
580
  )
581
 
582
- clear_btn.click(
583
- clear_all,
584
- outputs=[status_text, results_table, results_state, download_file]
 
585
  )
586
 
587
- # Initial setup
588
- app.load(lambda: "πŸ”„ Ready for evaluation - Upload images to get started!")
 
 
 
 
 
 
 
 
 
 
589
 
590
- return app
591
 
592
 
593
  if __name__ == "__main__":
594
- app = create_interface()
595
- app.queue(max_size=10).launch(
596
- server_name="0.0.0.0",
597
- server_port=7860,
598
- share=False,
599
- show_error=True
600
- )
 
1
  import os
2
+ import asyncio
3
+ from typing import List, Dict, Optional, Tuple, Any
4
+ from dataclasses import dataclass, field
 
 
5
  from pathlib import Path
6
+ import logging
7
 
8
  import cv2
9
  import numpy as np
 
11
  import onnxruntime as rt
12
  from PIL import Image
13
  import gradio as gr
 
14
  from transformers import pipeline
15
  from huggingface_hub import hf_hub_download
16
+ import pandas as pd
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
+ logger = logging.getLogger(__name__)
21
 
22
+ # Import aesthetic predictor function
23
  from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
24
 
25
 
26
  @dataclass
27
  class EvaluationResult:
28
+ """Data class for storing image evaluation results"""
29
  file_name: str
30
+ image_path: str
31
+ scores: Dict[str, Optional[float]] = field(default_factory=dict)
 
 
 
32
  final_score: Optional[float] = None
33
+
34
+ def calculate_final_score(self, selected_models: List[str]) -> None:
35
+ """Calculate the average score from selected models"""
36
+ valid_scores = [
37
+ score for model, score in self.scores.items()
38
+ if model in selected_models and score is not None
39
+ ]
40
+ self.final_score = np.mean(valid_scores) if valid_scores else None
41
 
42
 
43
+ class BaseModel:
44
+ """Base class for all evaluation models"""
45
+ def __init__(self, name: str):
46
+ self.name = name
47
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
48
+
49
+ async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
50
+ """Evaluate a batch of images"""
51
+ raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
+ class AestheticShadowModel(BaseModel):
55
+ """Aesthetic Shadow V2 model implementation"""
56
+ def __init__(self):
57
+ super().__init__("Aesthetic Shadow")
58
+ logger.info(f"Loading {self.name} model...")
59
+ self.model = pipeline(
60
+ "image-classification",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  model="NeoChen1024/aesthetic-shadow-v2-backup",
62
+ device=0 if self.device == 'cuda' else -1
63
  )
64
 
65
+ async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  try:
67
+ results = self.model(images)
68
+ scores = []
69
+ for result in results:
70
+ hq_score = next((p['score'] for p in result if p['label'] == 'hq'), 0)
71
+ scores.append(float(np.clip(hq_score * 10.0, 0.0, 10.0)))
72
+ return scores
73
  except Exception as e:
74
+ logger.error(f"Error in {self.name}: {e}")
75
+ return [None] * len(images)
76
 
77
 
78
+ class WaifuScorerModel(BaseModel):
79
+ """Waifu Scorer V3 model implementation"""
 
80
  def __init__(self):
81
+ super().__init__("Waifu Scorer")
82
+ logger.info(f"Loading {self.name} model...")
83
+ self._load_model()
84
 
85
+ def _load_model(self):
86
+ try:
87
+ import clip
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ # Load MLP model
90
+ self.mlp = self._create_mlp()
91
+ model_path = hf_hub_download("Eugeoter/waifu-scorer-v3", "model.pth")
92
+ state_dict = torch.load(model_path, map_location=self.device)
93
+ self.mlp.load_state_dict(state_dict)
94
+ self.mlp.to(self.device).eval()
95
 
96
+ # Load CLIP model
97
+ self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device)
98
+ self.available = True
99
+ except Exception as e:
100
+ logger.error(f"Failed to load {self.name}: {e}")
101
+ self.available = False
102
 
103
+ def _create_mlp(self) -> torch.nn.Module:
104
+ """Create the MLP architecture"""
105
+ return torch.nn.Sequential(
106
+ torch.nn.Linear(768, 2048),
107
+ torch.nn.ReLU(),
108
+ torch.nn.BatchNorm1d(2048),
109
+ torch.nn.Dropout(0.3),
110
+ torch.nn.Linear(2048, 512),
111
+ torch.nn.ReLU(),
112
+ torch.nn.BatchNorm1d(512),
113
+ torch.nn.Dropout(0.3),
114
+ torch.nn.Linear(512, 256),
115
+ torch.nn.ReLU(),
116
+ torch.nn.BatchNorm1d(256),
117
+ torch.nn.Dropout(0.2),
118
+ torch.nn.Linear(256, 128),
119
+ torch.nn.ReLU(),
120
+ torch.nn.BatchNorm1d(128),
121
+ torch.nn.Dropout(0.1),
122
+ torch.nn.Linear(128, 32),
123
+ torch.nn.ReLU(),
124
+ torch.nn.Linear(32, 1)
125
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ @torch.no_grad()
128
+ async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
129
+ if not self.available:
130
  return [None] * len(images)
131
 
132
  try:
133
+ # Process images
134
+ image_tensors = torch.cat([self.preprocess(img).unsqueeze(0) for img in images])
135
+ image_tensors = image_tensors.to(self.device)
136
+
137
+ # Extract features and predict
138
+ features = self.clip_model.encode_image(image_tensors)
139
+ features = features / features.norm(dim=-1, keepdim=True)
140
+ predictions = self.mlp(features)
141
+
142
+ scores = predictions.clamp(0, 10).cpu().numpy().flatten().tolist()
143
  return scores
144
  except Exception as e:
145
+ logger.error(f"Error in {self.name}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  return [None] * len(images)
147
+
148
+
149
+ class AestheticPredictorV25Model(BaseModel):
150
+ """Aesthetic Predictor V2.5 model implementation"""
151
+ def __init__(self):
152
+ super().__init__("Aesthetic V2.5")
153
+ logger.info(f"Loading {self.name} model...")
154
+ self.model, self.preprocessor = convert_v2_5_from_siglip(
155
+ low_cpu_mem_usage=True,
156
+ trust_remote_code=True,
157
+ )
158
+ if self.device == 'cuda':
159
+ self.model = self.model.to(torch.bfloat16).cuda()
160
 
161
+ @torch.no_grad()
162
+ async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
 
 
 
 
163
  try:
164
+ images_rgb = [img.convert("RGB") for img in images]
165
+ pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt").pixel_values
166
 
167
+ if self.device == 'cuda':
168
  pixel_values = pixel_values.to(torch.bfloat16).cuda()
169
 
170
+ scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
171
+ if scores.ndim == 0:
172
+ scores = np.array([scores])
173
+
174
+ return [float(np.clip(s, 0.0, 10.0)) for s in scores]
 
175
  except Exception as e:
176
+ logger.error(f"Error in {self.name}: {e}")
177
  return [None] * len(images)
178
+
179
+
180
+ class AnimeAestheticModel(BaseModel):
181
+ """Anime Aesthetic model implementation"""
182
+ def __init__(self):
183
+ super().__init__("Anime Score")
184
+ logger.info(f"Loading {self.name} model...")
185
+ model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
186
+ self.session = rt.InferenceSession(model_path, providers=['CPUExecutionProvider'])
187
 
188
+ async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
 
 
 
 
 
189
  scores = []
190
  for img in images:
191
  try:
192
+ score = self._process_single_image(img)
193
+ scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  except Exception as e:
195
+ logger.error(f"Error in {self.name} for single image: {e}")
196
  scores.append(None)
 
197
  return scores
198
 
199
+ def _process_single_image(self, img: Image.Image) -> float:
200
+ """Process a single image through the model"""
201
+ img_np = np.array(img).astype(np.float32) / 255.0
202
+ size = 768
203
+ h, w = img_np.shape[:2]
204
+
205
+ # Calculate new dimensions
206
+ if h > w:
207
+ new_h, new_w = size, int(size * w / h)
208
+ else:
209
+ new_h, new_w = int(size * h / w), size
210
+
211
+ # Resize and center
212
+ resized = cv2.resize(img_np, (new_w, new_h))
213
+ canvas = np.zeros((size, size, 3), dtype=np.float32)
214
+ pad_h = (size - new_h) // 2
215
+ pad_w = (size - new_w) // 2
216
+ canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
217
+
218
+ # Prepare input
219
+ input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
220
+ return self.session.run(None, {"img": input_tensor})[0].item()
221
+
222
+
223
+ class ImageEvaluator:
224
+ """Main class for managing image evaluation"""
225
+ def __init__(self):
226
+ self.models: Dict[str, BaseModel] = {}
227
+ self._initialize_models()
228
+ self.results: List[EvaluationResult] = []
229
+
230
+ def _initialize_models(self):
231
+ """Initialize all evaluation models"""
232
+ model_classes = [
233
+ ("aesthetic_shadow", AestheticShadowModel),
234
+ ("waifu_scorer", WaifuScorerModel),
235
+ ("aesthetic_predictor_v2_5", AestheticPredictorV25Model),
236
+ ("anime_aesthetic", AnimeAestheticModel),
237
+ ]
238
+
239
+ for key, model_class in model_classes:
240
+ try:
241
+ self.models[key] = model_class()
242
+ logger.info(f"Successfully loaded {key}")
243
+ except Exception as e:
244
+ logger.error(f"Failed to load {key}: {e}")
245
+
246
+ async def evaluate_images(
247
+ self,
248
+ file_paths: List[str],
249
+ selected_models: List[str],
250
+ batch_size: int = 8,
251
+ progress_callback = None
252
+ ) -> Tuple[List[EvaluationResult], List[str]]:
253
+ """Evaluate images with selected models"""
254
+ logs = []
255
+ results = []
256
+
257
+ # Load images
258
+ images = []
259
+ valid_paths = []
260
+ for path in file_paths:
261
+ try:
262
+ img = Image.open(path).convert("RGB")
263
+ images.append(img)
264
+ valid_paths.append(path)
265
+ except Exception as e:
266
+ logs.append(f"Failed to load {Path(path).name}: {e}")
267
+
268
+ if not images:
269
+ logs.append("No valid images to process")
270
+ return results, logs
271
+
272
+ logs.append(f"Loaded {len(images)} images")
273
 
274
+ # Process in batches
275
+ total_batches = (len(images) + batch_size - 1) // batch_size
 
 
276
 
277
+ for batch_idx in range(0, len(images), batch_size):
278
+ batch_images = images[batch_idx:batch_idx + batch_size]
279
+ batch_paths = valid_paths[batch_idx:batch_idx + batch_size]
280
+
281
+ # Evaluate with each selected model
282
+ batch_results = {}
283
+ for model_key in selected_models:
284
+ if model_key in self.models:
285
+ scores = await self.models[model_key].evaluate_batch(batch_images)
286
+ batch_results[model_key] = scores
287
+ logs.append(f"Processed batch {batch_idx//batch_size + 1}/{total_batches} with {self.models[model_key].name}")
288
+
289
+ # Create results
290
+ for i, (path, img) in enumerate(zip(batch_paths, batch_images)):
291
+ result = EvaluationResult(
292
+ file_name=Path(path).name,
293
+ image_path=path
294
+ )
295
+
296
+ for model_key in selected_models:
297
+ if model_key in batch_results:
298
+ result.scores[model_key] = batch_results[model_key][i]
299
+
300
+ result.calculate_final_score(selected_models)
301
+ results.append(result)
302
+
303
+ # Update progress
304
+ if progress_callback:
305
+ progress = (batch_idx + batch_size) / len(images) * 100
306
+ progress_callback(min(progress, 100))
307
+
308
+ self.results = results
309
+ return results, logs
310
 
311
+ def get_results_dataframe(self, selected_models: List[str]) -> pd.DataFrame:
312
+ """Convert results to pandas DataFrame"""
313
+ if not self.results:
314
+ return pd.DataFrame()
315
+
316
  data = []
317
+ for result in self.results:
318
  row = {
319
  'File Name': result.file_name,
320
+ 'Image': result.image_path,
321
  }
322
+
323
+ # Add model scores
324
+ for model_key in selected_models:
325
+ if model_key in self.models:
326
+ score = result.scores.get(model_key)
327
+ row[self.models[model_key].name] = f"{score:.4f}" if score is not None else "N/A"
328
+
329
+ row['Final Score'] = f"{result.final_score:.4f}" if result.final_score is not None else "N/A"
330
  data.append(row)
331
 
332
  return pd.DataFrame(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
 
335
  def create_interface():
336
+ """Create the Gradio interface"""
337
  evaluator = ImageEvaluator()
338
 
339
+ # Model options for checkbox
340
+ model_options = [
341
  ("Aesthetic Shadow", "aesthetic_shadow"),
342
  ("Waifu Scorer", "waifu_scorer"),
343
+ ("Aesthetic V2.5", "aesthetic_predictor_v2_5"),
344
+ ("Anime Score", "anime_aesthetic")
345
  ]
 
346
 
347
+ with gr.Blocks(theme=gr.themes.Soft(), title="Image Evaluation Tool") as demo:
348
  gr.Markdown("""
349
+ # 🎨 Advanced Image Evaluation Tool
350
 
351
+ Evaluate images using state-of-the-art aesthetic and quality prediction models.
352
+ Upload your images and select the models you want to use for evaluation.
 
 
 
 
 
 
353
  """)
354
 
355
  with gr.Row():
356
  with gr.Column(scale=1):
 
357
  input_files = gr.File(
358
+ label="Upload Images",
359
  file_count="multiple",
360
  file_types=["image"]
361
  )
362
 
363
+ model_checkboxes = gr.CheckboxGroup(
364
+ choices=[label for label, _ in model_options],
365
+ value=[label for label, _ in model_options],
366
+ label="Select Models",
367
  info="Choose which models to use for evaluation"
368
  )
369
 
370
  with gr.Row():
371
+ batch_size = gr.Slider(
 
 
 
 
 
 
372
  minimum=1,
373
+ maximum=64,
374
+ value=8,
375
  step=1,
376
+ label="Batch Size",
377
+ info="Number of images to process at once"
 
378
  )
379
 
380
+ with gr.Row():
381
+ evaluate_btn = gr.Button("πŸš€ Evaluate Images", variant="primary", scale=2)
382
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)
 
 
 
 
383
 
384
  with gr.Column(scale=2):
385
+ progress = gr.Progress()
386
+ logs = gr.Textbox(
387
+ label="Processing Logs",
388
+ lines=10,
389
+ max_lines=10,
390
+ autoscroll=True
391
  )
392
 
393
+ results_df = gr.Dataframe(
394
+ label="Evaluation Results",
 
395
  interactive=False,
396
+ wrap=True
 
397
  )
398
 
399
+ download_btn = gr.Button("πŸ“₯ Download Results (CSV)", variant="secondary")
400
+ download_file = gr.File(visible=False)
 
 
 
 
 
401
 
402
+ # State for storing results
403
  results_state = gr.State([])
404
 
405
+ async def process_images(files, selected_model_labels, batch_size, progress=gr.Progress()):
406
+ """Process uploaded images"""
407
+ if not files:
408
+ return "Please upload images first", pd.DataFrame(), []
 
 
 
409
 
410
+ # Convert labels to keys
411
+ selected_models = [key for label, key in model_options if label in selected_model_labels]
412
+
413
+ # Get file paths
414
+ file_paths = [f.name for f in files]
415
+
416
+ # Progress callback
417
+ def update_progress(value):
418
+ progress(value / 100, desc=f"Processing images... {value:.0f}%")
419
+
420
+ # Evaluate images
421
+ results, logs = await evaluator.evaluate_images(
422
+ file_paths,
423
+ selected_models,
424
+ batch_size,
425
+ update_progress
426
+ )
427
+
428
+ # Create DataFrame
429
+ df = evaluator.get_results_dataframe(selected_models)
430
+
431
+ # Format logs
432
+ log_text = "\n".join(logs[-10:]) # Show last 10 logs
433
+
434
+ return log_text, df, results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
+ def update_results_on_model_change(selected_model_labels, results):
437
+ """Update results when model selection changes"""
438
+ if not results:
439
  return pd.DataFrame()
440
 
441
+ # Convert labels to keys
442
+ selected_models = [key for label, key in model_options if label in selected_model_labels]
 
443
 
444
+ # Recalculate final scores
445
+ for result in results:
446
+ result.calculate_final_score(selected_models)
 
 
 
447
 
448
+ # Update evaluator results
449
+ evaluator.results = results
 
450
 
451
+ # Create updated DataFrame
452
+ return evaluator.get_results_dataframe(selected_models)
453
 
454
+ def clear_interface():
455
+ """Clear all results"""
456
+ return "", pd.DataFrame(), [], None
 
 
 
 
457
 
458
+ def prepare_download(selected_model_labels, results):
459
+ """Prepare CSV file for download"""
460
+ if not results:
461
+ return None
462
+
463
+ # Convert labels to keys
464
+ selected_models = [key for label, key in model_options if label in selected_model_labels]
465
+
466
+ # Get DataFrame
467
+ df = evaluator.get_results_dataframe(selected_models)
468
+
469
+ # Save to temporary file
470
+ import tempfile
471
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
472
+ df.to_csv(f, index=False)
473
+ return f.name
474
 
475
+ # Event handlers
476
  evaluate_btn.click(
477
+ fn=process_images,
478
+ inputs=[input_files, model_checkboxes, batch_size],
479
+ outputs=[logs, results_df, results_state]
480
  )
481
 
482
+ model_checkboxes.change(
483
+ fn=update_results_on_model_change,
484
+ inputs=[model_checkboxes, results_state],
485
+ outputs=[results_df]
486
  )
487
 
488
+ clear_btn.click(
489
+ fn=clear_interface,
490
+ outputs=[logs, results_df, results_state, download_file]
 
491
  )
492
 
493
+ download_btn.click(
494
+ fn=prepare_download,
495
+ inputs=[model_checkboxes, results_state],
496
+ outputs=[download_file]
497
  )
498
 
499
+ gr.Markdown("""
500
+ ### πŸ“ Notes
501
+ - **Model Selection**: Choose which models to use for evaluation. Final score is the average of selected models.
502
+ - **Batch Size**: Adjust based on your GPU memory. Larger batches process faster.
503
+ - **Results Table**: Click column headers to sort. The table updates automatically when models are selected/deselected.
504
+ - **Download**: Export results as CSV for further analysis.
505
+
506
+ ### 🎯 Score Interpretation
507
+ - **7-10**: High quality/aesthetic appeal
508
+ - **5-7**: Medium quality
509
+ - **0-5**: Lower quality
510
+ """)
511
 
512
+ return demo
513
 
514
 
515
  if __name__ == "__main__":
516
+ # Create and launch the interface
517
+ demo = create_interface()
518
+ demo.queue().launch()