VOIDER commited on
Commit
f56b01d
·
verified ·
1 Parent(s): 1bc1e75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +887 -412
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import os
 
 
2
  import asyncio
3
- from typing import List, Dict, Optional, Tuple, Any
4
- from dataclasses import dataclass, field
5
  from pathlib import Path
6
  import logging
7
 
@@ -11,508 +13,981 @@ import torch
11
  import onnxruntime as rt
12
  from PIL import Image
13
  import gradio as gr
14
- from transformers import pipeline
15
  from huggingface_hub import hf_hub_download
16
- import pandas as pd
17
 
18
- # Configure logging
19
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
  logger = logging.getLogger(__name__)
21
 
22
- # Import aesthetic predictor function
23
- from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
- @dataclass
27
- class EvaluationResult:
28
- """Data class for storing image evaluation results"""
29
- file_name: str
30
- image_path: str
31
- scores: Dict[str, Optional[float]] = field(default_factory=dict)
32
- final_score: Optional[float] = None
33
-
34
- def calculate_final_score(self, selected_models: List[str]) -> None:
35
- """Calculate the average score from selected models"""
36
- valid_scores = [
37
- score for model, score in self.scores.items()
38
- if model in selected_models and score is not None
 
 
 
 
 
 
39
  ]
40
- self.final_score = np.mean(valid_scores) if valid_scores else None
41
 
 
 
42
 
43
- class BaseModel:
44
- """Base class for all evaluation models"""
45
- def __init__(self, name: str):
46
- self.name = name
47
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
48
-
49
- async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
50
- """Evaluate a batch of images"""
51
- raise NotImplementedError
52
-
53
-
54
- class AestheticShadowModel(BaseModel):
55
- """Aesthetic Shadow V2 model implementation"""
56
- def __init__(self):
57
- super().__init__("Aesthetic Shadow")
58
- logger.info(f"Loading {self.name} model...")
59
- self.model = pipeline(
60
- "image-classification",
61
- model="NeoChen1024/aesthetic-shadow-v2-backup",
62
- device=0 if self.device == 'cuda' else -1
63
- )
64
-
65
- async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
66
- try:
67
- results = self.model(images)
68
- scores = []
69
- for result in results:
70
- hq_score = next((p['score'] for p in result if p['label'] == 'hq'), 0)
71
- scores.append(float(np.clip(hq_score * 10.0, 0.0, 10.0)))
72
- return scores
73
- except Exception as e:
74
- logger.error(f"Error in {self.name}: {e}")
75
- return [None] * len(images)
76
 
 
 
 
 
 
 
 
 
 
77
 
78
- class WaifuScorerModel(BaseModel):
79
- """Waifu Scorer V3 model implementation"""
80
- def __init__(self):
81
- super().__init__("Waifu Scorer")
82
- logger.info(f"Loading {self.name} model...")
83
- self._load_model()
84
-
85
- def _load_model(self):
86
  try:
87
- import clip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Load MLP model
90
- self.mlp = self._create_mlp()
91
- model_path = hf_hub_download("Eugeoter/waifu-scorer-v3", "model.pth")
92
- state_dict = torch.load(model_path, map_location=self.device)
93
  self.mlp.load_state_dict(state_dict)
94
- self.mlp.to(self.device).eval()
95
-
96
- # Load CLIP model
97
  self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device)
98
  self.available = True
 
 
 
99
  except Exception as e:
100
- logger.error(f"Failed to load {self.name}: {e}")
101
- self.available = False
102
-
103
- def _create_mlp(self) -> torch.nn.Module:
104
- """Create the MLP architecture"""
105
- return torch.nn.Sequential(
106
- torch.nn.Linear(768, 2048),
107
- torch.nn.ReLU(),
108
- torch.nn.BatchNorm1d(2048),
109
- torch.nn.Dropout(0.3),
110
- torch.nn.Linear(2048, 512),
111
- torch.nn.ReLU(),
112
- torch.nn.BatchNorm1d(512),
113
- torch.nn.Dropout(0.3),
114
- torch.nn.Linear(512, 256),
115
- torch.nn.ReLU(),
116
- torch.nn.BatchNorm1d(256),
117
- torch.nn.Dropout(0.2),
118
- torch.nn.Linear(256, 128),
119
- torch.nn.ReLU(),
120
- torch.nn.BatchNorm1d(128),
121
- torch.nn.Dropout(0.1),
122
- torch.nn.Linear(128, 32),
123
- torch.nn.ReLU(),
124
- torch.nn.Linear(32, 1)
125
- )
126
-
127
  @torch.no_grad()
128
- async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
129
  if not self.available:
130
  return [None] * len(images)
131
 
 
 
 
 
 
 
 
132
  try:
133
- # Process images
134
- image_tensors = torch.cat([self.preprocess(img).unsqueeze(0) for img in images])
135
- image_tensors = image_tensors.to(self.device)
136
 
137
- # Extract features and predict
138
- features = self.clip_model.encode_image(image_tensors)
139
- features = features / features.norm(dim=-1, keepdim=True)
140
- predictions = self.mlp(features)
141
 
142
- scores = predictions.clamp(0, 10).cpu().numpy().flatten().tolist()
143
- return scores
 
144
  except Exception as e:
145
- logger.error(f"Error in {self.name}: {e}")
146
- return [None] * len(images)
147
 
148
 
149
- class AestheticPredictorV25Model(BaseModel):
150
- """Aesthetic Predictor V2.5 model implementation"""
151
- def __init__(self):
152
- super().__init__("Aesthetic V2.5")
153
- logger.info(f"Loading {self.name} model...")
154
  self.model, self.preprocessor = convert_v2_5_from_siglip(
155
- low_cpu_mem_usage=True,
156
- trust_remote_code=True,
157
  )
158
- if self.device == 'cuda':
159
  self.model = self.model.to(torch.bfloat16).cuda()
160
-
 
161
  @torch.no_grad()
162
- async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
 
 
163
  try:
164
  images_rgb = [img.convert("RGB") for img in images]
165
  pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt").pixel_values
166
-
167
- if self.device == 'cuda':
168
  pixel_values = pixel_values.to(torch.bfloat16).cuda()
169
 
170
- scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
171
- if scores.ndim == 0:
172
- scores = np.array([scores])
173
-
174
- return [float(np.clip(s, 0.0, 10.0)) for s in scores]
 
175
  except Exception as e:
176
- logger.error(f"Error in {self.name}: {e}")
177
  return [None] * len(images)
178
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- class AnimeAestheticModel(BaseModel):
181
- """Anime Aesthetic model implementation"""
182
- def __init__(self):
183
- super().__init__("Anime Score")
184
- logger.info(f"Loading {self.name} model...")
185
- model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
186
- self.session = rt.InferenceSession(model_path, providers=['CPUExecutionProvider'])
187
-
188
- async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]:
189
- scores = []
190
- for img in images:
191
- try:
192
- score = self._process_single_image(img)
193
- scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
194
- except Exception as e:
195
- logger.error(f"Error in {self.name} for single image: {e}")
196
- scores.append(None)
197
- return scores
198
-
199
- def _process_single_image(self, img: Image.Image) -> float:
200
- """Process a single image through the model"""
201
- img_np = np.array(img).astype(np.float32) / 255.0
202
- size = 768
203
- h, w = img_np.shape[:2]
204
-
205
- # Calculate new dimensions
206
- if h > w:
207
- new_h, new_w = size, int(size * w / h)
208
- else:
209
- new_h, new_w = int(size * h / w), size
210
-
211
- # Resize and center
212
- resized = cv2.resize(img_np, (new_w, new_h))
213
- canvas = np.zeros((size, size, 3), dtype=np.float32)
214
- pad_h = (size - new_h) // 2
215
- pad_w = (size - new_w) // 2
216
- canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
217
 
218
- # Prepare input
219
- input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
220
- return self.session.run(None, {"img": input_tensor})[0].item()
 
 
 
221
 
 
 
 
222
 
223
- class ImageEvaluator:
224
- """Main class for managing image evaluation"""
225
- def __init__(self):
226
- self.models: Dict[str, BaseModel] = {}
227
- self._initialize_models()
228
- self.results: List[EvaluationResult] = []
229
-
230
- def _initialize_models(self):
231
- """Initialize all evaluation models"""
232
- model_classes = [
233
- ("aesthetic_shadow", AestheticShadowModel),
234
- ("waifu_scorer", WaifuScorerModel),
235
- ("aesthetic_predictor_v2_5", AestheticPredictorV25Model),
236
- ("anime_aesthetic", AnimeAestheticModel),
237
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
- for key, model_class in model_classes:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  try:
241
- self.models[key] = model_class()
242
- logger.info(f"Successfully loaded {key}")
 
 
 
 
 
 
243
  except Exception as e:
244
- logger.error(f"Failed to load {key}: {e}")
245
-
246
- async def evaluate_images(
247
- self,
248
- file_paths: List[str],
249
- selected_models: List[str],
250
- batch_size: int = 8,
251
- progress_callback = None
252
- ) -> Tuple[List[EvaluationResult], List[str]]:
253
- """Evaluate images with selected models"""
254
- logs = []
255
- results = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
- # Load images
258
- images = []
259
- valid_paths = []
260
- for path in file_paths:
261
  try:
262
- img = Image.open(path).convert("RGB")
263
- images.append(img)
264
- valid_paths.append(path)
265
- except Exception as e:
266
- logs.append(f"Failed to load {Path(path).name}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
- if not images:
269
- logs.append("No valid images to process")
270
- return results, logs
 
 
 
 
 
271
 
272
- logs.append(f"Loaded {len(images)} images")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- # Process in batches
275
- total_batches = (len(images) + batch_size - 1) // batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- for batch_idx in range(0, len(images), batch_size):
278
- batch_images = images[batch_idx:batch_idx + batch_size]
279
- batch_paths = valid_paths[batch_idx:batch_idx + batch_size]
280
-
281
- # Evaluate with each selected model
282
- batch_results = {}
283
- for model_key in selected_models:
284
- if model_key in self.models:
285
- scores = await self.models[model_key].evaluate_batch(batch_images)
286
- batch_results[model_key] = scores
287
- logs.append(f"Processed batch {batch_idx//batch_size + 1}/{total_batches} with {self.models[model_key].name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
- # Create results
290
- for i, (path, img) in enumerate(zip(batch_paths, batch_images)):
291
- result = EvaluationResult(
292
- file_name=Path(path).name,
293
- image_path=path
294
- )
295
 
296
- for model_key in selected_models:
297
- if model_key in batch_results:
298
- result.scores[model_key] = batch_results[model_key][i]
 
 
 
 
 
 
 
 
 
 
 
299
 
300
- result.calculate_final_score(selected_models)
301
- results.append(result)
302
-
303
- # Update progress
304
- if progress_callback:
305
- progress = (batch_idx + batch_size) / len(images) * 100
306
- progress_callback(min(progress, 100))
307
-
308
- self.results = results
309
- return results, logs
310
-
311
- def get_results_dataframe(self, selected_models: List[str]) -> pd.DataFrame:
312
- """Convert results to pandas DataFrame"""
313
- if not self.results:
314
- return pd.DataFrame()
315
-
316
- data = []
317
- for result in self.results:
318
- row = {
319
- 'File Name': result.file_name,
320
- 'Image': result.image_path,
321
- }
322
 
323
- # Add model scores
324
- for model_key in selected_models:
325
- if model_key in self.models:
326
- score = result.scores.get(model_key)
327
- row[self.models[model_key].name] = f"{score:.4f}" if score is not None else "N/A"
328
 
329
- row['Final Score'] = f"{result.final_score:.4f}" if result.final_score is not None else "N/A"
330
- data.append(row)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
- return pd.DataFrame(data)
 
 
 
 
 
 
 
 
 
 
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
  def create_interface():
336
- """Create the Gradio interface"""
337
- evaluator = ImageEvaluator()
338
-
339
- # Model options for checkbox
340
- model_options = [
341
- ("Aesthetic Shadow", "aesthetic_shadow"),
342
- ("Waifu Scorer", "waifu_scorer"),
343
- ("Aesthetic V2.5", "aesthetic_predictor_v2_5"),
344
- ("Anime Score", "anime_aesthetic")
345
- ]
346
 
347
- with gr.Blocks(theme=gr.themes.Soft(), title="Image Evaluation Tool") as demo:
 
 
 
 
348
  gr.Markdown("""
349
- # 🎨 Advanced Image Evaluation Tool
350
-
351
- Evaluate images using state-of-the-art aesthetic and quality prediction models.
352
- Upload your images and select the models you want to use for evaluation.
353
  """)
354
-
 
 
 
 
 
 
 
 
 
355
  with gr.Row():
356
- with gr.Column(scale=1):
357
- input_files = gr.File(
358
- label="Upload Images",
359
- file_count="multiple",
360
- file_types=["image"]
361
- )
362
 
363
- model_checkboxes = gr.CheckboxGroup(
364
- choices=[label for label, _ in model_options],
365
- value=[label for label, _ in model_options],
366
- label="Select Models",
367
- info="Choose which models to use for evaluation"
368
- )
369
-
370
- with gr.Row():
371
- batch_size = gr.Slider(
372
- minimum=1,
373
- maximum=64,
374
- value=8,
375
- step=1,
376
- label="Batch Size",
377
- info="Number of images to process at once"
378
  )
 
 
 
379
 
380
- with gr.Row():
381
- evaluate_btn = gr.Button("🚀 Evaluate Images", variant="primary", scale=2)
382
- clear_btn = gr.Button("🗑️ Clear", variant="secondary", scale=1)
383
-
384
- with gr.Column(scale=2):
385
- progress = gr.Progress()
386
- logs = gr.Textbox(
387
- label="Processing Logs",
388
- lines=10,
389
- max_lines=10,
390
- autoscroll=True
391
- )
392
 
393
- results_df = gr.Dataframe(
 
 
 
 
394
  label="Evaluation Results",
395
- interactive=False,
396
- wrap=True
 
 
397
  )
398
-
399
- download_btn = gr.Button("📥 Download Results (CSV)", variant="secondary")
400
- download_file = gr.File(visible=False)
401
-
402
- # State for storing results
403
- results_state = gr.State([])
404
-
405
- async def process_images(files, selected_model_labels, batch_size, progress=gr.Progress()):
406
- """Process uploaded images"""
407
- if not files:
408
- return "Please upload images first", pd.DataFrame(), []
409
-
410
- # Convert labels to keys
411
- selected_models = [key for label, key in model_options if label in selected_model_labels]
412
-
413
- # Get file paths
414
- file_paths = [f.name for f in files]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
- # Progress callback
417
- def update_progress(value):
418
- progress(value / 100, desc=f"Processing images... {value:.0f}%")
 
 
419
 
420
- # Evaluate images
421
- results, logs = await evaluator.evaluate_images(
422
- file_paths,
423
- selected_models,
424
- batch_size,
425
- update_progress
426
  )
427
 
428
- # Create DataFrame
429
- df = evaluator.get_results_dataframe(selected_models)
430
-
431
- # Format logs
432
- log_text = "\n".join(logs[-10:]) # Show last 10 logs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
 
434
- return log_text, df, results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
- def update_results_on_model_change(selected_model_labels, results):
437
- """Update results when model selection changes"""
438
- if not results:
439
- return pd.DataFrame()
440
-
441
- # Convert labels to keys
442
- selected_models = [key for label, key in model_options if label in selected_model_labels]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
- # Recalculate final scores
445
- for result in results:
446
- result.calculate_final_score(selected_models)
447
 
448
- # Update evaluator results
449
- evaluator.results = results
450
 
451
- # Create updated DataFrame
452
- return evaluator.get_results_dataframe(selected_models)
453
-
454
- def clear_interface():
455
- """Clear all results"""
456
- return "", pd.DataFrame(), [], None
457
-
458
- def prepare_download(selected_model_labels, results):
459
- """Prepare CSV file for download"""
460
- if not results:
461
- return None
 
 
 
 
 
 
 
462
 
463
- # Convert labels to keys
464
- selected_models = [key for label, key in model_options if label in selected_model_labels]
 
 
 
 
 
 
 
 
 
 
 
 
465
 
466
- # Get DataFrame
467
- df = evaluator.get_results_dataframe(selected_models)
 
 
 
 
 
468
 
469
- # Save to temporary file
470
- import tempfile
471
- with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
472
- df.to_csv(f, index=False)
473
- return f.name
474
-
475
- # Event handlers
476
- evaluate_btn.click(
477
- fn=process_images,
478
- inputs=[input_files, model_checkboxes, batch_size],
479
- outputs=[logs, results_df, results_state]
480
- )
481
-
482
- model_checkboxes.change(
483
- fn=update_results_on_model_change,
484
- inputs=[model_checkboxes, results_state],
485
- outputs=[results_df]
486
  )
487
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  clear_btn.click(
489
- fn=clear_interface,
490
- outputs=[logs, results_df, results_state, download_file]
 
 
 
491
  )
492
 
493
- download_btn.click(
494
- fn=prepare_download,
495
- inputs=[model_checkboxes, results_state],
496
- outputs=[download_file]
497
  )
498
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  gr.Markdown("""
500
- ### 📝 Notes
501
- - **Model Selection**: Choose which models to use for evaluation. Final score is the average of selected models.
502
- - **Batch Size**: Adjust based on your GPU memory. Larger batches process faster.
503
- - **Results Table**: Click column headers to sort. The table updates automatically when models are selected/deselected.
504
- - **Download**: Export results as CSV for further analysis.
505
-
506
- ### 🎯 Score Interpretation
507
- - **7-10**: High quality/aesthetic appeal
508
- - **5-7**: Medium quality
509
- - **0-5**: Lower quality
510
  """)
511
-
512
  return demo
513
 
514
 
515
  if __name__ == "__main__":
516
- # Create and launch the interface
517
- demo = create_interface()
518
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import shutil
3
+ import tempfile
4
  import asyncio
5
+ from io import BytesIO, StringIO
6
+ import csv
7
  from pathlib import Path
8
  import logging
9
 
 
13
  import onnxruntime as rt
14
  from PIL import Image
15
  import gradio as gr
16
+ from transformers import pipeline, AutoProcessor, AutoModelForImageClassification
17
  from huggingface_hub import hf_hub_download
 
18
 
19
+ # Configure basic logging
20
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
21
  logger = logging.getLogger(__name__)
22
 
23
+ # --- Dependency: aesthetic_predictor_v2_5.py ---
24
+ # This file should exist in the same directory or be in PYTHONPATH.
25
+ # For demonstration, a stub is provided. Replace with actual implementation.
26
+ # aesthetic_predictor_v2_5.py STUB START
27
+ # (Normally this would be in its own file: aesthetic_predictor_v2_5.py)
28
+ def convert_v2_5_from_siglip(repo_id="unum-cloud/siglip-base-patch16-224-aesthetic-v2.5", low_cpu_mem_usage=True, trust_remote_code=True):
29
+ logger.info(f"Loading model and preprocessor from Hugging Face Hub: {repo_id}")
30
+ try:
31
+ # Attempt to load actual models if available and network permits
32
+ processor = AutoProcessor.from_pretrained(repo_id, low_cpu_mem_usage=low_cpu_mem_usage, trust_remote_code=trust_remote_code)
33
+ model = AutoModelForImageClassification.from_pretrained(repo_id, low_cpu_mem_usage=low_cpu_mem_usage, trust_remote_code=trust_remote_code)
34
+ logger.info("Successfully loaded model and preprocessor from Hugging Face Hub.")
35
+ except Exception as e:
36
+ logger.warning(f"Failed to load from {repo_id} due to: {e}. Using fallback mock objects.")
37
+ # Fallback to simpler mock objects if HF download fails or for offline use
38
+ class MockProcessor:
39
+ def __call__(self, images, return_tensors="pt"):
40
+ if isinstance(images, list):
41
+ num_images = len(images)
42
+ return {"pixel_values": torch.randn(num_images, 3, 224, 224)}
43
+ else:
44
+ return {"pixel_values": torch.randn(1, 3, 224, 224)}
45
+ class MockModel:
46
+ def __init__(self): self._parameters = {"dummy": torch.nn.Parameter(torch.empty(0))}
47
+ def __call__(self, pixel_values):
48
+ bs = pixel_values.shape[0]
49
+ class Output:
50
+ def __init__(self, logits_val): self.logits = logits_val
51
+ return Output(logits_val=torch.rand(bs, 1) * 10) # Simulate scores 0-10
52
+ def to(self, *args, **kwargs): return self
53
+ def cuda(self, *args, **kwargs): return self
54
+ def bfloat16(self, *args, **kwargs): return self
55
+ processor = MockProcessor()
56
+ model = MockModel()
57
+ logger.info("Using fallback mock model and preprocessor for Aesthetic Predictor V2.5.")
58
+ return model, processor
59
+ # aesthetic_predictor_v2_5.py STUB END
60
 
61
 
62
+ #####################################
63
+ # Model Definitions #
64
+ #####################################
65
+
66
+ class MLP(torch.nn.Module):
67
+ def __init__(self, input_size: int, batch_norm: bool = True):
68
+ super().__init__()
69
+ self.input_size = input_size
70
+ layers = [
71
+ torch.nn.Linear(self.input_size, 2048), torch.nn.ReLU(),
72
+ torch.nn.BatchNorm1d(2048) if batch_norm else torch.nn.Identity(), torch.nn.Dropout(0.3),
73
+ torch.nn.Linear(2048, 512), torch.nn.ReLU(),
74
+ torch.nn.BatchNorm1d(512) if batch_norm else torch.nn.Identity(), torch.nn.Dropout(0.3),
75
+ torch.nn.Linear(512, 256), torch.nn.ReLU(),
76
+ torch.nn.BatchNorm1d(256) if batch_norm else torch.nn.Identity(), torch.nn.Dropout(0.2),
77
+ torch.nn.Linear(256, 128), torch.nn.ReLU(),
78
+ torch.nn.BatchNorm1d(128) if batch_norm else torch.nn.Identity(), torch.nn.Dropout(0.1),
79
+ torch.nn.Linear(128, 32), torch.nn.ReLU(),
80
+ torch.nn.Linear(32, 1)
81
  ]
82
+ self.layers = torch.nn.Sequential(*layers)
83
 
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ return self.layers(x)
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ class WaifuScorer:
89
+ def __init__(self, model_path: str = None, device: str = 'cuda', cache_dir: str = None, verbose: bool = False):
90
+ self.verbose = verbose
91
+ self.device = device
92
+ self.dtype = torch.float32
93
+ self.available = False
94
+ self.clip_model = None
95
+ self.preprocess = None
96
+ self.mlp = None
97
 
 
 
 
 
 
 
 
 
98
  try:
99
+ import clip # Dynamically import clip
100
+ if model_path is None:
101
+ model_path = "Eugeoter/waifu-scorer-v3/model.pth"
102
+ if self.verbose: logger.info(f"WaifuScorer model path not provided. Using default: {model_path}")
103
+
104
+ if not Path(model_path).is_file():
105
+ try:
106
+ # Assuming model_path like "user/repo/file.pth" for hf_hub_download
107
+ parts = model_path.split("/")
108
+ if len(parts) >= 3:
109
+ repo_id_parts = parts[:-1]
110
+ filename = parts[-1]
111
+ repo_id_str = "/".join(repo_id_parts)
112
+ model_path_resolved = hf_hub_download(repo_id=repo_id_str, filename=filename, cache_dir=cache_dir)
113
+ else: # try as repo_id and assume model.pth or common name
114
+ model_path_resolved = hf_hub_download(repo_id=model_path, filename="model.pth", cache_dir=cache_dir) # fallback filename
115
+ except Exception as e:
116
+ logger.error(f"Failed to download WaifuScorer model from HF Hub ({model_path}): {e}")
117
+ # Try a more specific default if the generic one failed
118
+ logger.info("Attempting to download specific WaifuScorer model Eugeoter/waifu-scorer-v3/model.pth")
119
+ model_path_resolved = hf_hub_download("Eugeoter/waifu-scorer-v3", "model.pth", cache_dir=cache_dir)
120
+ model_path = model_path_resolved
121
+
122
+
123
+ if self.verbose: logger.info(f"Loading WaifuScorer model from: {model_path}")
124
+
125
+ self.mlp = MLP(input_size=768)
126
+ if str(model_path).endswith(".safetensors"):
127
+ from safetensors.torch import load_file
128
+ state_dict = load_file(model_path, device=device)
129
+ else:
130
+ state_dict = torch.load(model_path, map_location=device)
131
 
132
+ # Adjust keys if necessary (e.g. if saved from DataParallel)
133
+ if any(key.startswith("module.") for key in state_dict.keys()):
134
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
135
+
136
  self.mlp.load_state_dict(state_dict)
137
+ self.mlp.to(device=self.device, dtype=self.dtype)
138
+ self.mlp.eval()
139
+
140
  self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device)
141
  self.available = True
142
+ logger.info("WaifuScorer initialized successfully.")
143
+ except ImportError:
144
+ logger.error("OpenAI CLIP library not found. WaifuScorer will be unavailable. Please install with 'pip install openai-clip'")
145
  except Exception as e:
146
+ logger.error(f"Unable to initialize WaifuScorer: {e}")
147
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  @torch.no_grad()
149
+ def __call__(self, images: list[Image.Image]) -> list[float | None]:
150
  if not self.available:
151
  return [None] * len(images)
152
 
153
+ if not images:
154
+ return []
155
+
156
+ original_n = len(images)
157
+ # Handle single image case for CLIP if it has issues with batch_size=1 (some versions might)
158
+ processed_images = images if len(images) > 1 else images * 2
159
+
160
  try:
161
+ image_tensors = [self.preprocess(img).unsqueeze(0) for img in processed_images]
162
+ image_batch = torch.cat(image_tensors).to(self.device)
163
+ image_features = self.clip_model.encode_image(image_batch)
164
 
165
+ norm = image_features.norm(p=2, dim=-1, keepdim=True)
166
+ norm = torch.where(norm == 0, torch.tensor(1.0, device=norm.device, dtype=norm.dtype), norm) # Avoid division by zero
167
+ im_emb = (image_features / norm).to(device=self.device, dtype=self.dtype)
 
168
 
169
+ predictions = self.mlp(im_emb)
170
+ scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
171
+ return scores[:original_n]
172
  except Exception as e:
173
+ logger.error(f"Error during WaifuScorer prediction: {e}")
174
+ return [None] * original_n
175
 
176
 
177
+ class AestheticPredictorV2_5_Wrapper:
178
+ def __init__(self, device: str):
179
+ self.device = device
 
 
180
  self.model, self.preprocessor = convert_v2_5_from_siglip(
181
+ low_cpu_mem_usage=True, trust_remote_code=True
 
182
  )
183
+ if self.device == 'cuda' and torch.cuda.is_available():
184
  self.model = self.model.to(torch.bfloat16).cuda()
185
+ logger.info("Aesthetic Predictor V2.5 Wrapper initialized.")
186
+
187
  @torch.no_grad()
188
+ def inference(self, images: list[Image.Image]) -> list[float | None]:
189
+ if not images:
190
+ return []
191
  try:
192
  images_rgb = [img.convert("RGB") for img in images]
193
  pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt").pixel_values
194
+ if self.device == 'cuda' and torch.cuda.is_available():
 
195
  pixel_values = pixel_values.to(torch.bfloat16).cuda()
196
 
197
+ scores_tensor = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
198
+ if scores_tensor.ndim == 0: # Single image result
199
+ scores = [scores_tensor.item()]
200
+ else:
201
+ scores = scores_tensor.tolist()
202
+ return [round(max(0.0, min(s, 10.0)), 4) for s in scores] # Clip and round
203
  except Exception as e:
204
+ logger.error(f"Error during Aesthetic Predictor V2.5 inference: {e}")
205
  return [None] * len(images)
206
 
207
+ def load_anime_aesthetic_onnx_model(cache_dir: str = None) -> rt.InferenceSession | None:
208
+ try:
209
+ model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx", cache_dir=cache_dir)
210
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider']
211
+ session = rt.InferenceSession(model_path, providers=providers)
212
+ logger.info(f"Anime Aesthetic ONNX model loaded with providers: {session.get_providers()}")
213
+ return session
214
+ except Exception as e:
215
+ logger.error(f"Failed to load Anime Aesthetic ONNX model: {e}")
216
+ return None
217
 
218
+ def preprocess_anime_aesthetic_batch(images_pil: list[Image.Image], target_size: int = 768) -> np.ndarray | None:
219
+ if not images_pil:
220
+ return None
221
+ batch_canvases = []
222
+ try:
223
+ for img_pil in images_pil:
224
+ img_np = np.array(img_pil.convert("RGB")).astype(np.float32) / 255.0
225
+ h, w = img_np.shape[:2]
226
+ if h > w:
227
+ new_h, new_w = target_size, int(target_size * w / h)
228
+ else:
229
+ new_h, new_w = int(target_size * h / w), target_size
230
+
231
+ resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
232
+ canvas = np.zeros((target_size, target_size, 3), dtype=np.float32)
233
+ pad_h = (target_size - new_h) // 2
234
+ pad_w = (target_size - new_w) // 2
235
+ canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
236
+ batch_canvases.append(canvas)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ input_tensor_batch = np.array(batch_canvases, dtype=np.float32) # (N, H, W, C)
239
+ input_tensor_batch = np.transpose(input_tensor_batch, (0, 3, 1, 2)) # (N, C, H, W)
240
+ return input_tensor_batch
241
+ except Exception as e:
242
+ logger.error(f"Error during Anime Aesthetic preprocessing: {e}")
243
+ return None
244
 
245
+ #####################################
246
+ # Image Evaluation Tool #
247
+ #####################################
248
 
249
+ class ModelManager:
250
+ def __init__(self, cache_dir: str = None):
251
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
252
+ logger.info(f"Using device: {self.device}")
253
+ self.cache_dir = cache_dir
254
+ self.models = {}
255
+ self.model_configs = {}
256
+ self._load_all_models()
257
+
258
+ self.processing_queue: asyncio.Queue = asyncio.Queue()
259
+ self.worker_task = None
260
+ self._temp_files_to_clean = [] # For CSV files
261
+
262
+ def _load_all_models(self):
263
+ logger.info("Loading Aesthetic Shadow model...")
264
+ try:
265
+ self.models["aesthetic_shadow"] = pipeline("image-classification", model="NeoChen1024/aesthetic-shadow-v2-backup", device=0 if self.device == 'cuda' else -1)
266
+ self.model_configs["aesthetic_shadow"] = {"name": "Aesthetic Shadow", "process_func": self._process_aesthetic_shadow}
267
+ logger.info("Aesthetic Shadow model loaded.")
268
+ except Exception as e:
269
+ logger.error(f"Failed to load Aesthetic Shadow model: {e}")
270
+
271
+ logger.info("Loading Waifu Scorer model...")
272
+ try:
273
+ ws = WaifuScorer(device=self.device, cache_dir=self.cache_dir, verbose=True)
274
+ if ws.available:
275
+ self.models["waifu_scorer"] = ws
276
+ self.model_configs["waifu_scorer"] = {"name": "Waifu Scorer", "process_func": self._process_waifu_scorer}
277
+ logger.info("Waifu Scorer model loaded.")
278
+ else:
279
+ logger.warning("Waifu Scorer model is not available.")
280
+ except Exception as e:
281
+ logger.error(f"Failed to load Waifu Scorer model: {e}")
282
+
283
+ logger.info("Loading Aesthetic Predictor V2.5...")
284
+ try:
285
+ ap_v25 = AestheticPredictorV2_5_Wrapper(device=self.device)
286
+ self.models["aesthetic_predictor_v2_5"] = ap_v25
287
+ self.model_configs["aesthetic_predictor_v2_5"] = {"name": "Aesthetic V2.5", "process_func": self._process_aesthetic_predictor_v2_5}
288
+ logger.info("Aesthetic Predictor V2.5 loaded.")
289
+ except Exception as e:
290
+ logger.error(f"Failed to load Aesthetic Predictor V2.5: {e}")
291
+
292
+ logger.info("Loading Anime Aesthetic model...")
293
+ try:
294
+ aa_model = load_anime_aesthetic_onnx_model(cache_dir=self.cache_dir)
295
+ if aa_model:
296
+ self.models["anime_aesthetic"] = aa_model
297
+ self.model_configs["anime_aesthetic"] = {"name": "Anime Score", "process_func": self._process_anime_aesthetic}
298
+ logger.info("Anime Aesthetic model loaded.")
299
+ else:
300
+ logger.warning("Anime Aesthetic ONNX model failed to load and will be unavailable.")
301
+ except Exception as e:
302
+ logger.error(f"Failed to load Anime Aesthetic model: {e}")
303
 
304
+ logger.info(f"Available models for processing: {list(self.model_configs.keys())}")
305
+
306
+
307
+ async def start_worker_if_not_running(self):
308
+ if self.worker_task is None or self.worker_task.done():
309
+ self.worker_task = asyncio.create_task(self._worker())
310
+ logger.info("Async worker started.")
311
+
312
+ async def _worker(self):
313
+ while True:
314
+ request = await self.processing_queue.get()
315
+ if request is None: # Shutdown signal
316
+ self.processing_queue.task_done()
317
+ logger.info("Async worker received shutdown signal.")
318
+ break
319
+
320
+ future = request.get('future')
321
  try:
322
+ if request['type'] == 'run_evaluation_generator':
323
+ # The generator itself is created here and returned via future
324
+ # The Gradio callback will iterate over it
325
+ gen = self.run_evaluation_generator(**request['params'])
326
+ future.set_result(gen)
327
+ else:
328
+ logger.warning(f"Unknown request type in worker: {request.get('type')}")
329
+ if future: future.set_exception(ValueError("Unknown request type"))
330
  except Exception as e:
331
+ logger.error(f"Error in worker processing request: {e}", exc_info=True)
332
+ if future: future.set_exception(e)
333
+ finally:
334
+ self.processing_queue.task_done()
335
+
336
+ async def submit_evaluation_request(self, file_paths, auto_batch, manual_batch_size, selected_model_keys):
337
+ await self.start_worker_if_not_running()
338
+ future = asyncio.Future()
339
+ request_item = {
340
+ 'type': 'run_evaluation_generator',
341
+ 'params': {
342
+ 'file_paths': file_paths,
343
+ 'auto_batch': auto_batch,
344
+ 'manual_batch_size': manual_batch_size,
345
+ 'selected_model_keys': selected_model_keys,
346
+ },
347
+ 'future': future
348
+ }
349
+ await self.processing_queue.put(request_item)
350
+ return await future # Future resolves to the async generator
351
+
352
+ def auto_tune_batch_size(self, images: list[Image.Image], selected_model_keys: list[str]) -> int:
353
+ if not images or not selected_model_keys:
354
+ return 1
355
+
356
+ batch_size = 1
357
+ max_possible_batch = len(images)
358
+ test_image_pil = [images[0].copy()] # A list containing one PIL image, copy to avoid issues with transforms
359
+
360
+ logger.info(f"Auto-tuning batch size with selected models: {selected_model_keys}")
361
 
362
+ optimal_batch_size = 1
363
+ while batch_size <= max_possible_batch:
364
+ current_test_batch = test_image_pil * batch_size
 
365
  try:
366
+ logger.debug(f"Testing batch size: {batch_size}")
367
+ if "aesthetic_shadow" in selected_model_keys and "aesthetic_shadow" in self.models:
368
+ _ = self.models["aesthetic_shadow"](current_test_batch, batch_size=batch_size)
369
+ if "waifu_scorer" in selected_model_keys and "waifu_scorer" in self.models:
370
+ _ = self.models["waifu_scorer"](current_test_batch)
371
+ if "aesthetic_predictor_v2_5" in selected_model_keys and "aesthetic_predictor_v2_5" in self.models:
372
+ _ = self.models["aesthetic_predictor_v2_5"].inference(current_test_batch)
373
+ if "anime_aesthetic" in selected_model_keys and "anime_aesthetic" in self.models:
374
+ processed_input = preprocess_anime_aesthetic_batch(current_test_batch)
375
+ if processed_input is None: raise ValueError("Anime aesthetic preprocessing failed for test batch")
376
+ _ = self.models["anime_aesthetic"].run(None, {"img": processed_input})
377
+
378
+ optimal_batch_size = batch_size # This batch size worked
379
+ if batch_size * 2 > max_possible_batch : # If next step exceeds max, current is best fit
380
+ if max_possible_batch > batch_size: # Check if we can exactly fit max_possible_batch
381
+ # Test max_possible_batch one last time if it's > current batch_size and < batch_size*2
382
+ pass # Current optimal_batch_size is good, or we can check max_possible_batch specifically
383
+ break
384
+ batch_size *= 2
385
+
386
+ except Exception as e: # Typically torch.cuda.OutOfMemoryError or similar
387
+ logger.warning(f"Auto-tune failed at batch size {batch_size} for at least one model: {e}")
388
+ break # Current optimal_batch_size is the largest that worked before this failure
389
 
390
+ # Cap the batch size for very large numbers of images / powerful GPUs
391
+ final_optimal_batch = min(optimal_batch_size, max_possible_batch, 64)
392
+ logger.info(f"Optimal batch size determined: {final_optimal_batch}")
393
+ return max(1, final_optimal_batch)
394
+
395
+
396
+ async def run_evaluation_generator(self, file_paths: list[str], auto_batch: bool,
397
+ manual_batch_size: int, selected_model_keys: list[str]):
398
 
399
+ log_messages = []
400
+ def _log(msg):
401
+ log_messages.append(msg)
402
+ logger.info(msg)
403
+
404
+ _log("Starting image evaluation...")
405
+ yield {"type": "log_update", "messages": log_messages[-20:]} # Show last 20 logs
406
+ yield {"type": "progress", "value": 0.0, "desc": "Initiating..."}
407
+
408
+ images_pil = []
409
+ file_names = []
410
+ for f_path_str in file_paths:
411
+ try:
412
+ p = Path(f_path_str)
413
+ img = Image.open(p).convert("RGB")
414
+ images_pil.append(img)
415
+ file_names.append(p.name)
416
+ _log(f"Loaded image: {p.name}")
417
+ except Exception as e:
418
+ _log(f"Error opening {f_path_str}: {e}")
419
 
420
+ yield {"type": "log_update", "messages": log_messages[-20:]}
421
+
422
+ if not images_pil:
423
+ _log("No valid images loaded. Aborting.")
424
+ yield {"type": "log_update", "messages": log_messages[-20:]}
425
+ yield {"type": "progress", "value": 1.0, "desc": "No images loaded"}
426
+ yield {"type": "final_results_state", "data": []} # ensure state is empty
427
+ return
428
+
429
+ actual_batch_size = 1
430
+ if auto_batch:
431
+ _log("Auto-tuning batch size...")
432
+ yield {"type": "log_update", "messages": log_messages[-20:]}
433
+ yield {"type": "progress", "value": 0.05, "desc": "Auto-tuning batch size..."}
434
+ actual_batch_size = self.auto_tune_batch_size(images_pil, selected_model_keys)
435
+ _log(f"Auto-detected batch size: {actual_batch_size}")
436
+ else:
437
+ actual_batch_size = int(manual_batch_size) if manual_batch_size > 0 else 1
438
+ _log(f"Using manual batch size: {actual_batch_size}")
439
 
440
+ yield {"type": "batch_size_update", "value": actual_batch_size}
441
+ yield {"type": "log_update", "messages": log_messages[-20:]}
442
+
443
+ all_results_for_state = [] # Full data for gr.State
444
+ dataframe_rows_so_far = [] # Data for gr.DataFrame (PIL images, strings, numbers)
445
+
446
+ total_images = len(images_pil)
447
+ processed_count = 0
448
+
449
+ for i in range(0, total_images, actual_batch_size):
450
+ batch_images_pil = images_pil[i:i+actual_batch_size]
451
+ batch_file_names = file_names[i:i+actual_batch_size]
452
+ num_in_batch = len(batch_images_pil)
453
+ _log(f"Processing batch {i//actual_batch_size + 1}/{ (total_images + actual_batch_size -1) // actual_batch_size }: images {i+1} to {i+num_in_batch}")
454
+ yield {"type": "log_update", "messages": log_messages[-20:]}
455
+
456
+ batch_model_scores = {key: [None] * num_in_batch for key in self.model_configs.keys()}
457
+
458
+ for model_key in selected_model_keys:
459
+ if model_key in self.models and model_key in self.model_configs:
460
+ _log(f" Running {self.model_configs[model_key]['name']} for batch...")
461
+ yield {"type": "log_update", "messages": log_messages[-20:]}
462
+ try:
463
+ scores = await self.model_configs[model_key]['process_func'](batch_images_pil)
464
+ batch_model_scores[model_key] = scores
465
+ _log(f" {self.model_configs[model_key]['name']} scores: {scores}")
466
+ except Exception as e:
467
+ _log(f" Error processing batch with {self.model_configs[model_key]['name']}: {e}")
468
+ batch_model_scores[model_key] = [None] * num_in_batch # Ensure it's list of Nones
469
+ yield {"type": "log_update", "messages": log_messages[-20:]}
470
 
471
+ # Assemble results for this batch
472
+ current_batch_df_rows = []
473
+ for j in range(num_in_batch):
474
+ result_item_state = {'file_name': batch_file_names[j]} # For gr.State
 
 
475
 
476
+ # For DataFrame: [PIL.Image, filename, score1, score2, ..., final_score]
477
+ thumbnail = batch_images_pil[j].copy()
478
+ thumbnail.thumbnail((150, 150)) # Create thumbnail
479
+ result_item_df_row = [thumbnail, batch_file_names[j]]
480
+
481
+
482
+ current_image_scores = []
483
+ for model_key in self.model_configs.keys(): # Iterate in defined order for consistency
484
+ score = batch_model_scores[model_key][j]
485
+ result_item_state[model_key] = score # For gr.State
486
+ if model_key in selected_model_keys: # Only add to DF if selected
487
+ result_item_df_row.append(f"{score:.4f}" if isinstance(score, (float, int)) else "N/A")
488
+ if isinstance(score, (float, int)) and model_key in selected_model_keys:
489
+ current_image_scores.append(score)
490
 
491
+ final_score = None
492
+ if current_image_scores:
493
+ final_score_val = float(np.mean([s for s in current_image_scores if s is not None]))
494
+ final_score = float(np.clip(final_score_val, 0.0, 10.0))
495
+
496
+ result_item_state['final_score'] = final_score
497
+ result_item_df_row.append(f"{final_score:.4f}" if final_score is not None else "N/A")
498
+
499
+ all_results_for_state.append(result_item_state)
500
+ current_batch_df_rows.append(result_item_df_row)
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
+ dataframe_rows_so_far.extend(current_batch_df_rows)
 
 
 
 
503
 
504
+ processed_count += num_in_batch
505
+ progress_value = processed_count / total_images
506
+ yield {"type": "partial_results_df_rows", "data": dataframe_rows_so_far, "selected_model_keys": selected_model_keys}
507
+ yield {"type": "progress", "value": progress_value, "desc": f"Processed {processed_count}/{total_images}"}
508
+
509
+ _log("All images processed.")
510
+ yield {"type": "log_update", "messages": log_messages[-20:]}
511
+ yield {"type": "progress", "value": 1.0, "desc": "Completed!"}
512
+ yield {"type": "final_results_state", "data": all_results_for_state}
513
+
514
+
515
+ async def _process_aesthetic_shadow(self, batch_images: list[Image.Image]) -> list[float | None]:
516
+ model = self.models.get("aesthetic_shadow")
517
+ if not model: return [None] * len(batch_images)
518
+ results = model(batch_images, batch_size=len(batch_images)) # Assuming pipeline can take batch_size hint
519
+ scores = []
520
+ for res_group in results: # Results might be List[List[Dict]] or List[Dict]
521
+ # Handle both single image and batch results from pipeline
522
+ current_res_list = res_group if isinstance(res_group, list) else [res_group]
523
+ try:
524
+ hq_score_item = next(p for p in current_res_list if p['label'] == 'hq')
525
+ score = float(np.clip(hq_score_item['score'] * 10.0, 0.0, 10.0))
526
+ except (StopIteration, KeyError, TypeError):
527
+ score = None
528
+ scores.append(score)
529
+ return scores
530
+
531
+ async def _process_waifu_scorer(self, batch_images: list[Image.Image]) -> list[float | None]:
532
+ model = self.models.get("waifu_scorer")
533
+ if not model: return [None] * len(batch_images)
534
+ raw_scores = model(batch_images)
535
+ return [float(np.clip(s, 0.0, 10.0)) if s is not None else None for s in raw_scores]
536
+
537
+ async def _process_aesthetic_predictor_v2_5(self, batch_images: list[Image.Image]) -> list[float | None]:
538
+ model = self.models.get("aesthetic_predictor_v2_5")
539
+ if not model: return [None] * len(batch_images)
540
+ # Already returns clipped & rounded scores or Nones
541
+ return model.inference(batch_images)
542
+
543
+ async def _process_anime_aesthetic(self, batch_images: list[Image.Image]) -> list[float | None]:
544
+ model = self.models.get("anime_aesthetic")
545
+ if not model: return [None] * len(batch_images)
546
 
547
+ input_data = preprocess_anime_aesthetic_batch(batch_images)
548
+ if input_data is None:
549
+ return [None] * len(batch_images)
550
+
551
+ try:
552
+ preds = model.run(None, {"img": input_data})[0] # Assuming output is (N, 1) or (N,)
553
+ scores = [float(np.clip(p.item() * 10.0, 0.0, 10.0)) for p in preds]
554
+ return scores
555
+ except Exception as e:
556
+ logger.error(f"Error during Anime Aesthetic ONNX prediction: {e}")
557
+ return [None] * len(batch_images)
558
 
559
+ def add_temp_file_for_cleanup(self, file_path: str):
560
+ self._temp_files_to_clean.append(file_path)
561
+
562
+ async def shutdown_worker(self):
563
+ if self.worker_task and not self.worker_task.done():
564
+ logger.info("Attempting to shutdown worker...")
565
+ await self.processing_queue.put(None) # Send shutdown signal
566
+ try:
567
+ await asyncio.wait_for(self.worker_task, timeout=5.0)
568
+ logger.info("Worker task finished.")
569
+ except asyncio.TimeoutError:
570
+ logger.warning("Worker task did not finish in time. Cancelling...")
571
+ self.worker_task.cancel()
572
+ except Exception as e:
573
+ logger.error(f"Exception during worker shutdown: {e}")
574
+ await self.processing_queue.join() # Wait for queue to be fully processed
575
+ logger.info("Processing queue joined.")
576
+ self.worker_task = None
577
+
578
+
579
+ def cleanup(self):
580
+ logger.info("Running cleanup...")
581
+ # Shut down asyncio worker
582
+ if self.worker_task:
583
+ # If running in a context where an event loop is already running
584
+ if asyncio.get_event_loop().is_running():
585
+ asyncio.create_task(self.shutdown_worker()) # schedule it
586
+ else: # If no loop, run it
587
+ try:
588
+ asyncio.run(self.shutdown_worker())
589
+ except RuntimeError as e: # Handles "cannot be called when another loop is running"
590
+ logger.error(f"RuntimeError during cleanup's shutdown_worker: {e}. May need manual loop management.")
591
+
592
+ # Clean up temporary CSV files
593
+ for f_path in self_temp_files_to_clean:
594
+ try:
595
+ os.remove(f_path)
596
+ logger.info(f"Removed temp file: {f_path}")
597
+ except OSError as e:
598
+ logger.error(f"Error removing temp file {f_path}: {e}")
599
+ self._temp_files_to_clean.clear()
600
+ logger.info("Cleanup finished.")
601
+
602
+
603
+ #####################################
604
+ # Interface #
605
+ #####################################
606
+
607
+ # Initialize ModelManager once
608
+ model_manager = ModelManager(cache_dir=".model_cache")
609
 
610
  def create_interface():
611
+ # Define model choices based on ModelManager's loaded models
612
+ # Filter out models that failed to load
613
+ AVAILABLE_MODEL_KEYS = [k for k in model_manager.model_configs.keys() if k in model_manager.models]
614
+ AVAILABLE_MODEL_NAMES_MAP = {k: model_manager.model_configs[k]['name'] for k in AVAILABLE_MODEL_KEYS}
 
 
 
 
 
 
615
 
616
+ # [(display_name, value_key), ...] for CheckboxGroup
617
+ MODEL_CHOICES_FOR_CHECKBOX = [(AVAILABLE_MODEL_NAMES_MAP[k], k) for k in AVAILABLE_MODEL_KEYS]
618
+
619
+
620
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
621
  gr.Markdown("""
622
+ # Comprehensive Image Evaluation Tool (Refactored)
623
+ Upload images to evaluate them using multiple aesthetic and quality prediction models.
624
+ Results are displayed in a sortable table with image previews.
 
625
  """)
626
+
627
+ # Stores full processing results (list of dicts)
628
+ # Dict keys: 'file_name', 'final_score', and all model_keys with their scores
629
+ # This state is the source of truth for regenerating table and CSV
630
+ results_state = gr.State([])
631
+ # Stores current list of selected model keys (e.g., ['waifu_scorer', 'anime_aesthetic'])
632
+ selected_models_state = gr.State(AVAILABLE_MODEL_KEYS)
633
+ # Stores current log messages as a list
634
+ log_messages_state = gr.State([])
635
+
636
  with gr.Row():
637
+ with gr.Column(scale=1): # Inputs
638
+ input_images = gr.Files(label="Upload Images", file_count="multiple", type="filepath")
 
 
 
 
639
 
640
+ if not MODEL_CHOICES_FOR_CHECKBOX:
641
+ gr.Markdown("## No models loaded successfully. Please check logs.")
642
+ model_checkboxes = None # No models, no checkbox
643
+ else:
644
+ model_checkboxes = gr.CheckboxGroup(
645
+ choices=MODEL_CHOICES_FOR_CHECKBOX,
646
+ label="Select Models",
647
+ value=AVAILABLE_MODEL_KEYS, # Default to all available selected
648
+ info="Choose models for evaluation. Final score is an average of selected model scores."
 
 
 
 
 
 
649
  )
650
+
651
+ auto_batch_checkbox = gr.Checkbox(label="Automatic Batch Size Detection", value=True)
652
+ batch_size_input = gr.Number(label="Manual Batch Size", value=8, minimum=1, precision=0, interactive=False) # Interactive based on auto_batch_checkbox
653
 
654
+ process_btn = gr.Button("Evaluate Images", variant="primary", interactive=bool(MODEL_CHOICES_FOR_CHECKBOX))
655
+ clear_btn = gr.Button("Clear Results")
656
+ download_csv_btn = gr.Button("Download Results as CSV", variant="secondary")
657
+
658
+ with gr.Column(scale=3): # Outputs
659
+ progress_tracker = gr.Progress(label="Processing Progress")
660
+ log_output = gr.Textbox(label="Logs", lines=10, max_lines=20, interactive=False, autoscroll=True)
 
 
 
 
 
661
 
662
+ # Initial headers for DataFrame; will be updated dynamically
663
+ initial_df_headers = ['Image', 'File Name'] + [AVAILABLE_MODEL_NAMES_MAP[k] for k in AVAILABLE_MODEL_KEYS] + ['Final Score']
664
+ results_dataframe = gr.DataFrame(
665
+ headers=initial_df_headers,
666
+ datatype=['pil'] + ['str'] * (len(initial_df_headers) -1) , # Image + strings for scores
667
  label="Evaluation Results",
668
+ interactive=True, # Enables sorting by clicking headers
669
+ row_count=(10, "dynamic"),
670
+ col_count=(len(initial_df_headers), "fixed"),
671
+ wrap=True,
672
  )
673
+ # Hidden file component for download trigger
674
+ download_file_provider = gr.File(label="Download Link", visible=False)
675
+
676
+ # --- Callback Functions ---
677
+ def update_batch_size_interactive(auto_detect_enabled: bool):
678
+ return gr.Number.update(interactive=not auto_detect_enabled)
679
+
680
+ async def handle_process_images_ui(
681
+ files_list: list[gr. rýchle.TempFile] | None, # Gradio File objects
682
+ auto_batch_flag: bool,
683
+ manual_batch_val: int,
684
+ selected_model_keys_from_ui: list[str],
685
+ # Gradio will pass the gr.Progress instance automatically by type hinting
686
+ # Ensure the name 'progress_tracker_instance' matches an output component if you want to update it by dict key
687
+ # Otherwise, use the positional argument `progress`
688
+ progress_instance: gr.Progress
689
+ ):
690
+ if not files_list:
691
+ yield {
692
+ log_output: "No files uploaded. Please select images first.",
693
+ progress_tracker: gr.Progress(0.0, "Idle. No files."),
694
+ results_dataframe: gr.DataFrame.update(value=None), # Clear table
695
+ results_state: [],
696
+ selected_models_state: selected_model_keys_from_ui,
697
+ log_messages_state: ["No files uploaded. Please select images first."]
698
+ }
699
+ return
700
+
701
+ # Update selected_models_state right away
702
+ yield { selected_models_state: selected_model_keys_from_ui, log_messages_state: [] } # Clear logs state
703
+
704
+ # Convert Gradio TempFile objects to string paths
705
+ actual_file_paths = [f.name for f in files_list]
706
 
707
+ current_log_list = [] # Local log accumulator for this run
708
+
709
+ # Call the ModelManager's generator
710
+ # The progress_instance is implicitly passed by Gradio to this function
711
+ # The ModelManager generator will then use it via its own parameter `progress_tracker_instance`
712
 
713
+ # Need to get an async generator from model_manager
714
+ evaluation_generator = await model_manager.submit_evaluation_request(
715
+ actual_file_paths, auto_batch_flag, manual_batch_val, selected_model_keys_from_ui
 
 
 
716
  )
717
 
718
+ dataframe_update_value = None
719
+ final_results_for_app_state = []
720
+
721
+ async for event in evaluation_generator:
722
+ outputs_to_yield = {}
723
+ if event["type"] == "log_update":
724
+ current_log_list = event["messages"]
725
+ outputs_to_yield[log_output] = "\n".join(current_log_list)
726
+ elif event["type"] == "progress":
727
+ # Update progress bar directly using the passed instance
728
+ progress_instance(event["value"], desc=event.get("desc"))
729
+ elif event["type"] == "batch_size_update":
730
+ outputs_to_yield[batch_size_input] = gr.Number.update(value=event["value"])
731
+ elif event["type"] == "partial_results_df_rows":
732
+ # data is list of lists for DataFrame rows
733
+ # selected_model_keys used to generate current headers
734
+ dynamic_headers = ['Image', 'File Name'] + \
735
+ [AVAILABLE_MODEL_NAMES_MAP[k] for k in event["selected_model_keys"] if k in AVAILABLE_MODEL_NAMES_MAP] + \
736
+ ['Final Score']
737
+ dataframe_update_value = pd.DataFrame(event["data"], columns=dynamic_headers) if event["data"] else None
738
+ outputs_to_yield[results_dataframe] = gr.DataFrame.update(value=dataframe_update_value, headers=dynamic_headers)
739
+
740
+ elif event["type"] == "final_results_state":
741
+ final_results_for_app_state = event["data"]
742
+
743
+ if outputs_to_yield: # Only yield if there's something to update
744
+ yield outputs_to_yield
745
 
746
+ # Final updates after generator is exhausted
747
+ yield {
748
+ results_state: final_results_for_app_state,
749
+ log_messages_state: current_log_list, # Save final logs
750
+ # DataFrame should be up-to-date from the last partial_results_df_rows
751
+ }
752
+
753
+
754
+ def handle_clear_results_ui():
755
+ # Clear files, logs, table, progress, and internal states
756
+ return {
757
+ input_images: None,
758
+ log_output: "Results cleared.",
759
+ results_dataframe: gr.DataFrame.update(value=None, headers=initial_df_headers), # Reset with initial headers
760
+ progress_tracker: gr.Progress(0.0, "Idle"),
761
+ results_state: [],
762
+ # selected_models_state: AVAILABLE_MODEL_KEYS, # Optionally reset model selection
763
+ batch_size_input: gr.Number.update(value=8), # Reset batch size
764
+ log_messages_state: ["Results cleared."]
765
+ }
766
 
767
+ # Function to re-render DataFrame and update states when model selection changes
768
+ def handle_model_selection_or_state_change_ui(
769
+ current_selected_keys: list[str],
770
+ current_full_results: list[dict]
771
+ ):
772
+ if not current_full_results: # No data to process
773
+ dynamic_headers = ['Image', 'File Name'] + \
774
+ [AVAILABLE_MODEL_NAMES_MAP[k] for k in current_selected_keys if k in AVAILABLE_MODEL_NAMES_MAP] + \
775
+ ['Final Score']
776
+ return {
777
+ results_dataframe: gr.DataFrame.update(value=None, headers=dynamic_headers),
778
+ selected_models_state: current_selected_keys,
779
+ results_state: current_full_results # pass through if empty
780
+ }
781
+
782
+ new_df_rows = []
783
+ updated_full_results = []
784
+
785
+ for res_item_dict in current_full_results:
786
+ # Recalculate final score based on new selection
787
+ scores_to_avg = []
788
+ for mk in current_selected_keys:
789
+ if mk in res_item_dict and isinstance(res_item_dict[mk], (float, int)):
790
+ scores_to_avg.append(res_item_dict[mk])
791
+
792
+ new_final_score = None
793
+ if scores_to_avg:
794
+ new_final_score_val = float(np.mean(scores_to_avg))
795
+ new_final_score = float(np.clip(new_final_score_val, 0.0, 10.0))
796
+
797
+ # Update the item in results_state
798
+ res_item_dict['final_score'] = new_final_score
799
+ updated_full_results.append(res_item_dict.copy()) # Store updated item
800
+
801
+ # Prepare row for DataFrame
802
+ # Find the corresponding image (this assumes images are not stored in results_state, which they aren't)
803
+ # For simplicity, we'll need to re-generate thumbnails if we want them in this update path.
804
+ # A robust way: results_state stores paths or minimal data to re-fetch/re-create thumbnails.
805
+ # Current implementation of `run_evaluation_generator` directly yields DF rows with PIL images.
806
+ # If `handle_model_selection_change_ui` is to re-generate the DF from `results_state`,
807
+ # `results_state` items would need to include enough info for `Image.open` and `thumbnail`.
808
+ # This is a complex part if we want perfect dynamic DF regeneration with images.
809
+ # For now, let's assume `results_state` stores `PIL.Image` thumbnails if this path is critical.
810
+ # The `run_evaluation_generator` stores dicts without PIL image objects in `all_results_for_state`.
811
+ # This means `handle_model_selection_change_ui` cannot easily reconstruct the 'Image' column.
812
+ #
813
+ # SIMPLIFICATION: When model selection changes, we only update scores in the existing DataFrame
814
+ # if possible, or we re-calculate and re-populate. The current code path re-creates rows.
815
+ # To do this properly, `results_state` items should perhaps include original image path or cached thumbnail.
816
+ #
817
+ # Let's make results_state store {'file_path': ..., 'thumbnail_pil': ..., scores...}
818
+ # This needs `run_evaluation_generator` to save file_path and thumbnail_pil to `all_results_for_state`.
819
+ # Assume `results_state` items now contain 'thumbnail_pil' and other scores.
820
+
821
+ # If 'thumbnail_pil' is not in res_item_dict (because it wasn't saved that way), this will fail.
822
+ # This path requires results_state to contain PIL image data for the 'Image' column.
823
+ # The current 'run_evaluation_generator' does not save PIL images into `all_results_for_state`.
824
+ # It only creates them for immediate DataFrame update.
825
+ # This function needs to be re-thought if full DF reconstruction with images is needed here.
826
+
827
+ # Let's assume results_state IS NOT used to rebuild the image column.
828
+ # The change handler for model_checkboxes will mostly affect the *calculation* of final_score
829
+ # and *visibility* of columns if we were dynamically adding/removing them.
830
+ # Gradio's DataFrame doesn't easily hide/show columns; we change headers and data.
831
+
832
+ # Rebuild row for DF:
833
+ df_row = [res_item_dict.get('thumbnail_pil_placeholder', "N/A"), res_item_dict['file_name']]
834
+ for mk_cfg in AVAILABLE_MODEL_KEYS: # All possible models to maintain column order
835
+ if mk_cfg in current_selected_keys: # If this model is currently selected for display
836
+ score = res_item_dict.get(mk_cfg)
837
+ df_row.append(f"{score:.4f}" if isinstance(score, (float, int)) else "N/A")
838
+ # If not selected, this column won't even be in dynamic_headers.
839
+ df_row.append(f"{new_final_score:.4f}" if new_final_score is not None else "N/A")
840
+ new_df_rows.append(df_row)
841
 
842
+ dynamic_headers = ['Image', 'File Name'] + \
843
+ [AVAILABLE_MODEL_NAMES_MAP[k] for k in current_selected_keys if k in AVAILABLE_MODEL_NAMES_MAP] + \
844
+ ['Final Score']
845
 
846
+ import pandas as pd
847
+ df_value = pd.DataFrame(new_df_rows, columns=dynamic_headers) if new_df_rows else None
848
 
849
+ return {
850
+ results_dataframe: gr.DataFrame.update(value=df_value, headers=dynamic_headers),
851
+ selected_models_state: current_selected_keys, # Persist the new selection
852
+ results_state: updated_full_results # Persist updated scores
853
+ }
854
+
855
+
856
+ def handle_download_csv_ui(current_full_results: list[dict], current_selected_keys: list[str]):
857
+ if not current_full_results:
858
+ # Optionally, send a message to log_output if desired using yield
859
+ return gr.File.update(value=None, visible=False)
860
+
861
+ # Use StringIO to build CSV in memory
862
+ csv_output = StringIO()
863
+ # Define fieldnames: Filename, selected model scores, Final Score
864
+ fieldnames = ['File Name'] + \
865
+ [AVAILABLE_MODEL_NAMES_MAP[k] for k in current_selected_keys if k in AVAILABLE_MODEL_NAMES_MAP] + \
866
+ ['Final Score']
867
 
868
+ writer = csv.DictWriter(csv_output, fieldnames=fieldnames, extrasaction='ignore')
869
+ writer.writeheader()
870
+
871
+ for res_item in current_full_results:
872
+ row_to_write = {'File Name': res_item['file_name']}
873
+ final_score_val = res_item.get('final_score') # This should be up-to-date from results_state
874
+ row_to_write['Final Score'] = f"{final_score_val:.4f}" if final_score_val is not None else "N/A"
875
+
876
+ for key in current_selected_keys:
877
+ if key in AVAILABLE_MODEL_NAMES_MAP: # ensure it's a valid model key
878
+ model_display_name = AVAILABLE_MODEL_NAMES_MAP[key]
879
+ score_val = res_item.get(key)
880
+ row_to_write[model_display_name] = f"{score_val:.4f}" if isinstance(score_val, (float, int)) else "N/A"
881
+ writer.writerow(row_to_write)
882
 
883
+ csv_content = csv_output.getvalue()
884
+ csv_output.close()
885
+
886
+ # Save to a temporary file that Gradio can serve
887
+ with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv", encoding='utf-8') as tmp_file:
888
+ tmp_file.write(csv_content)
889
+ temp_file_path = tmp_file.name
890
 
891
+ model_manager.add_temp_file_for_cleanup(temp_file_path) # Register for cleanup
892
+
893
+ return gr.File.update(value=temp_file_path, visible=True, label="results.csv")
894
+
895
+
896
+ # --- Wire up components ---
897
+ auto_batch_checkbox.change(
898
+ fn=update_batch_size_interactive,
899
+ inputs=[auto_batch_checkbox],
900
+ outputs=[batch_size_input]
 
 
 
 
 
 
 
901
  )
902
+
903
+ # Check if model_checkboxes exists (i.e., models loaded)
904
+ if model_checkboxes:
905
+ process_btn.click(
906
+ fn=handle_process_images_ui,
907
+ inputs=[input_images, auto_batch_checkbox, batch_size_input, model_checkboxes],
908
+ outputs=[
909
+ log_output, progress_tracker, results_dataframe, batch_size_input,
910
+ results_state, selected_models_state, log_messages_state # Ensure all yielded components are listed
911
+ ]
912
+ )
913
+ # When model selection changes, update the displayed table and internal states
914
+ model_checkboxes.change(
915
+ fn=handle_model_selection_or_state_change_ui,
916
+ inputs=[model_checkboxes, results_state], # Takes current selection and full results data
917
+ outputs=[results_dataframe, selected_models_state, results_state]
918
+ )
919
+
920
  clear_btn.click(
921
+ fn=handle_clear_results_ui,
922
+ outputs=[
923
+ input_images, log_output, results_dataframe, progress_tracker,
924
+ results_state, batch_size_input, log_messages_state # model_checkboxes could be reset too if needed
925
+ ]
926
  )
927
 
928
+ download_csv_btn.click(
929
+ fn=handle_download_csv_ui,
930
+ inputs=[results_state, selected_models_state], # Use current results and selected models for CSV
931
+ outputs=[download_file_provider]
932
  )
933
+
934
+ # Initial setup on demo load
935
+ async def initial_load_setup():
936
+ await model_manager.start_worker_if_not_running() # Start async worker
937
+ # Set initial state for selected_models_state based on default checkbox values
938
+ # This is a bit of a workaround if direct binding isn't available for initial state from component value
939
+ return {selected_models_state: AVAILABLE_MODEL_KEYS, log_messages_state: ["Application loaded. Ready."]}
940
+
941
+ demo.load(
942
+ fn=initial_load_setup,
943
+ outputs=[selected_models_state, log_messages_state]
944
+ )
945
+ # Register cleanup function
946
+ demo.unload(model_manager.cleanup)
947
+
948
+
949
  gr.Markdown("""
950
+ ### Notes
951
+ - **Model Selection**: Dynamically choose models for evaluation. The 'Final Score' and displayed columns update accordingly.
952
+ - **Native Table**: Results are shown in a native Gradio DataFrame, allowing sorting by clicking column headers.
953
+ - **Batching**: Automatic batch size detection is enabled by default. You can switch to manual batch sizing.
954
+ - **CSV Export**: Download the current results (respecting selected models for columns) as a CSV file.
955
+ - **Asynchronous Processing**: Image evaluation runs in the background, providing live updates for logs and progress.
 
 
 
 
956
  """)
 
957
  return demo
958
 
959
 
960
  if __name__ == "__main__":
961
+ # For proper MLP Safetensor loading with adjusted keys, ensure 'safetensors' is installed.
962
+ # For WaifuScorer, ensure 'openai-clip' is installed.
963
+ # For ONNX models, 'onnxruntime' or 'onnxruntime-gpu'.
964
+ # For general Hugging Face model loading, 'transformers'.
965
+ # OpenCV ('cv2') for image manipulation: 'opencv-python'.
966
+ # And of course 'torch', 'numpy', 'Pillow', 'gradio'.
967
+
968
+ # Create a dummy aesthetic_predictor_v2_5.py if it doesn't exist for the stub to work
969
+ # (or ensure the real one is present)
970
+ if not Path("aesthetic_predictor_v2_5.py").exists():
971
+ stub_content = """
972
+ # Placeholder for aesthetic_predictor_v2_5.py
973
+ # This file needs to contain the actual 'convert_v2_5_from_siglip' function.
974
+ # The main script uses a basic stub if this file is missing or fails to import.
975
+ # print("aesthetic_predictor_v2_5.py placeholder executed")
976
+ def convert_v2_5_from_siglip(*args, **kwargs):
977
+ raise NotImplementedError("This is a placeholder. Implement convert_v2_5_from_siglip here or ensure the main script's stub is used.")
978
+ """
979
+ # Only write if you are sure, or better, let user handle this dependency.
980
+ # For this exercise, we assume the main script's internal stub is sufficient if the file is missing.
981
+ pass
982
+
983
+
984
+ # It's important that the ModelManager is initialized before create_interface() is called,
985
+ # as create_interface() relies on model_manager.model_configs.
986
+ # This is already handled by placing `model_manager = ModelManager()` globally.
987
+
988
+ app_interface = create_interface()
989
+ app_interface.queue().launch(debug=True, share=False) # Enable queue for async operations
990
+
991
+ # Ensure cleanup is called on exit if demo.unload isn't fully effective in all environments
992
+ import atexit
993
+ atexit.register(model_manager.cleanup)