Update app.py
Browse files
app.py
CHANGED
@@ -1,824 +1,600 @@
|
|
1 |
-
"""
|
2 |
-
Modern Image Evaluation Tool with Aesthetic and Quality Prediction Models
|
3 |
-
|
4 |
-
This refactored version features:
|
5 |
-
- Modern async/await patterns with proper error handling
|
6 |
-
- Type hints throughout for better code maintainability
|
7 |
-
- Dependency injection and factory patterns
|
8 |
-
- Proper resource management with context managers
|
9 |
-
- Configuration-driven model loading
|
10 |
-
- Improved batch processing with memory optimization
|
11 |
-
- Clean separation of concerns with proper abstraction layers
|
12 |
-
"""
|
13 |
-
|
14 |
-
import asyncio
|
15 |
-
import base64
|
16 |
-
import csv
|
17 |
-
import logging
|
18 |
import os
|
19 |
import tempfile
|
20 |
-
import
|
21 |
-
from
|
22 |
-
from
|
23 |
-
from
|
24 |
-
from io import BytesIO, StringIO
|
25 |
from pathlib import Path
|
26 |
-
from typing import Dict, List, Optional, Protocol, Tuple, Union, Any
|
27 |
-
from abc import ABC, abstractmethod
|
28 |
|
29 |
import cv2
|
30 |
-
import gradio as gr
|
31 |
import numpy as np
|
32 |
-
import onnxruntime as ort
|
33 |
import torch
|
34 |
-
import
|
35 |
from PIL import Image
|
|
|
|
|
36 |
from transformers import pipeline
|
37 |
from huggingface_hub import hf_hub_download
|
38 |
|
39 |
-
#
|
40 |
-
|
41 |
-
logger = logging.getLogger(__name__)
|
42 |
-
|
43 |
-
|
44 |
-
# =============================================================================
|
45 |
-
# Configuration and Data Models
|
46 |
-
# =============================================================================
|
47 |
-
|
48 |
-
class ModelType(Enum):
|
49 |
-
"""Enumeration of available model types."""
|
50 |
-
AESTHETIC_SHADOW = "aesthetic_shadow"
|
51 |
-
WAIFU_SCORER = "waifu_scorer"
|
52 |
-
AESTHETIC_PREDICTOR_V2_5 = "aesthetic_predictor_v2_5"
|
53 |
-
ANIME_AESTHETIC = "anime_aesthetic"
|
54 |
-
|
55 |
-
|
56 |
-
@dataclass
|
57 |
-
class ModelConfig:
|
58 |
-
"""Configuration for individual models."""
|
59 |
-
name: str
|
60 |
-
display_name: str
|
61 |
-
enabled: bool = True
|
62 |
-
batch_supported: bool = True
|
63 |
-
model_path: Optional[str] = None
|
64 |
-
cache_dir: Optional[str] = None
|
65 |
-
|
66 |
-
|
67 |
-
@dataclass
|
68 |
-
class ProcessingConfig:
|
69 |
-
"""Configuration for processing parameters."""
|
70 |
-
auto_batch: bool = False
|
71 |
-
manual_batch_size: int = 1
|
72 |
-
max_batch_size: int = 64
|
73 |
-
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
74 |
-
score_range: Tuple[float, float] = (0.0, 10.0)
|
75 |
|
76 |
|
77 |
@dataclass
|
78 |
class EvaluationResult:
|
79 |
-
"""Data class for
|
80 |
file_name: str
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
84 |
final_score: Optional[float] = None
|
85 |
-
processing_time: float = 0.0
|
86 |
-
error: Optional[str] = None
|
87 |
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
results: List[EvaluationResult]
|
93 |
-
logs: List[str]
|
94 |
-
processing_time: float
|
95 |
-
batch_size_used: int
|
96 |
-
success_count: int
|
97 |
-
error_count: int
|
98 |
-
|
99 |
-
|
100 |
-
# =============================================================================
|
101 |
-
# Model Interfaces and Implementations
|
102 |
-
# =============================================================================
|
103 |
-
|
104 |
-
class BaseModel(Protocol):
|
105 |
-
"""Protocol defining the interface for all evaluation models."""
|
106 |
-
|
107 |
-
async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
|
108 |
-
"""Predict scores for a batch of images."""
|
109 |
-
...
|
110 |
-
|
111 |
-
def is_available(self) -> bool:
|
112 |
-
"""Check if the model is available and ready for inference."""
|
113 |
-
...
|
114 |
-
|
115 |
-
def cleanup(self) -> None:
|
116 |
-
"""Clean up model resources."""
|
117 |
-
...
|
118 |
-
|
119 |
-
|
120 |
-
class ModernMLP(nn.Module):
|
121 |
-
"""Modern implementation of MLP with improved architecture."""
|
122 |
-
|
123 |
-
def __init__(
|
124 |
-
self,
|
125 |
-
input_size: int,
|
126 |
-
hidden_dims: List[int] = None,
|
127 |
-
dropout_rates: List[float] = None,
|
128 |
-
use_batch_norm: bool = True,
|
129 |
-
activation: nn.Module = nn.ReLU
|
130 |
-
):
|
131 |
super().__init__()
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
layers.append(nn.BatchNorm1d(hidden_dim))
|
147 |
-
|
148 |
-
if dropout_rate > 0:
|
149 |
-
layers.append(nn.Dropout(dropout_rate))
|
150 |
-
|
151 |
-
prev_dim = hidden_dim
|
152 |
-
|
153 |
-
# Final output layer
|
154 |
-
layers.append(nn.Linear(prev_dim, 1))
|
155 |
-
self.network = nn.Sequential(*layers)
|
156 |
-
|
157 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
158 |
return self.network(x)
|
159 |
|
160 |
|
161 |
-
class
|
162 |
-
"""
|
163 |
|
164 |
-
def __init__(self,
|
165 |
-
self.
|
166 |
-
self.
|
167 |
-
self.
|
168 |
-
self._available = False
|
169 |
-
self._model = None
|
170 |
-
self._clip_model = None
|
171 |
-
self._preprocess = None
|
172 |
-
|
173 |
-
self._initialize_model()
|
174 |
|
175 |
-
def
|
176 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
try:
|
178 |
import clip
|
179 |
|
180 |
-
#
|
181 |
-
model_path =
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
# Load weights
|
187 |
-
if model_path.endswith(".safetensors"):
|
188 |
-
from safetensors.torch import load_file
|
189 |
-
state_dict = load_file(model_path)
|
190 |
-
else:
|
191 |
-
state_dict = torch.load(model_path, map_location=self.device)
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
self._model.eval()
|
196 |
-
|
197 |
-
# Load CLIP model
|
198 |
-
self._clip_model, self._preprocess = clip.load("ViT-L/14", device=self.device)
|
199 |
-
self._available = True
|
200 |
-
|
201 |
-
logger.info(f"WaifuScorer model loaded successfully on {self.device}")
|
202 |
|
|
|
|
|
|
|
|
|
|
|
203 |
except Exception as e:
|
204 |
-
|
205 |
-
self.
|
206 |
|
207 |
-
def
|
208 |
-
"""
|
209 |
-
|
210 |
-
return self.config.model_path
|
211 |
-
|
212 |
-
# Default download path
|
213 |
-
model_path = "Eugeoter/waifu-scorer-v3/model.pth"
|
214 |
-
username, repo_id, model_name = model_path.split("/")[-3:]
|
215 |
-
return hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=self.config.cache_dir)
|
216 |
-
|
217 |
-
async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
|
218 |
-
"""Predict scores for a batch of images."""
|
219 |
-
if not self._available:
|
220 |
-
return [None] * len(images)
|
221 |
-
|
222 |
try:
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
# Extract features and predict
|
231 |
-
with torch.no_grad():
|
232 |
-
image_features = self._clip_model.encode_image(image_batch)
|
233 |
-
# Normalize features
|
234 |
-
norm = image_features.norm(2, dim=-1, keepdim=True)
|
235 |
-
norm[norm == 0] = 1
|
236 |
-
normalized_features = (image_features / norm).to(device=self.device, dtype=self.dtype)
|
237 |
-
|
238 |
-
predictions = self._model(normalized_features)
|
239 |
-
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
|
240 |
-
|
241 |
-
return scores[:len(images)]
|
242 |
|
|
|
|
|
|
|
|
|
243 |
except Exception as e:
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
def is_available(self) -> bool:
|
248 |
-
return self._available
|
249 |
|
250 |
-
def
|
251 |
-
"""
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
|
259 |
-
class
|
260 |
-
"""
|
261 |
|
262 |
-
def __init__(self
|
263 |
-
self.
|
264 |
-
self.
|
265 |
-
self._available = False
|
266 |
-
self._model = None
|
267 |
-
|
268 |
-
self._initialize_model()
|
269 |
|
270 |
-
def
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
|
|
|
|
|
|
|
|
|
|
280 |
|
281 |
-
|
282 |
-
|
283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
|
285 |
-
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
return [None] * len(images)
|
289 |
|
290 |
try:
|
291 |
-
results = self.
|
292 |
scores = []
|
293 |
-
|
294 |
for result in results:
|
295 |
try:
|
296 |
hq_score = next(p for p in result if p['label'] == 'hq')['score']
|
297 |
-
|
298 |
-
|
299 |
-
except (StopIteration, KeyError, TypeError):
|
300 |
scores.append(None)
|
301 |
-
|
302 |
return scores
|
303 |
-
|
304 |
except Exception as e:
|
305 |
-
|
306 |
return [None] * len(images)
|
307 |
|
308 |
-
def
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
class AestheticPredictorV25Model:
|
317 |
-
"""Wrapper for Aesthetic Predictor V2.5 model."""
|
318 |
-
|
319 |
-
def __init__(self, config: ModelConfig, device: str):
|
320 |
-
self.config = config
|
321 |
-
self.device = device
|
322 |
-
self._available = False
|
323 |
-
self._model = None
|
324 |
-
self._preprocessor = None
|
325 |
-
|
326 |
-
self._initialize_model()
|
327 |
-
|
328 |
-
def _initialize_model(self) -> None:
|
329 |
-
"""Initialize the model."""
|
330 |
try:
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
|
|
|
|
|
|
|
|
|
|
344 |
except Exception as e:
|
345 |
-
|
346 |
-
|
347 |
|
348 |
-
|
349 |
-
"""
|
350 |
-
|
|
|
351 |
return [None] * len(images)
|
352 |
|
353 |
try:
|
354 |
rgb_images = [img.convert("RGB") for img in images]
|
355 |
-
pixel_values =
|
356 |
|
357 |
if torch.cuda.is_available():
|
358 |
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
359 |
|
360 |
with torch.inference_mode():
|
361 |
-
scores =
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
return [float(np.round(np.clip(s, 0.0, 10.0), 4)) for s in scores]
|
367 |
-
|
368 |
except Exception as e:
|
369 |
-
|
370 |
return [None] * len(images)
|
371 |
|
372 |
-
def
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
if self._model is not None:
|
377 |
-
del self._model
|
378 |
-
|
379 |
-
|
380 |
-
class AnimeAestheticModel:
|
381 |
-
"""ONNX-based Anime Aesthetic model."""
|
382 |
-
|
383 |
-
def __init__(self, config: ModelConfig, device: str):
|
384 |
-
self.config = config
|
385 |
-
self.device = device
|
386 |
-
self._available = False
|
387 |
-
self._session = None
|
388 |
-
|
389 |
-
self._initialize_model()
|
390 |
-
|
391 |
-
def _initialize_model(self) -> None:
|
392 |
-
"""Initialize the ONNX model."""
|
393 |
-
try:
|
394 |
-
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
|
395 |
-
self._session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
|
396 |
-
self._available = True
|
397 |
-
logger.info("Anime Aesthetic model loaded successfully")
|
398 |
-
|
399 |
-
except Exception as e:
|
400 |
-
logger.error(f"Failed to initialize Anime Aesthetic: {e}")
|
401 |
-
self._available = False
|
402 |
-
|
403 |
-
async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
|
404 |
-
"""Predict scores for images (single image processing for ONNX)."""
|
405 |
-
if not self._available:
|
406 |
return [None] * len(images)
|
407 |
|
408 |
scores = []
|
409 |
for img in images:
|
410 |
try:
|
411 |
-
|
412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
except Exception as e:
|
414 |
-
|
415 |
scores.append(None)
|
416 |
|
417 |
return scores
|
418 |
|
419 |
-
def
|
420 |
-
"""
|
421 |
-
|
422 |
-
s = 768
|
423 |
-
h, w = img_np.shape[:2]
|
424 |
-
|
425 |
-
# Resize while maintaining aspect ratio
|
426 |
-
if h > w:
|
427 |
-
new_h, new_w = s, int(s * w / h)
|
428 |
-
else:
|
429 |
-
new_h, new_w = int(s * h / w), s
|
430 |
-
|
431 |
-
resized = cv2.resize(img_np, (new_w, new_h))
|
432 |
-
|
433 |
-
# Center crop/pad to square
|
434 |
-
canvas = np.zeros((s, s, 3), dtype=np.float32)
|
435 |
-
pad_h = (s - new_h) // 2
|
436 |
-
pad_w = (s - new_w) // 2
|
437 |
-
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
|
438 |
-
|
439 |
-
# Prepare input
|
440 |
-
input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
|
441 |
-
return self._session.run(None, {"img": input_tensor})[0].item()
|
442 |
-
|
443 |
-
def is_available(self) -> bool:
|
444 |
-
return self._available
|
445 |
-
|
446 |
-
def cleanup(self) -> None:
|
447 |
-
if self._session is not None:
|
448 |
-
del self._session
|
449 |
-
|
450 |
-
|
451 |
-
# =============================================================================
|
452 |
-
# Model Factory and Manager
|
453 |
-
# =============================================================================
|
454 |
-
|
455 |
-
class ModelFactory:
|
456 |
-
"""Factory for creating model instances."""
|
457 |
-
|
458 |
-
_MODEL_CLASSES = {
|
459 |
-
ModelType.AESTHETIC_SHADOW: AestheticShadowModel,
|
460 |
-
ModelType.WAIFU_SCORER: WaifuScorerModel,
|
461 |
-
ModelType.AESTHETIC_PREDICTOR_V2_5: AestheticPredictorV25Model,
|
462 |
-
ModelType.ANIME_AESTHETIC: AnimeAestheticModel,
|
463 |
-
}
|
464 |
-
|
465 |
-
@classmethod
|
466 |
-
def create_model(cls, model_type: ModelType, config: ModelConfig, device: str) -> BaseModel:
|
467 |
-
"""Create a model instance based on type."""
|
468 |
-
model_class = cls._MODEL_CLASSES.get(model_type)
|
469 |
-
if not model_class:
|
470 |
-
raise ValueError(f"Unknown model type: {model_type}")
|
471 |
-
|
472 |
-
return model_class(config, device)
|
473 |
-
|
474 |
-
|
475 |
-
class ModelManager:
|
476 |
-
"""Advanced model manager with async processing and resource management."""
|
477 |
-
|
478 |
-
def __init__(self, processing_config: ProcessingConfig):
|
479 |
-
self.config = processing_config
|
480 |
-
self.models: Dict[ModelType, BaseModel] = {}
|
481 |
-
self.model_configs = self._create_default_configs()
|
482 |
-
self._processing_queue = asyncio.Queue()
|
483 |
-
self._worker_task: Optional[asyncio.Task] = None
|
484 |
-
self._temp_dir = Path(tempfile.mkdtemp())
|
485 |
-
|
486 |
-
self._initialize_models()
|
487 |
-
|
488 |
-
def _create_default_configs(self) -> Dict[ModelType, ModelConfig]:
|
489 |
-
"""Create default model configurations."""
|
490 |
-
return {
|
491 |
-
ModelType.AESTHETIC_SHADOW: ModelConfig(
|
492 |
-
name="aesthetic_shadow",
|
493 |
-
display_name="Aesthetic Shadow"
|
494 |
-
),
|
495 |
-
ModelType.WAIFU_SCORER: ModelConfig(
|
496 |
-
name="waifu_scorer",
|
497 |
-
display_name="Waifu Scorer"
|
498 |
-
),
|
499 |
-
ModelType.AESTHETIC_PREDICTOR_V2_5: ModelConfig(
|
500 |
-
name="aesthetic_predictor_v2_5",
|
501 |
-
display_name="Aesthetic V2.5"
|
502 |
-
),
|
503 |
-
ModelType.ANIME_AESTHETIC: ModelConfig(
|
504 |
-
name="anime_aesthetic",
|
505 |
-
display_name="Anime Score",
|
506 |
-
batch_supported=False
|
507 |
-
),
|
508 |
-
}
|
509 |
-
|
510 |
-
def _initialize_models(self) -> None:
|
511 |
-
"""Initialize all models."""
|
512 |
-
logger.info("Initializing models...")
|
513 |
-
|
514 |
-
for model_type, config in self.model_configs.items():
|
515 |
-
if config.enabled:
|
516 |
-
try:
|
517 |
-
model = ModelFactory.create_model(model_type, config, self.config.device)
|
518 |
-
if model.is_available():
|
519 |
-
self.models[model_type] = model
|
520 |
-
logger.info(f"β {config.display_name} loaded successfully")
|
521 |
-
else:
|
522 |
-
logger.warning(f"β {config.display_name} failed to load")
|
523 |
-
except Exception as e:
|
524 |
-
logger.error(f"β {config.display_name} initialization error: {e}")
|
525 |
-
|
526 |
-
logger.info(f"Initialized {len(self.models)} models successfully")
|
527 |
-
|
528 |
-
async def start_worker(self) -> None:
|
529 |
-
"""Start the background processing worker."""
|
530 |
-
if self._worker_task is None:
|
531 |
-
self._worker_task = asyncio.create_task(self._worker_loop())
|
532 |
-
logger.info("Background worker started")
|
533 |
-
|
534 |
-
async def _worker_loop(self) -> None:
|
535 |
-
"""Main worker loop for processing requests."""
|
536 |
-
while True:
|
537 |
-
request = await self._processing_queue.get()
|
538 |
-
|
539 |
-
if request is None: # Shutdown signal
|
540 |
-
break
|
541 |
-
|
542 |
-
try:
|
543 |
-
result = await self._process_request(request)
|
544 |
-
request['future'].set_result(result)
|
545 |
-
except Exception as e:
|
546 |
-
request['future'].set_exception(e)
|
547 |
-
finally:
|
548 |
-
self._processing_queue.task_done()
|
549 |
-
|
550 |
-
async def process_images(
|
551 |
-
self,
|
552 |
-
file_paths: List[str],
|
553 |
-
selected_models: List[ModelType],
|
554 |
-
auto_batch: bool = False,
|
555 |
-
manual_batch_size: int = 1
|
556 |
-
) -> BatchResult:
|
557 |
-
"""Process images with selected models."""
|
558 |
-
future = asyncio.Future()
|
559 |
-
request = {
|
560 |
-
'file_paths': file_paths,
|
561 |
-
'selected_models': selected_models,
|
562 |
-
'auto_batch': auto_batch,
|
563 |
-
'manual_batch_size': manual_batch_size,
|
564 |
-
'future': future
|
565 |
-
}
|
566 |
-
|
567 |
-
await self._processing_queue.put(request)
|
568 |
-
return await future
|
569 |
-
|
570 |
-
async def _process_request(self, request: Dict) -> BatchResult:
|
571 |
-
"""Process a single batch request."""
|
572 |
-
start_time = asyncio.get_event_loop().time()
|
573 |
-
logs = []
|
574 |
-
results = []
|
575 |
-
|
576 |
-
file_paths = request['file_paths']
|
577 |
-
selected_models = request['selected_models']
|
578 |
-
auto_batch = request['auto_batch']
|
579 |
-
manual_batch_size = request['manual_batch_size']
|
580 |
-
|
581 |
-
# Load images
|
582 |
-
images, valid_paths = await self._load_images(file_paths, logs)
|
583 |
-
|
584 |
-
if not images:
|
585 |
-
return BatchResult([], logs, 0.0, 0, 0, len(file_paths))
|
586 |
-
|
587 |
-
# Determine batch size
|
588 |
-
batch_size = await self._determine_batch_size(images, auto_batch, manual_batch_size, logs)
|
589 |
-
|
590 |
-
# Process in batches
|
591 |
-
for i in range(0, len(images), batch_size):
|
592 |
-
batch_images = images[i:i+batch_size]
|
593 |
-
batch_paths = valid_paths[i:i+batch_size]
|
594 |
-
|
595 |
-
batch_results = await self._process_batch(batch_images, batch_paths, selected_models, logs)
|
596 |
-
results.extend(batch_results)
|
597 |
-
|
598 |
-
processing_time = asyncio.get_event_loop().time() - start_time
|
599 |
-
success_count = sum(1 for r in results if r.error is None)
|
600 |
-
error_count = len(results) - success_count
|
601 |
-
|
602 |
-
return BatchResult(
|
603 |
-
results=results,
|
604 |
-
logs=logs,
|
605 |
-
processing_time=processing_time,
|
606 |
-
batch_size_used=batch_size,
|
607 |
-
success_count=success_count,
|
608 |
-
error_count=error_count
|
609 |
-
)
|
610 |
-
|
611 |
-
async def _load_images(self, file_paths: List[str], logs: List[str]) -> Tuple[List[Image.Image], List[str]]:
|
612 |
-
"""Load and validate images."""
|
613 |
-
images = []
|
614 |
-
valid_paths = []
|
615 |
|
616 |
-
|
|
|
|
|
|
|
617 |
|
618 |
-
|
619 |
-
try:
|
620 |
-
img = Image.open(path).convert("RGB")
|
621 |
-
images.append(img)
|
622 |
-
valid_paths.append(path)
|
623 |
-
except Exception as e:
|
624 |
-
logs.append(f"Failed to load {path}: {e}")
|
625 |
-
|
626 |
-
logs.append(f"Successfully loaded {len(images)} images")
|
627 |
-
return images, valid_paths
|
628 |
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
641 |
batch_size = 1
|
642 |
-
|
643 |
|
644 |
-
while batch_size <=
|
645 |
try:
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
batch_size *= 2
|
652 |
except Exception:
|
653 |
break
|
654 |
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
""
|
667 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
668 |
|
669 |
-
#
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
model_predictions[model_type.value] = [None] * len(images)
|
680 |
-
|
681 |
-
# Create results
|
682 |
-
for i, (image, path) in enumerate(zip(images, paths)):
|
683 |
-
# Collect scores for this image
|
684 |
-
scores = {}
|
685 |
-
valid_scores = []
|
686 |
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
692 |
|
693 |
-
#
|
694 |
-
|
695 |
-
|
696 |
-
final_score = float(np.clip(final_score, *self.config.score_range))
|
697 |
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
|
|
|
|
702 |
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
thumbnail_b64=thumbnail_b64,
|
707 |
-
model_scores=scores,
|
708 |
-
final_score=final_score
|
709 |
-
)
|
710 |
|
711 |
-
|
712 |
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
def get_available_models(self) -> Dict[ModelType, str]:
|
722 |
-
"""Get available models with their display names."""
|
723 |
-
return {
|
724 |
-
model_type: self.model_configs[model_type].display_name
|
725 |
-
for model_type in self.models.keys()
|
726 |
-
}
|
727 |
-
|
728 |
-
async def cleanup(self) -> None:
|
729 |
-
"""Clean up resources."""
|
730 |
-
# Shutdown worker
|
731 |
-
if self._worker_task:
|
732 |
-
await self._processing_queue.put(None)
|
733 |
-
await self._worker_task
|
734 |
-
|
735 |
-
# Clean up models
|
736 |
-
for model in self.models.values():
|
737 |
-
model.cleanup()
|
738 |
-
|
739 |
-
# Clean up temp directory
|
740 |
-
if self._temp_dir.exists():
|
741 |
-
shutil.rmtree(self._temp_dir)
|
742 |
-
|
743 |
-
logger.info("Model manager cleanup completed")
|
744 |
-
|
745 |
-
|
746 |
-
# =============================================================================
|
747 |
-
# Results Processing and Export
|
748 |
-
# =============================================================================
|
749 |
-
|
750 |
-
class ResultsProcessor:
|
751 |
-
"""Handle result processing, sorting, and export functionality."""
|
752 |
-
|
753 |
-
@staticmethod
|
754 |
-
def sort_results(results: List[EvaluationResult], sort_by: str, reverse: bool = True) -> List[EvaluationResult]:
|
755 |
-
"""Sort results by specified criteria."""
|
756 |
-
sort_key_map = {
|
757 |
-
"Final Score": lambda r: r.final_score if r.final_score is not None else -float('inf'),
|
758 |
-
"File Name": lambda r: r.file_name.lower(),
|
759 |
-
**{f"model_{model_type.value}": lambda r, mt=model_type.value: r.model_scores.get(mt) or -float('inf')
|
760 |
-
for model_type in ModelType}
|
761 |
-
}
|
762 |
-
|
763 |
-
sort_key = sort_key_map.get(sort_by, sort_key_map["Final Score"])
|
764 |
-
return sorted(results, key=sort_key, reverse=reverse and sort_by != "File Name")
|
765 |
-
|
766 |
-
@staticmethod
|
767 |
-
def generate_html_table(results: List[EvaluationResult], selected_models: List[ModelType]) -> str:
|
768 |
-
"""Generate HTML table for results display."""
|
769 |
-
if not results:
|
770 |
-
return "<p>No results to display</p>"
|
771 |
-
|
772 |
-
# CSS styles
|
773 |
-
styles = """
|
774 |
-
<style>
|
775 |
-
.results-table {
|
776 |
-
width: 100%; border-collapse: collapse; margin: 20px 0;
|
777 |
-
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
778 |
-
}
|
779 |
-
.results-table th, .results-table td {
|
780 |
-
border: 1px solid #ddd; padding: 12px; text-align: center;
|
781 |
-
}
|
782 |
-
.results-table th {
|
783 |
-
background-color: #f8f9fa; font-weight: 600; color: #495057;
|
784 |
-
}
|
785 |
-
.results-table tr:nth-child(even) { background-color: #f8f9fa; }
|
786 |
-
.results-table tr:hover { background-color: #e9ecef; }
|
787 |
-
.image-preview {
|
788 |
-
max-width: 120px; max-height: 120px; border-radius: 8px;
|
789 |
-
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
790 |
-
}
|
791 |
-
.score-excellent { color: #28a745; font-weight: bold; }
|
792 |
-
.score-good { color: #ffc107; font-weight: bold; }
|
793 |
-
.score-poor { color: #dc3545; font-weight: bold; }
|
794 |
-
.score-na { color: #6c757d; font-style: italic; }
|
795 |
-
</style>
|
796 |
-
"""
|
797 |
|
798 |
-
#
|
799 |
-
|
800 |
-
|
|
|
|
|
|
|
801 |
|
802 |
-
|
803 |
-
|
804 |
-
|
|
|
|
|
805 |
|
806 |
-
|
|
|
|
|
|
|
|
|
807 |
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
# Final score
|
820 |
-
html += ResultsProcessor._format_score_cell(result.final_score)
|
821 |
-
html += '</tr>'
|
822 |
|
823 |
-
|
824 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import tempfile
|
3 |
+
import base64
|
4 |
+
from io import BytesIO
|
5 |
+
from typing import List, Dict, Any, Optional, Tuple
|
6 |
+
from dataclasses import dataclass
|
|
|
7 |
from pathlib import Path
|
|
|
|
|
8 |
|
9 |
import cv2
|
|
|
10 |
import numpy as np
|
|
|
11 |
import torch
|
12 |
+
import onnxruntime as rt
|
13 |
from PIL import Image
|
14 |
+
import gradio as gr
|
15 |
+
import pandas as pd
|
16 |
from transformers import pipeline
|
17 |
from huggingface_hub import hf_hub_download
|
18 |
|
19 |
+
# Import necessary function from aesthetic_predictor_v2_5
|
20 |
+
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
@dataclass
|
24 |
class EvaluationResult:
|
25 |
+
"""Data class for storing image evaluation results."""
|
26 |
file_name: str
|
27 |
+
image: Image.Image
|
28 |
+
aesthetic_shadow: Optional[float] = None
|
29 |
+
waifu_scorer: Optional[float] = None
|
30 |
+
aesthetic_v2_5: Optional[float] = None
|
31 |
+
anime_aesthetic: Optional[float] = None
|
32 |
final_score: Optional[float] = None
|
|
|
|
|
33 |
|
34 |
|
35 |
+
class MLP(torch.nn.Module):
|
36 |
+
"""Optimized MLP for image feature regression."""
|
37 |
+
def __init__(self, input_size: int = 768):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
super().__init__()
|
39 |
+
self.network = torch.nn.Sequential(
|
40 |
+
torch.nn.Linear(input_size, 1024),
|
41 |
+
torch.nn.ReLU(),
|
42 |
+
torch.nn.BatchNorm1d(1024),
|
43 |
+
torch.nn.Dropout(0.2),
|
44 |
+
torch.nn.Linear(1024, 256),
|
45 |
+
torch.nn.ReLU(),
|
46 |
+
torch.nn.BatchNorm1d(256),
|
47 |
+
torch.nn.Dropout(0.1),
|
48 |
+
torch.nn.Linear(256, 64),
|
49 |
+
torch.nn.ReLU(),
|
50 |
+
torch.nn.Linear(64, 1)
|
51 |
+
)
|
52 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
54 |
return self.network(x)
|
55 |
|
56 |
|
57 |
+
class ModelLoader:
|
58 |
+
"""Centralized model loading and management."""
|
59 |
|
60 |
+
def __init__(self, device: str = None):
|
61 |
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
62 |
+
self.models = {}
|
63 |
+
self._load_all_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
def _load_all_models(self):
|
66 |
+
"""Load all models during initialization."""
|
67 |
+
try:
|
68 |
+
self._load_aesthetic_shadow()
|
69 |
+
self._load_waifu_scorer()
|
70 |
+
self._load_aesthetic_v2_5()
|
71 |
+
self._load_anime_aesthetic()
|
72 |
+
print("β
All models loaded successfully!")
|
73 |
+
except Exception as e:
|
74 |
+
print(f"β Error loading models: {e}")
|
75 |
+
|
76 |
+
def _load_aesthetic_shadow(self):
|
77 |
+
"""Load Aesthetic Shadow model."""
|
78 |
+
print("π Loading Aesthetic Shadow...")
|
79 |
+
self.models['aesthetic_shadow'] = pipeline(
|
80 |
+
"image-classification",
|
81 |
+
model="NeoChen1024/aesthetic-shadow-v2-backup",
|
82 |
+
device=self.device
|
83 |
+
)
|
84 |
+
|
85 |
+
def _load_waifu_scorer(self):
|
86 |
+
"""Load Waifu Scorer model."""
|
87 |
+
print("π Loading Waifu Scorer...")
|
88 |
try:
|
89 |
import clip
|
90 |
|
91 |
+
# Load MLP
|
92 |
+
model_path = hf_hub_download("Eugeoter/waifu-scorer-v3", "model.pth")
|
93 |
+
mlp = MLP()
|
94 |
+
state_dict = torch.load(model_path, map_location=self.device)
|
95 |
+
mlp.load_state_dict(state_dict)
|
96 |
+
mlp.to(self.device).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
+
# Load CLIP
|
99 |
+
clip_model, preprocess = clip.load("ViT-L/14", device=self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
self.models['waifu_scorer'] = {
|
102 |
+
'mlp': mlp,
|
103 |
+
'clip_model': clip_model,
|
104 |
+
'preprocess': preprocess
|
105 |
+
}
|
106 |
except Exception as e:
|
107 |
+
print(f"β οΈ Waifu Scorer not available: {e}")
|
108 |
+
self.models['waifu_scorer'] = None
|
109 |
|
110 |
+
def _load_aesthetic_v2_5(self):
|
111 |
+
"""Load Aesthetic Predictor V2.5."""
|
112 |
+
print("π Loading Aesthetic V2.5...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
try:
|
114 |
+
model, preprocessor = convert_v2_5_from_siglip(
|
115 |
+
low_cpu_mem_usage=True,
|
116 |
+
trust_remote_code=True,
|
117 |
+
)
|
118 |
+
if torch.cuda.is_available():
|
119 |
+
model = model.to(torch.bfloat16).cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
+
self.models['aesthetic_v2_5'] = {
|
122 |
+
'model': model,
|
123 |
+
'preprocessor': preprocessor
|
124 |
+
}
|
125 |
except Exception as e:
|
126 |
+
print(f"β οΈ Aesthetic V2.5 not available: {e}")
|
127 |
+
self.models['aesthetic_v2_5'] = None
|
|
|
|
|
|
|
128 |
|
129 |
+
def _load_anime_aesthetic(self):
|
130 |
+
"""Load Anime Aesthetic model."""
|
131 |
+
print("π Loading Anime Aesthetic...")
|
132 |
+
try:
|
133 |
+
model_path = hf_hub_download("skytnt/anime-aesthetic", "model.onnx")
|
134 |
+
self.models['anime_aesthetic'] = rt.InferenceSession(
|
135 |
+
model_path,
|
136 |
+
providers=['CPUExecutionProvider']
|
137 |
+
)
|
138 |
+
except Exception as e:
|
139 |
+
print(f"β οΈ Anime Aesthetic not available: {e}")
|
140 |
+
self.models['anime_aesthetic'] = None
|
141 |
|
142 |
|
143 |
+
class ImageEvaluator:
|
144 |
+
"""Main image evaluation class with batch processing."""
|
145 |
|
146 |
+
def __init__(self):
|
147 |
+
self.loader = ModelLoader()
|
148 |
+
self.temp_dir = Path(tempfile.mkdtemp())
|
|
|
|
|
|
|
|
|
149 |
|
150 |
+
def evaluate_images(
|
151 |
+
self,
|
152 |
+
images: List[Image.Image],
|
153 |
+
file_names: List[str],
|
154 |
+
selected_models: List[str],
|
155 |
+
batch_size: int = 4,
|
156 |
+
progress_callback=None
|
157 |
+
) -> List[EvaluationResult]:
|
158 |
+
"""Evaluate images using selected models."""
|
159 |
+
results = []
|
160 |
+
total_batches = (len(images) + batch_size - 1) // batch_size
|
161 |
+
|
162 |
+
for batch_idx in range(0, len(images), batch_size):
|
163 |
+
batch_images = images[batch_idx:batch_idx + batch_size]
|
164 |
+
batch_names = file_names[batch_idx:batch_idx + batch_size]
|
165 |
|
166 |
+
# Update progress
|
167 |
+
if progress_callback:
|
168 |
+
progress = (batch_idx // batch_size + 1) / total_batches
|
169 |
+
progress_callback(progress, f"Processing batch {batch_idx//batch_size + 1}/{total_batches}")
|
170 |
+
|
171 |
+
# Process batch
|
172 |
+
batch_results = self._process_batch(batch_images, batch_names, selected_models)
|
173 |
+
results.extend(batch_results)
|
174 |
+
|
175 |
+
return results
|
176 |
|
177 |
+
def _process_batch(
|
178 |
+
self,
|
179 |
+
images: List[Image.Image],
|
180 |
+
file_names: List[str],
|
181 |
+
selected_models: List[str]
|
182 |
+
) -> List[EvaluationResult]:
|
183 |
+
"""Process a single batch of images."""
|
184 |
+
batch_results = []
|
185 |
+
|
186 |
+
# Initialize results
|
187 |
+
for i, (img, name) in enumerate(zip(images, file_names)):
|
188 |
+
result = EvaluationResult(file_name=name, image=img)
|
189 |
+
batch_results.append(result)
|
190 |
+
|
191 |
+
# Process each selected model
|
192 |
+
if 'aesthetic_shadow' in selected_models:
|
193 |
+
scores = self._eval_aesthetic_shadow(images)
|
194 |
+
for result, score in zip(batch_results, scores):
|
195 |
+
result.aesthetic_shadow = score
|
196 |
+
|
197 |
+
if 'waifu_scorer' in selected_models:
|
198 |
+
scores = self._eval_waifu_scorer(images)
|
199 |
+
for result, score in zip(batch_results, scores):
|
200 |
+
result.waifu_scorer = score
|
201 |
+
|
202 |
+
if 'aesthetic_v2_5' in selected_models:
|
203 |
+
scores = self._eval_aesthetic_v2_5(images)
|
204 |
+
for result, score in zip(batch_results, scores):
|
205 |
+
result.aesthetic_v2_5 = score
|
206 |
+
|
207 |
+
if 'anime_aesthetic' in selected_models:
|
208 |
+
scores = self._eval_anime_aesthetic(images)
|
209 |
+
for result, score in zip(batch_results, scores):
|
210 |
+
result.anime_aesthetic = score
|
211 |
+
|
212 |
+
# Calculate final scores
|
213 |
+
for result in batch_results:
|
214 |
+
result.final_score = self._calculate_final_score(result, selected_models)
|
215 |
+
|
216 |
+
return batch_results
|
217 |
+
|
218 |
+
def _eval_aesthetic_shadow(self, images: List[Image.Image]) -> List[Optional[float]]:
|
219 |
+
"""Evaluate using Aesthetic Shadow model."""
|
220 |
+
if not self.loader.models.get('aesthetic_shadow'):
|
221 |
return [None] * len(images)
|
222 |
|
223 |
try:
|
224 |
+
results = self.loader.models['aesthetic_shadow'](images)
|
225 |
scores = []
|
|
|
226 |
for result in results:
|
227 |
try:
|
228 |
hq_score = next(p for p in result if p['label'] == 'hq')['score']
|
229 |
+
scores.append(float(np.clip(hq_score * 10.0, 0.0, 10.0)))
|
230 |
+
except:
|
|
|
231 |
scores.append(None)
|
|
|
232 |
return scores
|
|
|
233 |
except Exception as e:
|
234 |
+
print(f"Error in Aesthetic Shadow: {e}")
|
235 |
return [None] * len(images)
|
236 |
|
237 |
+
def _eval_waifu_scorer(self, images: List[Image.Image]) -> List[Optional[float]]:
|
238 |
+
"""Evaluate using Waifu Scorer model."""
|
239 |
+
model_dict = self.loader.models.get('waifu_scorer')
|
240 |
+
if not model_dict:
|
241 |
+
return [None] * len(images)
|
242 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
try:
|
244 |
+
with torch.no_grad():
|
245 |
+
# Preprocess images
|
246 |
+
image_tensors = [model_dict['preprocess'](img).unsqueeze(0) for img in images]
|
247 |
+
if len(image_tensors) == 1:
|
248 |
+
image_tensors = image_tensors * 2 # CLIP requirement
|
249 |
+
|
250 |
+
image_batch = torch.cat(image_tensors).to(self.loader.device)
|
251 |
+
image_features = model_dict['clip_model'].encode_image(image_batch)
|
252 |
+
|
253 |
+
# Normalize features
|
254 |
+
norm = image_features.norm(2, dim=-1, keepdim=True)
|
255 |
+
norm[norm == 0] = 1
|
256 |
+
im_emb = (image_features / norm).to(self.loader.device)
|
257 |
+
|
258 |
+
predictions = model_dict['mlp'](im_emb)
|
259 |
+
scores = predictions.clamp(0, 10).cpu().numpy().flatten().tolist()
|
260 |
+
|
261 |
+
return scores[:len(images)]
|
262 |
except Exception as e:
|
263 |
+
print(f"Error in Waifu Scorer: {e}")
|
264 |
+
return [None] * len(images)
|
265 |
|
266 |
+
def _eval_aesthetic_v2_5(self, images: List[Image.Image]) -> List[Optional[float]]:
|
267 |
+
"""Evaluate using Aesthetic Predictor V2.5."""
|
268 |
+
model_dict = self.loader.models.get('aesthetic_v2_5')
|
269 |
+
if not model_dict:
|
270 |
return [None] * len(images)
|
271 |
|
272 |
try:
|
273 |
rgb_images = [img.convert("RGB") for img in images]
|
274 |
+
pixel_values = model_dict['preprocessor'](images=rgb_images, return_tensors="pt").pixel_values
|
275 |
|
276 |
if torch.cuda.is_available():
|
277 |
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
278 |
|
279 |
with torch.inference_mode():
|
280 |
+
scores = model_dict['model'](pixel_values).logits.squeeze().float().cpu().numpy()
|
281 |
+
if scores.ndim == 0:
|
282 |
+
scores = np.array([scores])
|
283 |
+
|
284 |
+
return [float(np.clip(s, 0.0, 10.0)) for s in scores.tolist()]
|
|
|
|
|
285 |
except Exception as e:
|
286 |
+
print(f"Error in Aesthetic V2.5: {e}")
|
287 |
return [None] * len(images)
|
288 |
|
289 |
+
def _eval_anime_aesthetic(self, images: List[Image.Image]) -> List[Optional[float]]:
|
290 |
+
"""Evaluate using Anime Aesthetic model."""
|
291 |
+
model = self.loader.models.get('anime_aesthetic')
|
292 |
+
if not model:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
return [None] * len(images)
|
294 |
|
295 |
scores = []
|
296 |
for img in images:
|
297 |
try:
|
298 |
+
# Preprocess image
|
299 |
+
img_np = np.array(img).astype(np.float32) / 255.0
|
300 |
+
h, w = img_np.shape[:2]
|
301 |
+
s = 768
|
302 |
+
|
303 |
+
if h > w:
|
304 |
+
new_h, new_w = s, int(s * w / h)
|
305 |
+
else:
|
306 |
+
new_h, new_w = int(s * h / w), s
|
307 |
+
|
308 |
+
resized = cv2.resize(img_np, (new_w, new_h))
|
309 |
+
canvas = np.zeros((s, s, 3), dtype=np.float32)
|
310 |
+
|
311 |
+
pad_h = (s - new_h) // 2
|
312 |
+
pad_w = (s - new_w) // 2
|
313 |
+
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
|
314 |
+
|
315 |
+
input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
|
316 |
+
pred = model.run(None, {"img": input_tensor})[0].item()
|
317 |
+
scores.append(float(np.clip(pred * 10.0, 0.0, 10.0)))
|
318 |
except Exception as e:
|
319 |
+
print(f"Error processing image: {e}")
|
320 |
scores.append(None)
|
321 |
|
322 |
return scores
|
323 |
|
324 |
+
def _calculate_final_score(self, result: EvaluationResult, selected_models: List[str]) -> Optional[float]:
|
325 |
+
"""Calculate final score from selected model results."""
|
326 |
+
scores = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
|
328 |
+
for model in selected_models:
|
329 |
+
score = getattr(result, model, None)
|
330 |
+
if score is not None:
|
331 |
+
scores.append(score)
|
332 |
|
333 |
+
return float(np.mean(scores)) if scores else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
+
def results_to_dataframe(self, results: List[EvaluationResult]) -> pd.DataFrame:
|
336 |
+
"""Convert results to pandas DataFrame."""
|
337 |
+
data = []
|
338 |
+
for result in results:
|
339 |
+
row = {
|
340 |
+
'File Name': result.file_name,
|
341 |
+
'Final Score': result.final_score,
|
342 |
+
}
|
343 |
+
if result.aesthetic_shadow is not None:
|
344 |
+
row['Aesthetic Shadow'] = result.aesthetic_shadow
|
345 |
+
if result.waifu_scorer is not None:
|
346 |
+
row['Waifu Scorer'] = result.waifu_scorer
|
347 |
+
if result.aesthetic_v2_5 is not None:
|
348 |
+
row['Aesthetic V2.5'] = result.aesthetic_v2_5
|
349 |
+
if result.anime_aesthetic is not None:
|
350 |
+
row['Anime Aesthetic'] = result.anime_aesthetic
|
351 |
+
data.append(row)
|
352 |
+
|
353 |
+
return pd.DataFrame(data)
|
354 |
+
|
355 |
+
def optimize_batch_size(self, sample_images: List[Image.Image]) -> int:
|
356 |
+
"""Automatically determine optimal batch size."""
|
357 |
+
if not sample_images:
|
358 |
+
return 1
|
359 |
+
|
360 |
+
test_image = sample_images[0]
|
361 |
batch_size = 1
|
362 |
+
max_test = min(16, len(sample_images))
|
363 |
|
364 |
+
while batch_size <= max_test:
|
365 |
try:
|
366 |
+
test_batch = [test_image] * batch_size
|
367 |
+
# Test with a lightweight model
|
368 |
+
if self.loader.models.get('aesthetic_shadow'):
|
369 |
+
_ = self.loader.models['aesthetic_shadow'](test_batch)
|
|
|
370 |
batch_size *= 2
|
371 |
except Exception:
|
372 |
break
|
373 |
|
374 |
+
optimal = max(1, batch_size // 2)
|
375 |
+
return min(optimal, 8) # Cap at reasonable size
|
376 |
+
|
377 |
+
|
378 |
+
def create_interface():
|
379 |
+
"""Create the Gradio interface."""
|
380 |
+
evaluator = ImageEvaluator()
|
381 |
+
|
382 |
+
# Available models
|
383 |
+
model_choices = [
|
384 |
+
("Aesthetic Shadow", "aesthetic_shadow"),
|
385 |
+
("Waifu Scorer", "waifu_scorer"),
|
386 |
+
("Aesthetic V2.5", "aesthetic_v2_5"),
|
387 |
+
("Anime Aesthetic", "anime_aesthetic")
|
388 |
+
]
|
389 |
+
available_models = [choice[1] for choice in model_choices]
|
390 |
+
|
391 |
+
with gr.Blocks(title="Image Evaluation Tool", theme=gr.themes.Soft()) as app:
|
392 |
+
gr.Markdown("""
|
393 |
+
# π¨ Modern Image Evaluation Tool
|
394 |
+
|
395 |
+
Upload images to evaluate them using state-of-the-art aesthetic and quality prediction models.
|
396 |
+
|
397 |
+
**Features:**
|
398 |
+
- Multiple AI models for comprehensive evaluation
|
399 |
+
- Batch processing with automatic optimization
|
400 |
+
- Interactive results table with sorting and filtering
|
401 |
+
- CSV export functionality
|
402 |
+
- Real-time progress tracking
|
403 |
+
""")
|
404 |
+
|
405 |
+
with gr.Row():
|
406 |
+
with gr.Column(scale=1):
|
407 |
+
# Input components
|
408 |
+
input_files = gr.File(
|
409 |
+
label="π Upload Images",
|
410 |
+
file_count="multiple",
|
411 |
+
file_types=["image"]
|
412 |
+
)
|
413 |
+
|
414 |
+
model_selection = gr.CheckboxGroup(
|
415 |
+
choices=model_choices,
|
416 |
+
value=available_models,
|
417 |
+
label="π€ Select Models",
|
418 |
+
info="Choose which models to use for evaluation"
|
419 |
+
)
|
420 |
+
|
421 |
+
with gr.Row():
|
422 |
+
auto_batch = gr.Checkbox(
|
423 |
+
label="π Auto Batch Size",
|
424 |
+
value=True,
|
425 |
+
info="Automatically optimize batch size"
|
426 |
+
)
|
427 |
+
|
428 |
+
manual_batch = gr.Slider(
|
429 |
+
minimum=1,
|
430 |
+
maximum=16,
|
431 |
+
value=4,
|
432 |
+
step=1,
|
433 |
+
label="π Batch Size",
|
434 |
+
interactive=False,
|
435 |
+
info="Manual batch size (when auto is disabled)"
|
436 |
+
)
|
437 |
+
|
438 |
+
evaluate_btn = gr.Button(
|
439 |
+
"π Evaluate Images",
|
440 |
+
variant="primary",
|
441 |
+
size="lg"
|
442 |
+
)
|
443 |
+
|
444 |
+
clear_btn = gr.Button("ποΈ Clear Results", variant="secondary")
|
445 |
+
|
446 |
+
with gr.Column(scale=2):
|
447 |
+
# Progress and status
|
448 |
+
progress_bar = gr.Progress()
|
449 |
+
status_text = gr.Textbox(
|
450 |
+
label="π Status",
|
451 |
+
interactive=False,
|
452 |
+
max_lines=2
|
453 |
+
)
|
454 |
+
|
455 |
+
# Results display
|
456 |
+
results_table = gr.DataFrame(
|
457 |
+
label="π Evaluation Results",
|
458 |
+
interactive=False,
|
459 |
+
wrap=True,
|
460 |
+
max_height=400
|
461 |
+
)
|
462 |
+
|
463 |
+
# Export functionality
|
464 |
+
with gr.Row():
|
465 |
+
export_csv = gr.Button("π₯ Export CSV", variant="secondary")
|
466 |
+
download_file = gr.File(
|
467 |
+
label="πΎ Download",
|
468 |
+
visible=False
|
469 |
+
)
|
470 |
|
471 |
+
# State management
|
472 |
+
results_state = gr.State([])
|
473 |
+
|
474 |
+
# Event handlers
|
475 |
+
def toggle_batch_slider(auto_enabled):
|
476 |
+
return gr.update(interactive=not auto_enabled)
|
477 |
+
|
478 |
+
def process_images(files, models, auto_batch_enabled, manual_batch_size, progress=gr.Progress()):
|
479 |
+
if not files or not models:
|
480 |
+
return "β Please upload images and select at least one model", pd.DataFrame(), []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
|
482 |
+
try:
|
483 |
+
# Load images
|
484 |
+
images = []
|
485 |
+
file_names = []
|
486 |
+
|
487 |
+
progress(0.1, "π Loading images...")
|
488 |
+
|
489 |
+
for file in files:
|
490 |
+
try:
|
491 |
+
img = Image.open(file.name).convert("RGB")
|
492 |
+
images.append(img)
|
493 |
+
file_names.append(os.path.basename(file.name))
|
494 |
+
except Exception as e:
|
495 |
+
print(f"Error loading {file.name}: {e}")
|
496 |
+
|
497 |
+
if not images:
|
498 |
+
return "β No valid images loaded", pd.DataFrame(), []
|
499 |
+
|
500 |
+
# Determine batch size
|
501 |
+
if auto_batch_enabled:
|
502 |
+
batch_size = evaluator.optimize_batch_size(images[:2])
|
503 |
+
progress(0.2, f"π§ Optimized batch size: {batch_size}")
|
504 |
+
else:
|
505 |
+
batch_size = int(manual_batch_size)
|
506 |
+
|
507 |
+
# Process images
|
508 |
+
def progress_callback(prog, msg):
|
509 |
+
progress(0.2 + prog * 0.7, msg)
|
510 |
+
|
511 |
+
results = evaluator.evaluate_images(
|
512 |
+
images, file_names, models, batch_size, progress_callback
|
513 |
+
)
|
514 |
+
|
515 |
+
progress(0.95, "π Generating results table...")
|
516 |
+
|
517 |
+
# Convert to DataFrame
|
518 |
+
df = evaluator.results_to_dataframe(results)
|
519 |
+
df = df.sort_values('Final Score', ascending=False, na_position='last')
|
520 |
+
|
521 |
+
progress(1.0, f"β
Processed {len(results)} images successfully!")
|
522 |
+
|
523 |
+
return f"β
Evaluated {len(results)} images using {len(models)} models", df, results
|
524 |
+
|
525 |
+
except Exception as e:
|
526 |
+
return f"β Error during processing: {str(e)}", pd.DataFrame(), []
|
527 |
+
|
528 |
+
def update_results_table(models, current_results):
|
529 |
+
if not current_results:
|
530 |
+
return pd.DataFrame()
|
531 |
|
532 |
+
# Recalculate final scores based on selected models
|
533 |
+
for result in current_results:
|
534 |
+
result.final_score = evaluator._calculate_final_score(result, models)
|
|
|
535 |
|
536 |
+
df = evaluator.results_to_dataframe(current_results)
|
537 |
+
return df.sort_values('Final Score', ascending=False, na_position='last')
|
538 |
+
|
539 |
+
def export_results(current_results):
|
540 |
+
if not current_results:
|
541 |
+
return gr.update(visible=False)
|
542 |
|
543 |
+
df = evaluator.results_to_dataframe(current_results)
|
544 |
+
csv_path = evaluator.temp_dir / "evaluation_results.csv"
|
545 |
+
df.to_csv(csv_path, index=False)
|
|
|
|
|
|
|
|
|
546 |
|
547 |
+
return gr.update(value=str(csv_path), visible=True)
|
548 |
|
549 |
+
def clear_all():
|
550 |
+
return (
|
551 |
+
"π Ready for new evaluation",
|
552 |
+
pd.DataFrame(),
|
553 |
+
[],
|
554 |
+
gr.update(visible=False)
|
555 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
|
557 |
+
# Wire up events
|
558 |
+
auto_batch.change(
|
559 |
+
toggle_batch_slider,
|
560 |
+
inputs=[auto_batch],
|
561 |
+
outputs=[manual_batch]
|
562 |
+
)
|
563 |
|
564 |
+
evaluate_btn.click(
|
565 |
+
process_images,
|
566 |
+
inputs=[input_files, model_selection, auto_batch, manual_batch],
|
567 |
+
outputs=[status_text, results_table, results_state]
|
568 |
+
)
|
569 |
|
570 |
+
model_selection.change(
|
571 |
+
update_results_table,
|
572 |
+
inputs=[model_selection, results_state],
|
573 |
+
outputs=[results_table]
|
574 |
+
)
|
575 |
|
576 |
+
export_csv.click(
|
577 |
+
export_results,
|
578 |
+
inputs=[results_state],
|
579 |
+
outputs=[download_file]
|
580 |
+
)
|
581 |
+
|
582 |
+
clear_btn.click(
|
583 |
+
clear_all,
|
584 |
+
outputs=[status_text, results_table, results_state, download_file]
|
585 |
+
)
|
|
|
|
|
|
|
|
|
586 |
|
587 |
+
# Initial setup
|
588 |
+
app.load(lambda: "π Ready for evaluation - Upload images to get started!")
|
589 |
+
|
590 |
+
return app
|
591 |
+
|
592 |
+
|
593 |
+
if __name__ == "__main__":
|
594 |
+
app = create_interface()
|
595 |
+
app.queue(max_size=10).launch(
|
596 |
+
server_name="0.0.0.0",
|
597 |
+
server_port=7860,
|
598 |
+
share=False,
|
599 |
+
show_error=True
|
600 |
+
)
|