VOIDER commited on
Commit
027d32e
·
verified ·
1 Parent(s): 27c67d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +772 -658
app.py CHANGED
@@ -1,710 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
- import shutil
3
  import tempfile
4
- import base64
5
- import asyncio
6
- from io import BytesIO
 
 
 
 
 
7
 
8
  import cv2
 
9
  import numpy as np
 
10
  import torch
11
- import onnxruntime as rt
12
  from PIL import Image
13
- import gradio as gr
14
  from transformers import pipeline
15
  from huggingface_hub import hf_hub_download
16
 
17
- # Import necessary function from aesthetic_predictor_v2_5
18
- from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
19
-
20
-
21
- #####################################
22
- # Model Definitions #
23
- #####################################
24
-
25
- class MLP(torch.nn.Module):
26
- """A simple multi-layer perceptron for image feature regression."""
27
- def __init__(self, input_size: int, batch_norm: bool = True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  super().__init__()
29
- self.input_size = input_size
30
- self.layers = torch.nn.Sequential(
31
- torch.nn.Linear(self.input_size, 2048),
32
- torch.nn.ReLU(),
33
- torch.nn.BatchNorm1d(2048) if batch_norm else torch.nn.Identity(),
34
- torch.nn.Dropout(0.3),
35
- torch.nn.Linear(2048, 512),
36
- torch.nn.ReLU(),
37
- torch.nn.BatchNorm1d(512) if batch_norm else torch.nn.Identity(),
38
- torch.nn.Dropout(0.3),
39
- torch.nn.Linear(512, 256),
40
- torch.nn.ReLU(),
41
- torch.nn.BatchNorm1d(256) if batch_norm else torch.nn.Identity(),
42
- torch.nn.Dropout(0.2),
43
- torch.nn.Linear(256, 128),
44
- torch.nn.ReLU(),
45
- torch.nn.BatchNorm1d(128) if batch_norm else torch.nn.Identity(),
46
- torch.nn.Dropout(0.1),
47
- torch.nn.Linear(128, 32),
48
- torch.nn.ReLU(),
49
- torch.nn.Linear(32, 1)
50
- )
51
-
 
 
52
  def forward(self, x: torch.Tensor) -> torch.Tensor:
53
- return self.layers(x)
54
 
55
 
56
- class WaifuScorer:
57
- """WaifuScorer model that uses CLIP for feature extraction and a custom MLP for scoring."""
58
- def __init__(self, model_path: str = None, device: str = 'cuda', cache_dir: str = None, verbose: bool = False):
59
- self.verbose = verbose
 
60
  self.device = device
61
  self.dtype = torch.float32
62
- self.available = False
63
-
 
 
 
 
 
 
 
64
  try:
65
- import clip # local import to avoid dependency issues
66
- # Set default model path if not provided
67
- if model_path is None:
68
- model_path = "Eugeoter/waifu-scorer-v3/model.pth"
69
- if self.verbose:
70
- print(f"Model path not provided. Using default: {model_path}")
71
-
72
- # Download model if not found locally
73
- if not os.path.isfile(model_path):
74
- username, repo_id, model_name = model_path.split("/")[-3:]
75
- model_path = hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=cache_dir)
76
-
77
- if self.verbose:
78
- print(f"Loading WaifuScorer model from: {model_path}")
79
-
80
- # Initialize MLP model
81
- self.mlp = MLP(input_size=768)
82
- # Load state dict
83
  if model_path.endswith(".safetensors"):
84
  from safetensors.torch import load_file
85
  state_dict = load_file(model_path)
86
  else:
87
- state_dict = torch.load(model_path, map_location=device)
88
- self.mlp.load_state_dict(state_dict)
89
- self.mlp.to(device)
90
- self.mlp.eval()
91
-
92
- # Load CLIP model for image preprocessing and feature extraction
93
- self.clip_model, self.preprocess = clip.load("ViT-L/14", device=device)
94
- self.available = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  except Exception as e:
96
- print(f"Unable to initialize WaifuScorer: {e}")
97
-
98
- @torch.no_grad()
99
- def __call__(self, images):
100
- if not self.available:
101
- return [None] * (len(images) if isinstance(images, list) else 1)
102
- if isinstance(images, Image.Image):
103
- images = [images]
104
- n = len(images)
105
- # Ensure at least two images for CLIP model compatibility
106
- if n == 1:
107
- images = images * 2
108
-
109
- image_tensors = [self.preprocess(img).unsqueeze(0) for img in images]
110
- image_batch = torch.cat(image_tensors).to(self.device)
111
- image_features = self.clip_model.encode_image(image_batch)
112
- # Normalize features
113
- norm = image_features.norm(2, dim=-1, keepdim=True)
114
- norm[norm == 0] = 1
115
- im_emb = (image_features / norm).to(device=self.device, dtype=self.dtype)
116
- predictions = self.mlp(im_emb)
117
- scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
118
- return scores[:n]
119
-
120
-
121
- #####################################
122
- # Aesthetic Predictor Functions #
123
- #####################################
124
-
125
- def load_aesthetic_predictor_v2_5():
126
- """Load and return an instance of Aesthetic Predictor V2.5 with batch processing support."""
127
- class AestheticPredictorV2_5_Impl:
128
- def __init__(self):
129
- print("Loading Aesthetic Predictor V2.5...")
130
- self.model, self.preprocessor = convert_v2_5_from_siglip(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  low_cpu_mem_usage=True,
132
  trust_remote_code=True,
133
  )
 
134
  if torch.cuda.is_available():
135
- self.model = self.model.to(torch.bfloat16).cuda()
136
-
137
- def inference(self, image):
138
- if isinstance(image, list):
139
- images_rgb = [img.convert("RGB") for img in image]
140
- pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt").pixel_values
141
- if torch.cuda.is_available():
142
- pixel_values = pixel_values.to(torch.bfloat16).cuda()
143
- with torch.inference_mode():
144
- scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
145
- if scores.ndim == 0:
146
- scores = np.array([scores])
147
- return scores.tolist()
148
- else:
149
- pixel_values = self.preprocessor(images=image.convert("RGB"), return_tensors="pt").pixel_values
150
- if torch.cuda.is_available():
151
- pixel_values = pixel_values.to(torch.bfloat16).cuda()
152
- with torch.inference_mode():
153
- score = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
154
- return score
155
-
156
- return AestheticPredictorV2_5_Impl()
157
-
158
-
159
- def load_anime_aesthetic_model():
160
- """Load and return the Anime Aesthetic ONNX model."""
161
- model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
162
- return rt.InferenceSession(model_path, providers=['CPUExecutionProvider'])
163
-
164
-
165
- def predict_anime_aesthetic(img, model):
166
- """Predict Anime Aesthetic score for a single image."""
167
- img_np = np.array(img).astype(np.float32) / 255.0
168
- s = 768
169
- h, w = img_np.shape[:2]
170
- if h > w:
171
- new_h, new_w = s, int(s * w / h)
172
- else:
173
- new_h, new_w = int(s * h / w), s
174
- resized = cv2.resize(img_np, (new_w, new_h))
175
- # Center the resized image in a square canvas
176
- canvas = np.zeros((s, s, 3), dtype=np.float32)
177
- pad_h = (s - new_h) // 2
178
- pad_w = (s - new_w) // 2
179
- canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
180
- # Prepare input for model
181
- input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
182
- pred = model.run(None, {"img": input_tensor})[0].item()
183
- return pred
184
-
185
-
186
- #####################################
187
- # Image Evaluation Tool #
188
- #####################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  class ModelManager:
191
- """Manages model loading and processing requests using a queue."""
192
- def __init__(self):
193
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
194
- print(f"Using device: {self.device}")
195
- print("Loading models... This may take some time.")
196
-
197
- # Load models once during initialization
198
- print("Loading Aesthetic Shadow model...")
199
- self.aesthetic_shadow_model = pipeline("image-classification", model="NeoChen1024/aesthetic-shadow-v2-backup", device=self.device)
200
- print("Loading Waifu Scorer model...")
201
- self.waifu_scorer_model = WaifuScorer(device=self.device, verbose=True)
202
- print("Loading Aesthetic Predictor V2.5...")
203
- self.aesthetic_predictor_model = load_aesthetic_predictor_v2_5()
204
- print("Loading Anime Aesthetic model...")
205
- self.anime_aesthetic_model = load_anime_aesthetic_model()
206
- print("All models loaded successfully!")
207
-
208
- self.available_models = {
209
- "aesthetic_shadow": {"name": "Aesthetic Shadow", "process": self._process_aesthetic_shadow, "model": self.aesthetic_shadow_model},
210
- "waifu_scorer": {"name": "Waifu Scorer", "process": self._process_waifu_scorer, "model": self.waifu_scorer_model},
211
- "aesthetic_predictor_v2_5": {"name": "Aesthetic V2.5", "process": self._process_aesthetic_predictor_v2_5, "model": self.aesthetic_predictor_model},
212
- "anime_aesthetic": {"name": "Anime Score", "process": self._process_anime_aesthetic, "model": self.anime_aesthetic_model},
 
 
 
 
 
 
 
 
 
 
213
  }
214
- self.processing_queue: asyncio.Queue = asyncio.Queue()
215
- self.worker_task = None # Initialize worker_task to None
216
- self.temp_dir = tempfile.mkdtemp()
217
-
218
- async def start_worker(self):
219
- """Start the background worker task."""
220
- if self.worker_task is None:
221
- self.worker_task = asyncio.create_task(self._worker())
222
-
223
- async def _worker(self):
224
- """Background worker to process image evaluation requests from the queue."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  while True:
226
- request = await self.processing_queue.get()
227
- if request is None: # Shutdown signal
228
- self.processing_queue.task_done()
229
  break
 
230
  try:
231
- results = await self._process_request(request)
232
- request['results_future'].set_result(results) # Fulfill the future with results
233
  except Exception as e:
234
- request['results_future'].set_exception(e) # Set exception if processing fails
235
  finally:
236
- self.processing_queue.task_done()
237
-
238
- async def submit_request(self, request_data):
239
- """Submit a new image processing request to the queue."""
240
- results_future = asyncio.Future() # Future to hold the results
241
- request = {**request_data, 'results_future': results_future}
242
- await self.processing_queue.put(request)
243
- return await results_future # Wait for and return results
244
-
245
- async def _process_request(self, request):
246
- """Process a single image evaluation request."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  file_paths = request['file_paths']
 
248
  auto_batch = request['auto_batch']
249
  manual_batch_size = request['manual_batch_size']
250
- selected_models = request['selected_models']
251
- log_events = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  images = []
253
- file_names = []
254
- final_results = []
255
-
256
- # Prepare images and file names
257
- total_files = len(file_paths)
258
- log_events.append(f"Starting to load {total_files} images...")
259
- for f in file_paths:
260
  try:
261
- img = Image.open(f).convert("RGB")
262
  images.append(img)
263
- file_names.append(os.path.basename(f))
264
  except Exception as e:
265
- log_events.append(f"Error opening {f}: {e}")
266
-
267
- if not images:
268
- log_events.append("No valid images loaded.")
269
- return [], log_events, 0, manual_batch_size
270
-
271
- log_events.append("Images loaded. Determining batch size...")
272
-
273
- try:
274
- manual_batch_size = int(manual_batch_size) if manual_batch_size is not None else 1
275
- except ValueError:
276
- manual_batch_size = 1
277
- log_events.append("Invalid manual batch size. Defaulting to 1.")
278
-
279
- optimal_batch = self.auto_tune_batch_size(images) if auto_batch else manual_batch_size
280
- log_events.append(f"Using batch size: {optimal_batch}")
281
-
282
- total_images = len(images)
283
- for i in range(0, total_images, optimal_batch):
284
- batch_images = images[i:i+optimal_batch]
285
- batch_file_names = file_names[i:i+optimal_batch]
286
- batch_index = i // optimal_batch + 1
287
- log_events.append(f"Processing batch {batch_index}: images {i+1} to {min(i+optimal_batch, total_images)}")
288
-
289
- batch_results = {}
290
-
291
- # Process selected models
292
- for model_key in selected_models:
293
- if self.available_models[model_key]['selected']: # Ensure model is selected
294
- batch_results[model_key] = await self.available_models[model_key]['process'](batch_images, log_events) # Removed 'self' here
295
- else:
296
- batch_results[model_key] = [None] * len(batch_images)
297
-
298
- # Combine results and create final results list
299
- for j in range(len(batch_images)):
300
- scores_to_average = []
301
- for model_key in selected_models:
302
- if self.available_models[model_key]['selected']: # Ensure model is selected
303
- score = batch_results[model_key][j]
304
- if score is not None:
305
- scores_to_average.append(score)
306
-
307
- final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None
308
- thumbnail = batch_images[j].copy()
309
- thumbnail.thumbnail((200, 200))
310
- result = {
311
- 'file_name': batch_file_names[j],
312
- 'img_data': self.image_to_base64(thumbnail), # Keep this for the HTML display
313
- 'final_score': final_score,
314
- }
315
- for model_key in selected_models: # Add model scores to result
316
- if self.available_models[model_key]['selected']:
317
- result[model_key] = batch_results[model_key][j]
318
- final_results.append(result)
319
-
320
- log_events.append("All images processed.")
321
- return final_results, log_events, 100, optimal_batch
322
-
323
-
324
- def image_to_base64(self, image: Image.Image) -> str:
325
- """Convert PIL Image to base64 encoded JPEG string."""
326
- buffered = BytesIO()
327
- image.save(buffered, format="JPEG")
328
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
329
-
330
- def auto_tune_batch_size(self, images: list) -> int:
331
- """Automatically determine the optimal batch size for processing."""
332
  batch_size = 1
333
- max_batch = len(images)
334
  test_image = images[0:1]
335
- while batch_size <= max_batch:
 
336
  try:
337
- if "aesthetic_shadow" in self.available_models and self.available_models["aesthetic_shadow"]['selected']: # Check if model is available and selected
338
- _ = self.available_models["aesthetic_shadow"]['model'](test_image * batch_size)
339
- if "waifu_scorer" in self.available_models and self.available_models["waifu_scorer"]['selected']: # Check if model is available and selected
340
- _ = self.available_models["waifu_scorer"]['model'](test_image * batch_size)
341
- if "aesthetic_predictor_v2_5" in self.available_models and self.available_models["aesthetic_predictor_v2_5"]['selected']: # Check if model is available and selected
342
- _ = self.available_models["aesthetic_predictor_v2_5"]['model'].inference(test_image * batch_size)
343
  batch_size *= 2
344
- if batch_size > max_batch:
345
- break
346
  except Exception:
347
  break
348
- optimal = max(1, batch_size // 2)
349
- if optimal > 64:
350
- optimal = 64
351
- print(f"Optimal batch size determined: {optimal}")
352
- print(f"Optimal batch size determined: {optimal}")
353
- return optimal
354
-
355
- async def _process_aesthetic_shadow(self, batch_images, log_events):
356
- try:
357
- shadow_results = self.available_models["aesthetic_shadow"]['model'](batch_images)
358
- log_events.append("Aesthetic Shadow processed for batch.")
359
- except Exception as e:
360
- log_events.append(f"Error in Aesthetic Shadow: {e}")
361
- shadow_results = [None] * len(batch_images)
362
- aesthetic_shadow_scores = []
363
- for res in shadow_results:
364
- try:
365
- hq_score = next(p for p in res if p['label'] == 'hq')['score']
366
- score = float(np.clip(hq_score * 10.0, 0.0, 10.0))
367
- except Exception:
368
- score = None
369
- aesthetic_shadow_scores.append(score)
370
- log_events.append("Aesthetic Shadow scores computed for batch.")
371
- return aesthetic_shadow_scores
372
-
373
- async def _process_waifu_scorer(self, batch_images, log_events):
374
- try:
375
- waifu_scores = self.available_models["waifu_scorer"]['model'](batch_images)
376
- waifu_scores = [float(np.clip(s, 0.0, 10.0)) if s is not None else None for s in waifu_scores]
377
- log_events.append("Waifu Scorer processed for batch.")
378
- except Exception as e:
379
- log_events.append(f"Error in Waifu Scorer: {e}")
380
- waifu_scores = [None] * len(batch_images)
381
- return waifu_scores
382
-
383
- async def _process_aesthetic_predictor_v2_5(self, batch_images, log_events):
384
- try:
385
- v2_5_scores = self.available_models["aesthetic_predictor_v2_5"]['model'].inference(batch_images)
386
- v2_5_scores = [float(np.round(np.clip(s, 0.0, 10.0), 4)) if s is not None else None for s in v2_5_scores]
387
- log_events.append("Aesthetic Predictor V2.5 processed for batch.")
388
- except Exception as e:
389
- log_events.append(f"Error in Aesthetic Predictor V2.5: {e}")
390
- v2_5_scores = [None] * len(batch_images)
391
- return v2_5_scores
392
-
393
- async def _process_anime_aesthetic(self, batch_images, log_events):
394
- anime_scores = []
395
- for j, img in enumerate(batch_images):
396
- try:
397
- score = predict_anime_aesthetic(img, self.available_models["anime_aesthetic"]['model'])
398
- anime_scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
399
- log_events.append(f"Anime Aesthetic processed for image {j + 1}.")
400
- except Exception as e:
401
- log_events.append(f"Error in Anime Aesthetic for image {j + 1}: {e}")
402
- anime_scores.append(None)
403
- return anime_scores
404
-
405
-
406
- def _generate_progress_html(self, percentage: float) -> str:
407
- """Generate HTML for a progress bar given a percentage."""
408
- return f"""
409
- <div style="width:100%;background-color:#ddd; border-radius:5px;">
410
- <div style="width:{percentage:.1f}%; background-color:#4CAF50; text-align:center; padding:5px 0; border-radius:5px;">
411
- {percentage:.1f}%
412
- </div>
413
- </div>
414
- """
415
-
416
- def _format_logs(self, logs: list) -> str:
417
- """Format log events into an HTML string."""
418
- return "<div style='max-height:300px; overflow-y:auto;'>" + "<br>".join(logs) + "</div>"
419
-
420
- def sort_results(self, results, sort_by: str = "Final Score") -> list:
421
- """Sort results based on the specified column."""
422
- key_map = {
423
- "Final Score": "final_score",
424
- "File Name": "file_name",
425
- "Aesthetic Shadow": "aesthetic_shadow",
426
- "Waifu Scorer": "waifu_scorer",
427
- "Aesthetic V2.5": "aesthetic_predictor_v2_5",
428
- "Anime Score": "anime_aesthetic"
429
  }
430
- key = key_map.get(sort_by, "final_score")
431
- reverse = sort_by != "File Name"
432
- results.sort(key=lambda r: r.get(key) if r.get(key) is not None else (-float('inf') if not reverse else float('inf')), reverse=reverse)
433
- return results
434
-
435
- def generate_html_table(self, results: list, selected_models) -> str:
436
- """Generate an HTML table to display the evaluation results."""
437
- table_html = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  <style>
439
- .results-table { width: 100%; border-collapse: collapse; margin: 20px 0; font-family: Arial, sans-serif; }
440
- .results-table th, .results-table td { color: #eee; border: 1px solid #ddd; padding: 8px; text-align: center; }
441
- .results-table th { font-weight: bold; }
442
- .results-table tr:nth-child(even) { background-color: transparent; }
443
- .results-table tr:hover { background-color: rgba(255, 255, 255, 0.1); }
444
- .image-preview { max-width: 150px; max-height: 150px; display: block; margin: 0 auto; }
445
- .good-score { color: #0f0; font-weight: bold; }
446
- .bad-score { color: #f00; font-weight: bold; }
447
- .medium-score { color: orange; font-weight: bold; }
 
 
 
 
 
 
 
 
 
 
 
448
  </style>
449
- <table class="results-table">
450
- <thead>
451
- <tr>
452
- <th>Image</th>
453
- <th>File Name</th>
454
  """
455
- visible_models = [] # Keep track of visible model columns
456
- if "aesthetic_shadow" in selected_models:
457
- table_html += "<th>Aesthetic Shadow</th>"
458
- visible_models.append("aesthetic_shadow")
459
- if "waifu_scorer" in selected_models:
460
- table_html += "<th>Waifu Scorer</th>"
461
- visible_models.append("waifu_scorer")
462
- if "aesthetic_predictor_v2_5" in selected_models:
463
- table_html += "<th>Aesthetic V2.5</th>"
464
- visible_models.append("aesthetic_predictor_v2_5")
465
- if "anime_aesthetic" in selected_models:
466
- table_html += "<th>Anime Score</th>"
467
- visible_models.append("anime_aesthetic")
468
- table_html += "<th>Final Score</th>"
469
- table_html += "</tr></thead><tbody>"
470
-
471
  for result in results:
472
- table_html += "<tr>"
473
- table_html += f'<td><img src="data:image/jpeg;base64,{result["img_data"]}" class="image-preview"></td>'
474
- table_html += f'<td>{result["file_name"]}</td>'
475
- for model_key in visible_models: # Iterate through visible models only
476
- score = result.get(model_key)
477
- table_html += self._format_score_cell(score)
478
-
479
- score = result.get("final_score")
480
- table_html += self._format_score_cell(score)
481
- table_html += "</tr>"
482
- table_html += """</tbody></table>"""
483
- return table_html
484
-
485
- def _format_score_cell(self, score):
486
- score_str = f"{score:.4f}" if isinstance(score, (int, float)) else "N/A"
487
- score_class = ""
488
- if isinstance(score, (int, float)):
489
- if score >= 7:
490
- score_class = "good-score"
491
- elif score >= 5:
492
- score_class = "medium-score"
493
- else:
494
- score_class = "bad-score"
495
- return f'<td class="{score_class}">{score_str}</td>'
496
-
497
-
498
- def cleanup(self):
499
- """Clean up temporary directories and shutdown worker."""
500
- if os.path.exists(self.temp_dir):
501
- shutil.rmtree(self.temp_dir)
502
- if self.worker_task is not None: # Check if worker_task was started
503
- asyncio.run(self.shutdown()) # Shutdown worker gracefully
504
-
505
- async def shutdown(self):
506
- """Send shutdown signal to worker and wait for it to finish."""
507
- if self.worker_task is not None: # Check if worker_task was started
508
- await self.processing_queue.put(None) # Send shutdown signal
509
- await self.worker_task # Wait for worker task to complete
510
- await self.processing_queue.join() # Wait for queue to be empty
511
-
512
-
513
- #####################################
514
- # Interface #
515
- #####################################
516
-
517
- model_manager = ModelManager() # Initialize ModelManager once outside the interface function
518
-
519
- def create_interface():
520
- sort_options = ["Final Score", "File Name", "Aesthetic Shadow", "Waifu Scorer", "Aesthetic V2.5", "Anime Score"]
521
- model_options = ["aesthetic_shadow", "waifu_scorer", "aesthetic_predictor_v2_5", "anime_aesthetic"]
522
-
523
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
524
- gr.Markdown("""
525
- # Comprehensive Image Evaluation Tool
526
-
527
- Upload images to evaluate them using multiple aesthetic and quality prediction models.
528
-
529
- **New features:**
530
- - **Dynamic Final Score:** Final score recalculates on model selection changes.
531
- - **Model Selection:** Choose which models to use for evaluation.
532
- - **Dynamic Table Updates:** Table updates automatically based on model selection.
533
- - **Automatic Sorting:** Table is automatically sorted by 'Final Score'.
534
- - **Detailed Logs:** See major processing events (limited to the last 10).
535
- - **Progress Bar:** Visual indication of processing status.
536
- - **Asynchronous Updates:** Streaming status and logs during processing.
537
- - **Batch Size Controls:** Choose manual batch size or let the tool auto-detect it.
538
- - **Download Results:** Export the evaluation results as CSV.
539
- """)
540
-
541
- with gr.Row():
542
- with gr.Column(scale=1):
543
- input_images = gr.Files(label="Upload Images", file_count="multiple")
544
- model_checkboxes = gr.CheckboxGroup(model_options, label="Select Models", value=model_options, info="Choose models for evaluation.")
545
- auto_batch_checkbox = gr.Checkbox(label="Automatic Batch Size Detection", value=False, info="Enable to automatically determine the optimal batch size.")
546
- batch_size_input = gr.Number(label="Batch Size", value=1, interactive=True, info="Manually specify the batch size if auto-detection is disabled.")
547
- sort_dropdown = gr.Dropdown(sort_options, value="Final Score", label="Sort by", info="Select the column to sort results by.")
548
- process_btn = gr.Button("Evaluate Images", variant="primary")
549
- clear_btn = gr.Button("Clear Results")
550
- download_csv = gr.Button("Download CSV", variant="secondary")
551
-
552
- with gr.Column(scale=2):
553
- progress_bar = gr.HTML(label="Progress Bar", value="""
554
- <div style='width:100%;background-color:#ddd;'>
555
- <div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div>
556
- </div>
557
- """)
558
- log_window = gr.HTML(label="Detailed Logs", value="<div style='max-height:300px; overflow-y:auto;'>Logs will appear here...</div>")
559
- status_html = gr.HTML(label="Status")
560
- output_html = gr.HTML(label="Evaluation Results")
561
- download_file_output = gr.File() # Initialize gr.File component without filename
562
- global_results_state = gr.State([]) # Initialize a global state to hold results
563
-
564
- # Function to convert results to CSV format, excluding 'img_data'.
565
- def results_to_csv(results, selected_models): # Take results as input
566
- import csv
567
- import io
568
- if not results:
569
- return None # Return None when no results are available
570
- output = io.StringIO()
571
- fieldnames = ['file_name', 'final_score'] # Base fieldnames
572
- for model_key in selected_models: # Add selected model names as fieldnames
573
- if model_key in selected_models: # Double check if model_key is indeed in selected_models list
574
- fieldnames.append(model_key)
575
-
576
- writer = csv.DictWriter(output, fieldnames=fieldnames)
577
- writer.writeheader()
578
- for res in results:
579
- row_dict = {'file_name': res['file_name'], 'final_score': res['final_score']} # Base data
580
- for model_key in selected_models: # Add selected model scores
581
- if model_key in selected_models: # Double check before accessing res[model_key]
582
- row_dict[model_key] = res.get(model_key, 'N/A') # Use get with default 'N/A' if model not in result (shouldn't happen but for safety)
583
- writer.writerow(row_dict)
584
- return output.getvalue()
585
-
586
-
587
- def update_batch_size_interactivity(auto_batch):
588
- return gr.update(interactive=not auto_batch)
589
-
590
- async def process_images_and_update(files, auto_batch, manual_batch, selected_models, current_results):
591
- file_paths = [f.name for f in files]
592
-
593
- # Prepare request data for the ModelManager
594
- request_data = {
595
- 'file_paths': file_paths,
596
- 'auto_batch': auto_batch,
597
- 'manual_batch_size': manual_batch,
598
- 'selected_models': {model: {'selected': model in selected_models} for model in model_options} # Pass model selections
599
- }
600
- # Submit request and get results from ModelManager
601
- results, logs, progress_percent, updated_batch = await model_manager.submit_request(request_data)
602
-
603
- updated_results = current_results + results # Append new results to current results
604
-
605
- html_table = model_manager.generate_html_table(updated_results, selected_models)
606
- progress_html = model_manager._generate_progress_html(progress_percent)
607
- log_html = model_manager._format_logs(logs[-10:])
608
-
609
- return status_html, html_table, log_html, progress_html, gr.update(value=updated_batch, interactive=not auto_batch), updated_results
610
-
611
-
612
- def update_table_sort(sort_by_column, selected_models, current_results):
613
- sorted_results = model_manager.sort_results(current_results, sort_by_column)
614
- return model_manager.generate_html_table(sorted_results, selected_models), sorted_results # Return sorted results
615
-
616
- def update_table_model_selection(selected_models, current_results):
617
- # Recalculate final scores based on selected models
618
- for result in current_results:
619
- scores_to_average = []
620
- for model_key in model_options: # Use model_options here, not available_models from manager in UI context
621
- if model_key in selected_models and model_key in model_manager.available_models and model_manager.available_models[model_key]['selected']: # consider only selected models from checkbox group and available_models
622
- score = result.get(model_key)
623
- if score is not None:
624
- scores_to_average.append(score)
625
- final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None
626
- result['final_score'] = final_score
627
-
628
- sorted_results = model_manager.sort_results(current_results, "Final Score") # Keep sorting by Final Score when models change
629
- return model_manager.generate_html_table(sorted_results, selected_models), sorted_results
630
-
631
-
632
- def clear_results():
633
- return (gr.update(value=""),
634
- gr.update(value=""),
635
- gr.update(value=""),
636
- gr.update(value="""
637
- <div style='width:100%;background-color:#ddd;'>
638
- <div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div>
639
- </div>
640
- """),
641
- gr.update(value=1),
642
- []) # Clear results state
643
-
644
- def download_results_csv_trigger(selected_models, current_results): # Changed function name to avoid conflict and clarify purpose
645
- csv_content = results_to_csv(current_results, selected_models)
646
- if csv_content is None:
647
- return None # Indicate no file to download
648
-
649
- # Create a temporary file to save the CSV data
650
- with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file:
651
- tmp_file.write(csv_content.encode())
652
- temp_file_path = tmp_file.name # Get the path to the temporary file
653
-
654
- return temp_file_path # Return the path to the temporary file
655
-
656
-
657
- # Set initial selection state for models in ModelManager (important!)
658
- for model_key in model_options:
659
- model_manager.available_models[model_key]['selected'] = True # Default to all selected initially
660
-
661
- auto_batch_checkbox.change(
662
- update_batch_size_interactivity,
663
- inputs=[auto_batch_checkbox],
664
- outputs=[batch_size_input]
665
- )
666
-
667
- process_btn.click(
668
- process_images_and_update,
669
- inputs=[input_images, auto_batch_checkbox, batch_size_input, model_checkboxes, global_results_state],
670
- outputs=[status_html, output_html, log_window, progress_bar, batch_size_input, global_results_state]
671
- )
672
- sort_dropdown.change(
673
- update_table_sort,
674
- inputs=[sort_dropdown, model_checkboxes, global_results_state],
675
- outputs=[output_html, global_results_state]
676
- )
677
- model_checkboxes.change( # Added change event for model checkboxes
678
- update_table_model_selection,
679
- inputs=[model_checkboxes, global_results_state],
680
- outputs=[output_html, global_results_state]
681
- )
682
- clear_btn.click(
683
- clear_results,
684
- inputs=[],
685
- outputs=[status_html, output_html, log_window, progress_bar, batch_size_input, global_results_state]
686
- )
687
- download_csv.click(
688
- download_results_csv_trigger, # Call the trigger function
689
- inputs=[model_checkboxes, global_results_state],
690
- outputs=[download_file_output] # Output is now the gr.File component
691
- )
692
- demo.load(lambda: update_table_sort("Final Score", model_options, []), inputs=None, outputs=[output_html, global_results_state]) # Initial sort and table render, pass empty initial results
693
- demo.load(model_manager.start_worker) # Start the worker task on demo load
694
-
695
- gr.Markdown("""
696
- ### Notes
697
- - Select models to use for evaluation using the checkboxes.
698
- - The 'Final Score' recalculates dynamically when models are selected/deselected.
699
- - The table updates automatically when models are selected/deselected and is always sorted by 'Final Score'.
700
- - The log window displays the most recent 10 events.
701
- - The progress bar shows overall processing status.
702
- - When 'Automatic Batch Size Detection' is enabled, the batch size field becomes disabled.
703
- - Use the download button to export your evaluation results as CSV.
704
- """)
705
-
706
- return demo
707
-
708
- if __name__ == "__main__":
709
- demo = create_interface()
710
- demo.queue().launch()
 
1
+ """
2
+ Modern Image Evaluation Tool with Aesthetic and Quality Prediction Models
3
+
4
+ This refactored version features:
5
+ - Modern async/await patterns with proper error handling
6
+ - Type hints throughout for better code maintainability
7
+ - Dependency injection and factory patterns
8
+ - Proper resource management with context managers
9
+ - Configuration-driven model loading
10
+ - Improved batch processing with memory optimization
11
+ - Clean separation of concerns with proper abstraction layers
12
+ """
13
+
14
+ import asyncio
15
+ import base64
16
+ import csv
17
+ import logging
18
  import os
 
19
  import tempfile
20
+ import shutil
21
+ from contextlib import asynccontextmanager
22
+ from dataclasses import dataclass, field
23
+ from enum import Enum
24
+ from io import BytesIO, StringIO
25
+ from pathlib import Path
26
+ from typing import Dict, List, Optional, Protocol, Tuple, Union, Any
27
+ from abc import ABC, abstractmethod
28
 
29
  import cv2
30
+ import gradio as gr
31
  import numpy as np
32
+ import onnxruntime as ort
33
  import torch
34
+ import torch.nn as nn
35
  from PIL import Image
 
36
  from transformers import pipeline
37
  from huggingface_hub import hf_hub_download
38
 
39
+ # Configure logging
40
+ logging.basicConfig(level=logging.INFO)
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ # =============================================================================
45
+ # Configuration and Data Models
46
+ # =============================================================================
47
+
48
+ class ModelType(Enum):
49
+ """Enumeration of available model types."""
50
+ AESTHETIC_SHADOW = "aesthetic_shadow"
51
+ WAIFU_SCORER = "waifu_scorer"
52
+ AESTHETIC_PREDICTOR_V2_5 = "aesthetic_predictor_v2_5"
53
+ ANIME_AESTHETIC = "anime_aesthetic"
54
+
55
+
56
+ @dataclass
57
+ class ModelConfig:
58
+ """Configuration for individual models."""
59
+ name: str
60
+ display_name: str
61
+ enabled: bool = True
62
+ batch_supported: bool = True
63
+ model_path: Optional[str] = None
64
+ cache_dir: Optional[str] = None
65
+
66
+
67
+ @dataclass
68
+ class ProcessingConfig:
69
+ """Configuration for processing parameters."""
70
+ auto_batch: bool = False
71
+ manual_batch_size: int = 1
72
+ max_batch_size: int = 64
73
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
74
+ score_range: Tuple[float, float] = (0.0, 10.0)
75
+
76
+
77
+ @dataclass
78
+ class EvaluationResult:
79
+ """Data class for individual evaluation results."""
80
+ file_name: str
81
+ file_path: str
82
+ thumbnail_b64: str
83
+ model_scores: Dict[str, Optional[float]] = field(default_factory=dict)
84
+ final_score: Optional[float] = None
85
+ processing_time: float = 0.0
86
+ error: Optional[str] = None
87
+
88
+
89
+ @dataclass
90
+ class BatchResult:
91
+ """Data class for batch processing results."""
92
+ results: List[EvaluationResult]
93
+ logs: List[str]
94
+ processing_time: float
95
+ batch_size_used: int
96
+ success_count: int
97
+ error_count: int
98
+
99
+
100
+ # =============================================================================
101
+ # Model Interfaces and Implementations
102
+ # =============================================================================
103
+
104
+ class BaseModel(Protocol):
105
+ """Protocol defining the interface for all evaluation models."""
106
+
107
+ async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
108
+ """Predict scores for a batch of images."""
109
+ ...
110
+
111
+ def is_available(self) -> bool:
112
+ """Check if the model is available and ready for inference."""
113
+ ...
114
+
115
+ def cleanup(self) -> None:
116
+ """Clean up model resources."""
117
+ ...
118
+
119
+
120
+ class ModernMLP(nn.Module):
121
+ """Modern implementation of MLP with improved architecture."""
122
+
123
+ def __init__(
124
+ self,
125
+ input_size: int,
126
+ hidden_dims: List[int] = None,
127
+ dropout_rates: List[float] = None,
128
+ use_batch_norm: bool = True,
129
+ activation: nn.Module = nn.ReLU
130
+ ):
131
  super().__init__()
132
+
133
+ if hidden_dims is None:
134
+ hidden_dims = [2048, 512, 256, 128, 32]
135
+ if dropout_rates is None:
136
+ dropout_rates = [0.3, 0.3, 0.2, 0.1, 0.0]
137
+
138
+ layers = []
139
+ prev_dim = input_size
140
+
141
+ for i, (hidden_dim, dropout_rate) in enumerate(zip(hidden_dims, dropout_rates)):
142
+ layers.append(nn.Linear(prev_dim, hidden_dim))
143
+ layers.append(activation())
144
+
145
+ if use_batch_norm and i < len(hidden_dims) - 1:
146
+ layers.append(nn.BatchNorm1d(hidden_dim))
147
+
148
+ if dropout_rate > 0:
149
+ layers.append(nn.Dropout(dropout_rate))
150
+
151
+ prev_dim = hidden_dim
152
+
153
+ # Final output layer
154
+ layers.append(nn.Linear(prev_dim, 1))
155
+ self.network = nn.Sequential(*layers)
156
+
157
  def forward(self, x: torch.Tensor) -> torch.Tensor:
158
+ return self.network(x)
159
 
160
 
161
+ class WaifuScorerModel:
162
+ """Modernized WaifuScorer implementation with better error handling."""
163
+
164
+ def __init__(self, config: ModelConfig, device: str):
165
+ self.config = config
166
  self.device = device
167
  self.dtype = torch.float32
168
+ self._available = False
169
+ self._model = None
170
+ self._clip_model = None
171
+ self._preprocess = None
172
+
173
+ self._initialize_model()
174
+
175
+ def _initialize_model(self) -> None:
176
+ """Initialize the model with proper error handling."""
177
  try:
178
+ import clip
179
+
180
+ # Download model if needed
181
+ model_path = self._get_model_path()
182
+
183
+ # Initialize MLP
184
+ self._model = ModernMLP(input_size=768)
185
+
186
+ # Load weights
 
 
 
 
 
 
 
 
 
187
  if model_path.endswith(".safetensors"):
188
  from safetensors.torch import load_file
189
  state_dict = load_file(model_path)
190
  else:
191
+ state_dict = torch.load(model_path, map_location=self.device)
192
+
193
+ self._model.load_state_dict(state_dict)
194
+ self._model.to(self.device)
195
+ self._model.eval()
196
+
197
+ # Load CLIP model
198
+ self._clip_model, self._preprocess = clip.load("ViT-L/14", device=self.device)
199
+ self._available = True
200
+
201
+ logger.info(f"WaifuScorer model loaded successfully on {self.device}")
202
+
203
+ except Exception as e:
204
+ logger.error(f"Failed to initialize WaifuScorer: {e}")
205
+ self._available = False
206
+
207
+ def _get_model_path(self) -> str:
208
+ """Get or download the model path."""
209
+ if self.config.model_path and os.path.isfile(self.config.model_path):
210
+ return self.config.model_path
211
+
212
+ # Default download path
213
+ model_path = "Eugeoter/waifu-scorer-v3/model.pth"
214
+ username, repo_id, model_name = model_path.split("/")[-3:]
215
+ return hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=self.config.cache_dir)
216
+
217
+ async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
218
+ """Predict scores for a batch of images."""
219
+ if not self._available:
220
+ return [None] * len(images)
221
+
222
+ try:
223
+ # Handle single image case for CLIP compatibility
224
+ batch_images = images * 2 if len(images) == 1 else images
225
+
226
+ # Preprocess images
227
+ image_tensors = [self._preprocess(img).unsqueeze(0) for img in batch_images]
228
+ image_batch = torch.cat(image_tensors).to(self.device)
229
+
230
+ # Extract features and predict
231
+ with torch.no_grad():
232
+ image_features = self._clip_model.encode_image(image_batch)
233
+ # Normalize features
234
+ norm = image_features.norm(2, dim=-1, keepdim=True)
235
+ norm[norm == 0] = 1
236
+ normalized_features = (image_features / norm).to(device=self.device, dtype=self.dtype)
237
+
238
+ predictions = self._model(normalized_features)
239
+ scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
240
+
241
+ return scores[:len(images)]
242
+
243
+ except Exception as e:
244
+ logger.error(f"Error in WaifuScorer prediction: {e}")
245
+ return [None] * len(images)
246
+
247
+ def is_available(self) -> bool:
248
+ return self._available
249
+
250
+ def cleanup(self) -> None:
251
+ """Clean up model resources."""
252
+ if self._model is not None:
253
+ del self._model
254
+ if self._clip_model is not None:
255
+ del self._clip_model
256
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
257
+
258
+
259
+ class AestheticShadowModel:
260
+ """Wrapper for Aesthetic Shadow model using transformers pipeline."""
261
+
262
+ def __init__(self, config: ModelConfig, device: str):
263
+ self.config = config
264
+ self.device = device
265
+ self._available = False
266
+ self._model = None
267
+
268
+ self._initialize_model()
269
+
270
+ def _initialize_model(self) -> None:
271
+ """Initialize the model pipeline."""
272
+ try:
273
+ self._model = pipeline(
274
+ "image-classification",
275
+ model="NeoChen1024/aesthetic-shadow-v2-backup",
276
+ device=self.device
277
+ )
278
+ self._available = True
279
+ logger.info("Aesthetic Shadow model loaded successfully")
280
+
281
  except Exception as e:
282
+ logger.error(f"Failed to initialize Aesthetic Shadow: {e}")
283
+ self._available = False
284
+
285
+ async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
286
+ """Predict scores for a batch of images."""
287
+ if not self._available:
288
+ return [None] * len(images)
289
+
290
+ try:
291
+ results = self._model(images)
292
+ scores = []
293
+
294
+ for result in results:
295
+ try:
296
+ hq_score = next(p for p in result if p['label'] == 'hq')['score']
297
+ score = float(np.clip(hq_score * 10.0, 0.0, 10.0))
298
+ scores.append(score)
299
+ except (StopIteration, KeyError, TypeError):
300
+ scores.append(None)
301
+
302
+ return scores
303
+
304
+ except Exception as e:
305
+ logger.error(f"Error in Aesthetic Shadow prediction: {e}")
306
+ return [None] * len(images)
307
+
308
+ def is_available(self) -> bool:
309
+ return self._available
310
+
311
+ def cleanup(self) -> None:
312
+ if self._model is not None:
313
+ del self._model
314
+
315
+
316
+ class AestheticPredictorV25Model:
317
+ """Wrapper for Aesthetic Predictor V2.5 model."""
318
+
319
+ def __init__(self, config: ModelConfig, device: str):
320
+ self.config = config
321
+ self.device = device
322
+ self._available = False
323
+ self._model = None
324
+ self._preprocessor = None
325
+
326
+ self._initialize_model()
327
+
328
+ def _initialize_model(self) -> None:
329
+ """Initialize the model."""
330
+ try:
331
+ from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
332
+
333
+ self._model, self._preprocessor = convert_v2_5_from_siglip(
334
  low_cpu_mem_usage=True,
335
  trust_remote_code=True,
336
  )
337
+
338
  if torch.cuda.is_available():
339
+ self._model = self._model.to(torch.bfloat16).cuda()
340
+
341
+ self._available = True
342
+ logger.info("Aesthetic Predictor V2.5 loaded successfully")
343
+
344
+ except Exception as e:
345
+ logger.error(f"Failed to initialize Aesthetic Predictor V2.5: {e}")
346
+ self._available = False
347
+
348
+ async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
349
+ """Predict scores for a batch of images."""
350
+ if not self._available:
351
+ return [None] * len(images)
352
+
353
+ try:
354
+ rgb_images = [img.convert("RGB") for img in images]
355
+ pixel_values = self._preprocessor(images=rgb_images, return_tensors="pt").pixel_values
356
+
357
+ if torch.cuda.is_available():
358
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
359
+
360
+ with torch.inference_mode():
361
+ scores = self._model(pixel_values).logits.squeeze().float().cpu().numpy()
362
+
363
+ if scores.ndim == 0:
364
+ scores = np.array([scores])
365
+
366
+ return [float(np.round(np.clip(s, 0.0, 10.0), 4)) for s in scores]
367
+
368
+ except Exception as e:
369
+ logger.error(f"Error in Aesthetic Predictor V2.5 prediction: {e}")
370
+ return [None] * len(images)
371
+
372
+ def is_available(self) -> bool:
373
+ return self._available
374
+
375
+ def cleanup(self) -> None:
376
+ if self._model is not None:
377
+ del self._model
378
+
379
+
380
+ class AnimeAestheticModel:
381
+ """ONNX-based Anime Aesthetic model."""
382
+
383
+ def __init__(self, config: ModelConfig, device: str):
384
+ self.config = config
385
+ self.device = device
386
+ self._available = False
387
+ self._session = None
388
+
389
+ self._initialize_model()
390
+
391
+ def _initialize_model(self) -> None:
392
+ """Initialize the ONNX model."""
393
+ try:
394
+ model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
395
+ self._session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
396
+ self._available = True
397
+ logger.info("Anime Aesthetic model loaded successfully")
398
+
399
+ except Exception as e:
400
+ logger.error(f"Failed to initialize Anime Aesthetic: {e}")
401
+ self._available = False
402
+
403
+ async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
404
+ """Predict scores for images (single image processing for ONNX)."""
405
+ if not self._available:
406
+ return [None] * len(images)
407
+
408
+ scores = []
409
+ for img in images:
410
+ try:
411
+ score = self._predict_single(img)
412
+ scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
413
+ except Exception as e:
414
+ logger.error(f"Error predicting anime aesthetic for image: {e}")
415
+ scores.append(None)
416
+
417
+ return scores
418
+
419
+ def _predict_single(self, img: Image.Image) -> float:
420
+ """Predict score for a single image."""
421
+ img_np = np.array(img).astype(np.float32) / 255.0
422
+ s = 768
423
+ h, w = img_np.shape[:2]
424
+
425
+ # Resize while maintaining aspect ratio
426
+ if h > w:
427
+ new_h, new_w = s, int(s * w / h)
428
+ else:
429
+ new_h, new_w = int(s * h / w), s
430
+
431
+ resized = cv2.resize(img_np, (new_w, new_h))
432
+
433
+ # Center crop/pad to square
434
+ canvas = np.zeros((s, s, 3), dtype=np.float32)
435
+ pad_h = (s - new_h) // 2
436
+ pad_w = (s - new_w) // 2
437
+ canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
438
+
439
+ # Prepare input
440
+ input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
441
+ return self._session.run(None, {"img": input_tensor})[0].item()
442
+
443
+ def is_available(self) -> bool:
444
+ return self._available
445
+
446
+ def cleanup(self) -> None:
447
+ if self._session is not None:
448
+ del self._session
449
+
450
+
451
+ # =============================================================================
452
+ # Model Factory and Manager
453
+ # =============================================================================
454
+
455
+ class ModelFactory:
456
+ """Factory for creating model instances."""
457
+
458
+ _MODEL_CLASSES = {
459
+ ModelType.AESTHETIC_SHADOW: AestheticShadowModel,
460
+ ModelType.WAIFU_SCORER: WaifuScorerModel,
461
+ ModelType.AESTHETIC_PREDICTOR_V2_5: AestheticPredictorV25Model,
462
+ ModelType.ANIME_AESTHETIC: AnimeAestheticModel,
463
+ }
464
+
465
+ @classmethod
466
+ def create_model(cls, model_type: ModelType, config: ModelConfig, device: str) -> BaseModel:
467
+ """Create a model instance based on type."""
468
+ model_class = cls._MODEL_CLASSES.get(model_type)
469
+ if not model_class:
470
+ raise ValueError(f"Unknown model type: {model_type}")
471
+
472
+ return model_class(config, device)
473
+
474
 
475
  class ModelManager:
476
+ """Advanced model manager with async processing and resource management."""
477
+
478
+ def __init__(self, processing_config: ProcessingConfig):
479
+ self.config = processing_config
480
+ self.models: Dict[ModelType, BaseModel] = {}
481
+ self.model_configs = self._create_default_configs()
482
+ self._processing_queue = asyncio.Queue()
483
+ self._worker_task: Optional[asyncio.Task] = None
484
+ self._temp_dir = Path(tempfile.mkdtemp())
485
+
486
+ self._initialize_models()
487
+
488
+ def _create_default_configs(self) -> Dict[ModelType, ModelConfig]:
489
+ """Create default model configurations."""
490
+ return {
491
+ ModelType.AESTHETIC_SHADOW: ModelConfig(
492
+ name="aesthetic_shadow",
493
+ display_name="Aesthetic Shadow"
494
+ ),
495
+ ModelType.WAIFU_SCORER: ModelConfig(
496
+ name="waifu_scorer",
497
+ display_name="Waifu Scorer"
498
+ ),
499
+ ModelType.AESTHETIC_PREDICTOR_V2_5: ModelConfig(
500
+ name="aesthetic_predictor_v2_5",
501
+ display_name="Aesthetic V2.5"
502
+ ),
503
+ ModelType.ANIME_AESTHETIC: ModelConfig(
504
+ name="anime_aesthetic",
505
+ display_name="Anime Score",
506
+ batch_supported=False
507
+ ),
508
  }
509
+
510
+ def _initialize_models(self) -> None:
511
+ """Initialize all models."""
512
+ logger.info("Initializing models...")
513
+
514
+ for model_type, config in self.model_configs.items():
515
+ if config.enabled:
516
+ try:
517
+ model = ModelFactory.create_model(model_type, config, self.config.device)
518
+ if model.is_available():
519
+ self.models[model_type] = model
520
+ logger.info(f"✓ {config.display_name} loaded successfully")
521
+ else:
522
+ logger.warning(f"✗ {config.display_name} failed to load")
523
+ except Exception as e:
524
+ logger.error(f"✗ {config.display_name} initialization error: {e}")
525
+
526
+ logger.info(f"Initialized {len(self.models)} models successfully")
527
+
528
+ async def start_worker(self) -> None:
529
+ """Start the background processing worker."""
530
+ if self._worker_task is None:
531
+ self._worker_task = asyncio.create_task(self._worker_loop())
532
+ logger.info("Background worker started")
533
+
534
+ async def _worker_loop(self) -> None:
535
+ """Main worker loop for processing requests."""
536
  while True:
537
+ request = await self._processing_queue.get()
538
+
539
+ if request is None: # Shutdown signal
540
  break
541
+
542
  try:
543
+ result = await self._process_request(request)
544
+ request['future'].set_result(result)
545
  except Exception as e:
546
+ request['future'].set_exception(e)
547
  finally:
548
+ self._processing_queue.task_done()
549
+
550
+ async def process_images(
551
+ self,
552
+ file_paths: List[str],
553
+ selected_models: List[ModelType],
554
+ auto_batch: bool = False,
555
+ manual_batch_size: int = 1
556
+ ) -> BatchResult:
557
+ """Process images with selected models."""
558
+ future = asyncio.Future()
559
+ request = {
560
+ 'file_paths': file_paths,
561
+ 'selected_models': selected_models,
562
+ 'auto_batch': auto_batch,
563
+ 'manual_batch_size': manual_batch_size,
564
+ 'future': future
565
+ }
566
+
567
+ await self._processing_queue.put(request)
568
+ return await future
569
+
570
+ async def _process_request(self, request: Dict) -> BatchResult:
571
+ """Process a single batch request."""
572
+ start_time = asyncio.get_event_loop().time()
573
+ logs = []
574
+ results = []
575
+
576
  file_paths = request['file_paths']
577
+ selected_models = request['selected_models']
578
  auto_batch = request['auto_batch']
579
  manual_batch_size = request['manual_batch_size']
580
+
581
+ # Load images
582
+ images, valid_paths = await self._load_images(file_paths, logs)
583
+
584
+ if not images:
585
+ return BatchResult([], logs, 0.0, 0, 0, len(file_paths))
586
+
587
+ # Determine batch size
588
+ batch_size = await self._determine_batch_size(images, auto_batch, manual_batch_size, logs)
589
+
590
+ # Process in batches
591
+ for i in range(0, len(images), batch_size):
592
+ batch_images = images[i:i+batch_size]
593
+ batch_paths = valid_paths[i:i+batch_size]
594
+
595
+ batch_results = await self._process_batch(batch_images, batch_paths, selected_models, logs)
596
+ results.extend(batch_results)
597
+
598
+ processing_time = asyncio.get_event_loop().time() - start_time
599
+ success_count = sum(1 for r in results if r.error is None)
600
+ error_count = len(results) - success_count
601
+
602
+ return BatchResult(
603
+ results=results,
604
+ logs=logs,
605
+ processing_time=processing_time,
606
+ batch_size_used=batch_size,
607
+ success_count=success_count,
608
+ error_count=error_count
609
+ )
610
+
611
+ async def _load_images(self, file_paths: List[str], logs: List[str]) -> Tuple[List[Image.Image], List[str]]:
612
+ """Load and validate images."""
613
  images = []
614
+ valid_paths = []
615
+
616
+ logs.append(f"Loading {len(file_paths)} images...")
617
+
618
+ for path in file_paths:
 
 
619
  try:
620
+ img = Image.open(path).convert("RGB")
621
  images.append(img)
622
+ valid_paths.append(path)
623
  except Exception as e:
624
+ logs.append(f"Failed to load {path}: {e}")
625
+
626
+ logs.append(f"Successfully loaded {len(images)} images")
627
+ return images, valid_paths
628
+
629
+ async def _determine_batch_size(
630
+ self,
631
+ images: List[Image.Image],
632
+ auto_batch: bool,
633
+ manual_batch_size: int,
634
+ logs: List[str]
635
+ ) -> int:
636
+ """Determine optimal batch size."""
637
+ if not auto_batch:
638
+ return min(manual_batch_size, len(images))
639
+
640
+ # Auto-tune batch size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
  batch_size = 1
 
642
  test_image = images[0:1]
643
+
644
+ while batch_size <= min(len(images), self.config.max_batch_size):
645
  try:
646
+ # Test with a sample of available models
647
+ test_batch = test_image * batch_size
648
+ for model_type, model in list(self.models.items())[:2]: # Test with first 2 models
649
+ await model.predict(test_batch)
650
+
 
651
  batch_size *= 2
 
 
652
  except Exception:
653
  break
654
+
655
+ optimal_batch = max(1, batch_size // 2)
656
+ logs.append(f"Auto-tuned batch size: {optimal_batch}")
657
+ return optimal_batch
658
+
659
+ async def _process_batch(
660
+ self,
661
+ images: List[Image.Image],
662
+ paths: List[str],
663
+ selected_models: List[ModelType],
664
+ logs: List[str]
665
+ ) -> List[EvaluationResult]:
666
+ """Process a single batch of images."""
667
+ batch_results = []
668
+
669
+ # Get predictions from all models
670
+ model_predictions = {}
671
+ for model_type in selected_models:
672
+ if model_type in self.models:
673
+ try:
674
+ predictions = await self.models[model_type].predict(images)
675
+ model_predictions[model_type.value] = predictions
676
+ logs.append(f" {self.model_configs[model_type].display_name} processed batch")
677
+ except Exception as e:
678
+ logs.append(f"✗ {self.model_configs[model_type].display_name} error: {e}")
679
+ model_predictions[model_type.value] = [None] * len(images)
680
+
681
+ # Create results
682
+ for i, (image, path) in enumerate(zip(images, paths)):
683
+ # Collect scores for this image
684
+ scores = {}
685
+ valid_scores = []
686
+
687
+ for model_type in selected_models:
688
+ score = model_predictions.get(model_type.value, [None] * len(images))[i]
689
+ scores[model_type.value] = score
690
+ if score is not None:
691
+ valid_scores.append(score)
692
+
693
+ # Calculate final score
694
+ final_score = np.mean(valid_scores) if valid_scores else None
695
+ if final_score is not None:
696
+ final_score = float(np.clip(final_score, *self.config.score_range))
697
+
698
+ # Create thumbnail
699
+ thumbnail = image.copy()
700
+ thumbnail.thumbnail((200, 200), Image.Resampling.LANCZOS)
701
+ thumbnail_b64 = self._image_to_base64(thumbnail)
702
+
703
+ result = EvaluationResult(
704
+ file_name=Path(path).name,
705
+ file_path=path,
706
+ thumbnail_b64=thumbnail_b64,
707
+ model_scores=scores,
708
+ final_score=final_score
709
+ )
710
+
711
+ batch_results.append(result)
712
+
713
+ return batch_results
714
+
715
+ def _image_to_base64(self, image: Image.Image) -> str:
716
+ """Convert PIL Image to base64 string."""
717
+ buffer = BytesIO()
718
+ image.save(buffer, format="JPEG", quality=85, optimize=True)
719
+ return base64.b64encode(buffer.getvalue()).decode('utf-8')
720
+
721
+ def get_available_models(self) -> Dict[ModelType, str]:
722
+ """Get available models with their display names."""
723
+ return {
724
+ model_type: self.model_configs[model_type].display_name
725
+ for model_type in self.models.keys()
 
 
 
 
 
 
 
 
 
726
  }
727
+
728
+ async def cleanup(self) -> None:
729
+ """Clean up resources."""
730
+ # Shutdown worker
731
+ if self._worker_task:
732
+ await self._processing_queue.put(None)
733
+ await self._worker_task
734
+
735
+ # Clean up models
736
+ for model in self.models.values():
737
+ model.cleanup()
738
+
739
+ # Clean up temp directory
740
+ if self._temp_dir.exists():
741
+ shutil.rmtree(self._temp_dir)
742
+
743
+ logger.info("Model manager cleanup completed")
744
+
745
+
746
+ # =============================================================================
747
+ # Results Processing and Export
748
+ # =============================================================================
749
+
750
+ class ResultsProcessor:
751
+ """Handle result processing, sorting, and export functionality."""
752
+
753
+ @staticmethod
754
+ def sort_results(results: List[EvaluationResult], sort_by: str, reverse: bool = True) -> List[EvaluationResult]:
755
+ """Sort results by specified criteria."""
756
+ sort_key_map = {
757
+ "Final Score": lambda r: r.final_score if r.final_score is not None else -float('inf'),
758
+ "File Name": lambda r: r.file_name.lower(),
759
+ **{f"model_{model_type.value}": lambda r, mt=model_type.value: r.model_scores.get(mt) or -float('inf')
760
+ for model_type in ModelType}
761
+ }
762
+
763
+ sort_key = sort_key_map.get(sort_by, sort_key_map["Final Score"])
764
+ return sorted(results, key=sort_key, reverse=reverse and sort_by != "File Name")
765
+
766
+ @staticmethod
767
+ def generate_html_table(results: List[EvaluationResult], selected_models: List[ModelType]) -> str:
768
+ """Generate HTML table for results display."""
769
+ if not results:
770
+ return "<p>No results to display</p>"
771
+
772
+ # CSS styles
773
+ styles = """
774
  <style>
775
+ .results-table {
776
+ width: 100%; border-collapse: collapse; margin: 20px 0;
777
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
778
+ }
779
+ .results-table th, .results-table td {
780
+ border: 1px solid #ddd; padding: 12px; text-align: center;
781
+ }
782
+ .results-table th {
783
+ background-color: #f8f9fa; font-weight: 600; color: #495057;
784
+ }
785
+ .results-table tr:nth-child(even) { background-color: #f8f9fa; }
786
+ .results-table tr:hover { background-color: #e9ecef; }
787
+ .image-preview {
788
+ max-width: 120px; max-height: 120px; border-radius: 8px;
789
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
790
+ }
791
+ .score-excellent { color: #28a745; font-weight: bold; }
792
+ .score-good { color: #ffc107; font-weight: bold; }
793
+ .score-poor { color: #dc3545; font-weight: bold; }
794
+ .score-na { color: #6c757d; font-style: italic; }
795
  </style>
 
 
 
 
 
796
  """
797
+
798
+ # Table header
799
+ html = styles + '<table class="results-table"><thead><tr>'
800
+ html += '<th>Image</th><th>File Name</th>'
801
+
802
+ for model_type in selected_models:
803
+ model_name = ModelType(model_type).name.replace('_', ' ').title()
804
+ html += f'<th>{model_name}</th>'
805
+
806
+ html += '<th>Final Score</th></tr></thead><tbody>'
807
+
808
+ # Table rows
 
 
 
 
809
  for result in results:
810
+ html += '<tr>'
811
+ html += f'<td><img src="data:image/jpeg;base64,{result.thumbnail_b64}" class="image-preview" alt="{result.file_name}"></td>'
812
+ html += f'<td>{result.file_name}</td>'
813
+
814
+ # Model scores
815
+ for model_type in selected_models:
816
+ score = result.model_scores.get(model_type.value)
817
+ html += ResultsProcessor._format_score_cell(score)
818
+
819
+ # Final score
820
+ html += ResultsProcessor._format_score_cell(result.final_score)
821
+ html += '</tr>'
822
+
823
+ html += '</tbody></table>'
824
+ return html