Update app.py
Browse files
app.py
CHANGED
@@ -1,710 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
-
import shutil
|
3 |
import tempfile
|
4 |
-
import
|
5 |
-
import
|
6 |
-
from
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
import cv2
|
|
|
9 |
import numpy as np
|
|
|
10 |
import torch
|
11 |
-
import
|
12 |
from PIL import Image
|
13 |
-
import gradio as gr
|
14 |
from transformers import pipeline
|
15 |
from huggingface_hub import hf_hub_download
|
16 |
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
super().__init__()
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
52 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
53 |
-
return self.
|
54 |
|
55 |
|
56 |
-
class
|
57 |
-
"""WaifuScorer
|
58 |
-
|
59 |
-
|
|
|
60 |
self.device = device
|
61 |
self.dtype = torch.float32
|
62 |
-
self.
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
try:
|
65 |
-
import clip
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
username, repo_id, model_name = model_path.split("/")[-3:]
|
75 |
-
model_path = hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=cache_dir)
|
76 |
-
|
77 |
-
if self.verbose:
|
78 |
-
print(f"Loading WaifuScorer model from: {model_path}")
|
79 |
-
|
80 |
-
# Initialize MLP model
|
81 |
-
self.mlp = MLP(input_size=768)
|
82 |
-
# Load state dict
|
83 |
if model_path.endswith(".safetensors"):
|
84 |
from safetensors.torch import load_file
|
85 |
state_dict = load_file(model_path)
|
86 |
else:
|
87 |
-
state_dict = torch.load(model_path, map_location=device)
|
88 |
-
|
89 |
-
self.
|
90 |
-
self.
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
except Exception as e:
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
def
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
def
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
low_cpu_mem_usage=True,
|
132 |
trust_remote_code=True,
|
133 |
)
|
|
|
134 |
if torch.cuda.is_available():
|
135 |
-
self.
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
class ModelManager:
|
191 |
-
"""
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
self.
|
200 |
-
|
201 |
-
self.
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
}
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
while True:
|
226 |
-
request = await self.
|
227 |
-
|
228 |
-
|
229 |
break
|
|
|
230 |
try:
|
231 |
-
|
232 |
-
request['
|
233 |
except Exception as e:
|
234 |
-
request['
|
235 |
finally:
|
236 |
-
self.
|
237 |
-
|
238 |
-
async def
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
file_paths = request['file_paths']
|
|
|
248 |
auto_batch = request['auto_batch']
|
249 |
manual_batch_size = request['manual_batch_size']
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
images = []
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
log_events.append(f"Starting to load {total_files} images...")
|
259 |
-
for f in file_paths:
|
260 |
try:
|
261 |
-
img = Image.open(
|
262 |
images.append(img)
|
263 |
-
|
264 |
except Exception as e:
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
total_images = len(images)
|
283 |
-
for i in range(0, total_images, optimal_batch):
|
284 |
-
batch_images = images[i:i+optimal_batch]
|
285 |
-
batch_file_names = file_names[i:i+optimal_batch]
|
286 |
-
batch_index = i // optimal_batch + 1
|
287 |
-
log_events.append(f"Processing batch {batch_index}: images {i+1} to {min(i+optimal_batch, total_images)}")
|
288 |
-
|
289 |
-
batch_results = {}
|
290 |
-
|
291 |
-
# Process selected models
|
292 |
-
for model_key in selected_models:
|
293 |
-
if self.available_models[model_key]['selected']: # Ensure model is selected
|
294 |
-
batch_results[model_key] = await self.available_models[model_key]['process'](batch_images, log_events) # Removed 'self' here
|
295 |
-
else:
|
296 |
-
batch_results[model_key] = [None] * len(batch_images)
|
297 |
-
|
298 |
-
# Combine results and create final results list
|
299 |
-
for j in range(len(batch_images)):
|
300 |
-
scores_to_average = []
|
301 |
-
for model_key in selected_models:
|
302 |
-
if self.available_models[model_key]['selected']: # Ensure model is selected
|
303 |
-
score = batch_results[model_key][j]
|
304 |
-
if score is not None:
|
305 |
-
scores_to_average.append(score)
|
306 |
-
|
307 |
-
final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None
|
308 |
-
thumbnail = batch_images[j].copy()
|
309 |
-
thumbnail.thumbnail((200, 200))
|
310 |
-
result = {
|
311 |
-
'file_name': batch_file_names[j],
|
312 |
-
'img_data': self.image_to_base64(thumbnail), # Keep this for the HTML display
|
313 |
-
'final_score': final_score,
|
314 |
-
}
|
315 |
-
for model_key in selected_models: # Add model scores to result
|
316 |
-
if self.available_models[model_key]['selected']:
|
317 |
-
result[model_key] = batch_results[model_key][j]
|
318 |
-
final_results.append(result)
|
319 |
-
|
320 |
-
log_events.append("All images processed.")
|
321 |
-
return final_results, log_events, 100, optimal_batch
|
322 |
-
|
323 |
-
|
324 |
-
def image_to_base64(self, image: Image.Image) -> str:
|
325 |
-
"""Convert PIL Image to base64 encoded JPEG string."""
|
326 |
-
buffered = BytesIO()
|
327 |
-
image.save(buffered, format="JPEG")
|
328 |
-
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
329 |
-
|
330 |
-
def auto_tune_batch_size(self, images: list) -> int:
|
331 |
-
"""Automatically determine the optimal batch size for processing."""
|
332 |
batch_size = 1
|
333 |
-
max_batch = len(images)
|
334 |
test_image = images[0:1]
|
335 |
-
|
|
|
336 |
try:
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
_ = self.available_models["aesthetic_predictor_v2_5"]['model'].inference(test_image * batch_size)
|
343 |
batch_size *= 2
|
344 |
-
if batch_size > max_batch:
|
345 |
-
break
|
346 |
except Exception:
|
347 |
break
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
def sort_results(self, results, sort_by: str = "Final Score") -> list:
|
421 |
-
"""Sort results based on the specified column."""
|
422 |
-
key_map = {
|
423 |
-
"Final Score": "final_score",
|
424 |
-
"File Name": "file_name",
|
425 |
-
"Aesthetic Shadow": "aesthetic_shadow",
|
426 |
-
"Waifu Scorer": "waifu_scorer",
|
427 |
-
"Aesthetic V2.5": "aesthetic_predictor_v2_5",
|
428 |
-
"Anime Score": "anime_aesthetic"
|
429 |
}
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
<style>
|
439 |
-
.results-table {
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
.results-table
|
444 |
-
|
445 |
-
|
446 |
-
.
|
447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
</style>
|
449 |
-
<table class="results-table">
|
450 |
-
<thead>
|
451 |
-
<tr>
|
452 |
-
<th>Image</th>
|
453 |
-
<th>File Name</th>
|
454 |
"""
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
visible_models.append("anime_aesthetic")
|
468 |
-
table_html += "<th>Final Score</th>"
|
469 |
-
table_html += "</tr></thead><tbody>"
|
470 |
-
|
471 |
for result in results:
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
score_class = ""
|
488 |
-
if isinstance(score, (int, float)):
|
489 |
-
if score >= 7:
|
490 |
-
score_class = "good-score"
|
491 |
-
elif score >= 5:
|
492 |
-
score_class = "medium-score"
|
493 |
-
else:
|
494 |
-
score_class = "bad-score"
|
495 |
-
return f'<td class="{score_class}">{score_str}</td>'
|
496 |
-
|
497 |
-
|
498 |
-
def cleanup(self):
|
499 |
-
"""Clean up temporary directories and shutdown worker."""
|
500 |
-
if os.path.exists(self.temp_dir):
|
501 |
-
shutil.rmtree(self.temp_dir)
|
502 |
-
if self.worker_task is not None: # Check if worker_task was started
|
503 |
-
asyncio.run(self.shutdown()) # Shutdown worker gracefully
|
504 |
-
|
505 |
-
async def shutdown(self):
|
506 |
-
"""Send shutdown signal to worker and wait for it to finish."""
|
507 |
-
if self.worker_task is not None: # Check if worker_task was started
|
508 |
-
await self.processing_queue.put(None) # Send shutdown signal
|
509 |
-
await self.worker_task # Wait for worker task to complete
|
510 |
-
await self.processing_queue.join() # Wait for queue to be empty
|
511 |
-
|
512 |
-
|
513 |
-
#####################################
|
514 |
-
# Interface #
|
515 |
-
#####################################
|
516 |
-
|
517 |
-
model_manager = ModelManager() # Initialize ModelManager once outside the interface function
|
518 |
-
|
519 |
-
def create_interface():
|
520 |
-
sort_options = ["Final Score", "File Name", "Aesthetic Shadow", "Waifu Scorer", "Aesthetic V2.5", "Anime Score"]
|
521 |
-
model_options = ["aesthetic_shadow", "waifu_scorer", "aesthetic_predictor_v2_5", "anime_aesthetic"]
|
522 |
-
|
523 |
-
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
524 |
-
gr.Markdown("""
|
525 |
-
# Comprehensive Image Evaluation Tool
|
526 |
-
|
527 |
-
Upload images to evaluate them using multiple aesthetic and quality prediction models.
|
528 |
-
|
529 |
-
**New features:**
|
530 |
-
- **Dynamic Final Score:** Final score recalculates on model selection changes.
|
531 |
-
- **Model Selection:** Choose which models to use for evaluation.
|
532 |
-
- **Dynamic Table Updates:** Table updates automatically based on model selection.
|
533 |
-
- **Automatic Sorting:** Table is automatically sorted by 'Final Score'.
|
534 |
-
- **Detailed Logs:** See major processing events (limited to the last 10).
|
535 |
-
- **Progress Bar:** Visual indication of processing status.
|
536 |
-
- **Asynchronous Updates:** Streaming status and logs during processing.
|
537 |
-
- **Batch Size Controls:** Choose manual batch size or let the tool auto-detect it.
|
538 |
-
- **Download Results:** Export the evaluation results as CSV.
|
539 |
-
""")
|
540 |
-
|
541 |
-
with gr.Row():
|
542 |
-
with gr.Column(scale=1):
|
543 |
-
input_images = gr.Files(label="Upload Images", file_count="multiple")
|
544 |
-
model_checkboxes = gr.CheckboxGroup(model_options, label="Select Models", value=model_options, info="Choose models for evaluation.")
|
545 |
-
auto_batch_checkbox = gr.Checkbox(label="Automatic Batch Size Detection", value=False, info="Enable to automatically determine the optimal batch size.")
|
546 |
-
batch_size_input = gr.Number(label="Batch Size", value=1, interactive=True, info="Manually specify the batch size if auto-detection is disabled.")
|
547 |
-
sort_dropdown = gr.Dropdown(sort_options, value="Final Score", label="Sort by", info="Select the column to sort results by.")
|
548 |
-
process_btn = gr.Button("Evaluate Images", variant="primary")
|
549 |
-
clear_btn = gr.Button("Clear Results")
|
550 |
-
download_csv = gr.Button("Download CSV", variant="secondary")
|
551 |
-
|
552 |
-
with gr.Column(scale=2):
|
553 |
-
progress_bar = gr.HTML(label="Progress Bar", value="""
|
554 |
-
<div style='width:100%;background-color:#ddd;'>
|
555 |
-
<div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div>
|
556 |
-
</div>
|
557 |
-
""")
|
558 |
-
log_window = gr.HTML(label="Detailed Logs", value="<div style='max-height:300px; overflow-y:auto;'>Logs will appear here...</div>")
|
559 |
-
status_html = gr.HTML(label="Status")
|
560 |
-
output_html = gr.HTML(label="Evaluation Results")
|
561 |
-
download_file_output = gr.File() # Initialize gr.File component without filename
|
562 |
-
global_results_state = gr.State([]) # Initialize a global state to hold results
|
563 |
-
|
564 |
-
# Function to convert results to CSV format, excluding 'img_data'.
|
565 |
-
def results_to_csv(results, selected_models): # Take results as input
|
566 |
-
import csv
|
567 |
-
import io
|
568 |
-
if not results:
|
569 |
-
return None # Return None when no results are available
|
570 |
-
output = io.StringIO()
|
571 |
-
fieldnames = ['file_name', 'final_score'] # Base fieldnames
|
572 |
-
for model_key in selected_models: # Add selected model names as fieldnames
|
573 |
-
if model_key in selected_models: # Double check if model_key is indeed in selected_models list
|
574 |
-
fieldnames.append(model_key)
|
575 |
-
|
576 |
-
writer = csv.DictWriter(output, fieldnames=fieldnames)
|
577 |
-
writer.writeheader()
|
578 |
-
for res in results:
|
579 |
-
row_dict = {'file_name': res['file_name'], 'final_score': res['final_score']} # Base data
|
580 |
-
for model_key in selected_models: # Add selected model scores
|
581 |
-
if model_key in selected_models: # Double check before accessing res[model_key]
|
582 |
-
row_dict[model_key] = res.get(model_key, 'N/A') # Use get with default 'N/A' if model not in result (shouldn't happen but for safety)
|
583 |
-
writer.writerow(row_dict)
|
584 |
-
return output.getvalue()
|
585 |
-
|
586 |
-
|
587 |
-
def update_batch_size_interactivity(auto_batch):
|
588 |
-
return gr.update(interactive=not auto_batch)
|
589 |
-
|
590 |
-
async def process_images_and_update(files, auto_batch, manual_batch, selected_models, current_results):
|
591 |
-
file_paths = [f.name for f in files]
|
592 |
-
|
593 |
-
# Prepare request data for the ModelManager
|
594 |
-
request_data = {
|
595 |
-
'file_paths': file_paths,
|
596 |
-
'auto_batch': auto_batch,
|
597 |
-
'manual_batch_size': manual_batch,
|
598 |
-
'selected_models': {model: {'selected': model in selected_models} for model in model_options} # Pass model selections
|
599 |
-
}
|
600 |
-
# Submit request and get results from ModelManager
|
601 |
-
results, logs, progress_percent, updated_batch = await model_manager.submit_request(request_data)
|
602 |
-
|
603 |
-
updated_results = current_results + results # Append new results to current results
|
604 |
-
|
605 |
-
html_table = model_manager.generate_html_table(updated_results, selected_models)
|
606 |
-
progress_html = model_manager._generate_progress_html(progress_percent)
|
607 |
-
log_html = model_manager._format_logs(logs[-10:])
|
608 |
-
|
609 |
-
return status_html, html_table, log_html, progress_html, gr.update(value=updated_batch, interactive=not auto_batch), updated_results
|
610 |
-
|
611 |
-
|
612 |
-
def update_table_sort(sort_by_column, selected_models, current_results):
|
613 |
-
sorted_results = model_manager.sort_results(current_results, sort_by_column)
|
614 |
-
return model_manager.generate_html_table(sorted_results, selected_models), sorted_results # Return sorted results
|
615 |
-
|
616 |
-
def update_table_model_selection(selected_models, current_results):
|
617 |
-
# Recalculate final scores based on selected models
|
618 |
-
for result in current_results:
|
619 |
-
scores_to_average = []
|
620 |
-
for model_key in model_options: # Use model_options here, not available_models from manager in UI context
|
621 |
-
if model_key in selected_models and model_key in model_manager.available_models and model_manager.available_models[model_key]['selected']: # consider only selected models from checkbox group and available_models
|
622 |
-
score = result.get(model_key)
|
623 |
-
if score is not None:
|
624 |
-
scores_to_average.append(score)
|
625 |
-
final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None
|
626 |
-
result['final_score'] = final_score
|
627 |
-
|
628 |
-
sorted_results = model_manager.sort_results(current_results, "Final Score") # Keep sorting by Final Score when models change
|
629 |
-
return model_manager.generate_html_table(sorted_results, selected_models), sorted_results
|
630 |
-
|
631 |
-
|
632 |
-
def clear_results():
|
633 |
-
return (gr.update(value=""),
|
634 |
-
gr.update(value=""),
|
635 |
-
gr.update(value=""),
|
636 |
-
gr.update(value="""
|
637 |
-
<div style='width:100%;background-color:#ddd;'>
|
638 |
-
<div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div>
|
639 |
-
</div>
|
640 |
-
"""),
|
641 |
-
gr.update(value=1),
|
642 |
-
[]) # Clear results state
|
643 |
-
|
644 |
-
def download_results_csv_trigger(selected_models, current_results): # Changed function name to avoid conflict and clarify purpose
|
645 |
-
csv_content = results_to_csv(current_results, selected_models)
|
646 |
-
if csv_content is None:
|
647 |
-
return None # Indicate no file to download
|
648 |
-
|
649 |
-
# Create a temporary file to save the CSV data
|
650 |
-
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file:
|
651 |
-
tmp_file.write(csv_content.encode())
|
652 |
-
temp_file_path = tmp_file.name # Get the path to the temporary file
|
653 |
-
|
654 |
-
return temp_file_path # Return the path to the temporary file
|
655 |
-
|
656 |
-
|
657 |
-
# Set initial selection state for models in ModelManager (important!)
|
658 |
-
for model_key in model_options:
|
659 |
-
model_manager.available_models[model_key]['selected'] = True # Default to all selected initially
|
660 |
-
|
661 |
-
auto_batch_checkbox.change(
|
662 |
-
update_batch_size_interactivity,
|
663 |
-
inputs=[auto_batch_checkbox],
|
664 |
-
outputs=[batch_size_input]
|
665 |
-
)
|
666 |
-
|
667 |
-
process_btn.click(
|
668 |
-
process_images_and_update,
|
669 |
-
inputs=[input_images, auto_batch_checkbox, batch_size_input, model_checkboxes, global_results_state],
|
670 |
-
outputs=[status_html, output_html, log_window, progress_bar, batch_size_input, global_results_state]
|
671 |
-
)
|
672 |
-
sort_dropdown.change(
|
673 |
-
update_table_sort,
|
674 |
-
inputs=[sort_dropdown, model_checkboxes, global_results_state],
|
675 |
-
outputs=[output_html, global_results_state]
|
676 |
-
)
|
677 |
-
model_checkboxes.change( # Added change event for model checkboxes
|
678 |
-
update_table_model_selection,
|
679 |
-
inputs=[model_checkboxes, global_results_state],
|
680 |
-
outputs=[output_html, global_results_state]
|
681 |
-
)
|
682 |
-
clear_btn.click(
|
683 |
-
clear_results,
|
684 |
-
inputs=[],
|
685 |
-
outputs=[status_html, output_html, log_window, progress_bar, batch_size_input, global_results_state]
|
686 |
-
)
|
687 |
-
download_csv.click(
|
688 |
-
download_results_csv_trigger, # Call the trigger function
|
689 |
-
inputs=[model_checkboxes, global_results_state],
|
690 |
-
outputs=[download_file_output] # Output is now the gr.File component
|
691 |
-
)
|
692 |
-
demo.load(lambda: update_table_sort("Final Score", model_options, []), inputs=None, outputs=[output_html, global_results_state]) # Initial sort and table render, pass empty initial results
|
693 |
-
demo.load(model_manager.start_worker) # Start the worker task on demo load
|
694 |
-
|
695 |
-
gr.Markdown("""
|
696 |
-
### Notes
|
697 |
-
- Select models to use for evaluation using the checkboxes.
|
698 |
-
- The 'Final Score' recalculates dynamically when models are selected/deselected.
|
699 |
-
- The table updates automatically when models are selected/deselected and is always sorted by 'Final Score'.
|
700 |
-
- The log window displays the most recent 10 events.
|
701 |
-
- The progress bar shows overall processing status.
|
702 |
-
- When 'Automatic Batch Size Detection' is enabled, the batch size field becomes disabled.
|
703 |
-
- Use the download button to export your evaluation results as CSV.
|
704 |
-
""")
|
705 |
-
|
706 |
-
return demo
|
707 |
-
|
708 |
-
if __name__ == "__main__":
|
709 |
-
demo = create_interface()
|
710 |
-
demo.queue().launch()
|
|
|
1 |
+
"""
|
2 |
+
Modern Image Evaluation Tool with Aesthetic and Quality Prediction Models
|
3 |
+
|
4 |
+
This refactored version features:
|
5 |
+
- Modern async/await patterns with proper error handling
|
6 |
+
- Type hints throughout for better code maintainability
|
7 |
+
- Dependency injection and factory patterns
|
8 |
+
- Proper resource management with context managers
|
9 |
+
- Configuration-driven model loading
|
10 |
+
- Improved batch processing with memory optimization
|
11 |
+
- Clean separation of concerns with proper abstraction layers
|
12 |
+
"""
|
13 |
+
|
14 |
+
import asyncio
|
15 |
+
import base64
|
16 |
+
import csv
|
17 |
+
import logging
|
18 |
import os
|
|
|
19 |
import tempfile
|
20 |
+
import shutil
|
21 |
+
from contextlib import asynccontextmanager
|
22 |
+
from dataclasses import dataclass, field
|
23 |
+
from enum import Enum
|
24 |
+
from io import BytesIO, StringIO
|
25 |
+
from pathlib import Path
|
26 |
+
from typing import Dict, List, Optional, Protocol, Tuple, Union, Any
|
27 |
+
from abc import ABC, abstractmethod
|
28 |
|
29 |
import cv2
|
30 |
+
import gradio as gr
|
31 |
import numpy as np
|
32 |
+
import onnxruntime as ort
|
33 |
import torch
|
34 |
+
import torch.nn as nn
|
35 |
from PIL import Image
|
|
|
36 |
from transformers import pipeline
|
37 |
from huggingface_hub import hf_hub_download
|
38 |
|
39 |
+
# Configure logging
|
40 |
+
logging.basicConfig(level=logging.INFO)
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
|
44 |
+
# =============================================================================
|
45 |
+
# Configuration and Data Models
|
46 |
+
# =============================================================================
|
47 |
+
|
48 |
+
class ModelType(Enum):
|
49 |
+
"""Enumeration of available model types."""
|
50 |
+
AESTHETIC_SHADOW = "aesthetic_shadow"
|
51 |
+
WAIFU_SCORER = "waifu_scorer"
|
52 |
+
AESTHETIC_PREDICTOR_V2_5 = "aesthetic_predictor_v2_5"
|
53 |
+
ANIME_AESTHETIC = "anime_aesthetic"
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class ModelConfig:
|
58 |
+
"""Configuration for individual models."""
|
59 |
+
name: str
|
60 |
+
display_name: str
|
61 |
+
enabled: bool = True
|
62 |
+
batch_supported: bool = True
|
63 |
+
model_path: Optional[str] = None
|
64 |
+
cache_dir: Optional[str] = None
|
65 |
+
|
66 |
+
|
67 |
+
@dataclass
|
68 |
+
class ProcessingConfig:
|
69 |
+
"""Configuration for processing parameters."""
|
70 |
+
auto_batch: bool = False
|
71 |
+
manual_batch_size: int = 1
|
72 |
+
max_batch_size: int = 64
|
73 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
74 |
+
score_range: Tuple[float, float] = (0.0, 10.0)
|
75 |
+
|
76 |
+
|
77 |
+
@dataclass
|
78 |
+
class EvaluationResult:
|
79 |
+
"""Data class for individual evaluation results."""
|
80 |
+
file_name: str
|
81 |
+
file_path: str
|
82 |
+
thumbnail_b64: str
|
83 |
+
model_scores: Dict[str, Optional[float]] = field(default_factory=dict)
|
84 |
+
final_score: Optional[float] = None
|
85 |
+
processing_time: float = 0.0
|
86 |
+
error: Optional[str] = None
|
87 |
+
|
88 |
+
|
89 |
+
@dataclass
|
90 |
+
class BatchResult:
|
91 |
+
"""Data class for batch processing results."""
|
92 |
+
results: List[EvaluationResult]
|
93 |
+
logs: List[str]
|
94 |
+
processing_time: float
|
95 |
+
batch_size_used: int
|
96 |
+
success_count: int
|
97 |
+
error_count: int
|
98 |
+
|
99 |
+
|
100 |
+
# =============================================================================
|
101 |
+
# Model Interfaces and Implementations
|
102 |
+
# =============================================================================
|
103 |
+
|
104 |
+
class BaseModel(Protocol):
|
105 |
+
"""Protocol defining the interface for all evaluation models."""
|
106 |
+
|
107 |
+
async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
|
108 |
+
"""Predict scores for a batch of images."""
|
109 |
+
...
|
110 |
+
|
111 |
+
def is_available(self) -> bool:
|
112 |
+
"""Check if the model is available and ready for inference."""
|
113 |
+
...
|
114 |
+
|
115 |
+
def cleanup(self) -> None:
|
116 |
+
"""Clean up model resources."""
|
117 |
+
...
|
118 |
+
|
119 |
+
|
120 |
+
class ModernMLP(nn.Module):
|
121 |
+
"""Modern implementation of MLP with improved architecture."""
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
input_size: int,
|
126 |
+
hidden_dims: List[int] = None,
|
127 |
+
dropout_rates: List[float] = None,
|
128 |
+
use_batch_norm: bool = True,
|
129 |
+
activation: nn.Module = nn.ReLU
|
130 |
+
):
|
131 |
super().__init__()
|
132 |
+
|
133 |
+
if hidden_dims is None:
|
134 |
+
hidden_dims = [2048, 512, 256, 128, 32]
|
135 |
+
if dropout_rates is None:
|
136 |
+
dropout_rates = [0.3, 0.3, 0.2, 0.1, 0.0]
|
137 |
+
|
138 |
+
layers = []
|
139 |
+
prev_dim = input_size
|
140 |
+
|
141 |
+
for i, (hidden_dim, dropout_rate) in enumerate(zip(hidden_dims, dropout_rates)):
|
142 |
+
layers.append(nn.Linear(prev_dim, hidden_dim))
|
143 |
+
layers.append(activation())
|
144 |
+
|
145 |
+
if use_batch_norm and i < len(hidden_dims) - 1:
|
146 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
147 |
+
|
148 |
+
if dropout_rate > 0:
|
149 |
+
layers.append(nn.Dropout(dropout_rate))
|
150 |
+
|
151 |
+
prev_dim = hidden_dim
|
152 |
+
|
153 |
+
# Final output layer
|
154 |
+
layers.append(nn.Linear(prev_dim, 1))
|
155 |
+
self.network = nn.Sequential(*layers)
|
156 |
+
|
157 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
158 |
+
return self.network(x)
|
159 |
|
160 |
|
161 |
+
class WaifuScorerModel:
|
162 |
+
"""Modernized WaifuScorer implementation with better error handling."""
|
163 |
+
|
164 |
+
def __init__(self, config: ModelConfig, device: str):
|
165 |
+
self.config = config
|
166 |
self.device = device
|
167 |
self.dtype = torch.float32
|
168 |
+
self._available = False
|
169 |
+
self._model = None
|
170 |
+
self._clip_model = None
|
171 |
+
self._preprocess = None
|
172 |
+
|
173 |
+
self._initialize_model()
|
174 |
+
|
175 |
+
def _initialize_model(self) -> None:
|
176 |
+
"""Initialize the model with proper error handling."""
|
177 |
try:
|
178 |
+
import clip
|
179 |
+
|
180 |
+
# Download model if needed
|
181 |
+
model_path = self._get_model_path()
|
182 |
+
|
183 |
+
# Initialize MLP
|
184 |
+
self._model = ModernMLP(input_size=768)
|
185 |
+
|
186 |
+
# Load weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
if model_path.endswith(".safetensors"):
|
188 |
from safetensors.torch import load_file
|
189 |
state_dict = load_file(model_path)
|
190 |
else:
|
191 |
+
state_dict = torch.load(model_path, map_location=self.device)
|
192 |
+
|
193 |
+
self._model.load_state_dict(state_dict)
|
194 |
+
self._model.to(self.device)
|
195 |
+
self._model.eval()
|
196 |
+
|
197 |
+
# Load CLIP model
|
198 |
+
self._clip_model, self._preprocess = clip.load("ViT-L/14", device=self.device)
|
199 |
+
self._available = True
|
200 |
+
|
201 |
+
logger.info(f"WaifuScorer model loaded successfully on {self.device}")
|
202 |
+
|
203 |
+
except Exception as e:
|
204 |
+
logger.error(f"Failed to initialize WaifuScorer: {e}")
|
205 |
+
self._available = False
|
206 |
+
|
207 |
+
def _get_model_path(self) -> str:
|
208 |
+
"""Get or download the model path."""
|
209 |
+
if self.config.model_path and os.path.isfile(self.config.model_path):
|
210 |
+
return self.config.model_path
|
211 |
+
|
212 |
+
# Default download path
|
213 |
+
model_path = "Eugeoter/waifu-scorer-v3/model.pth"
|
214 |
+
username, repo_id, model_name = model_path.split("/")[-3:]
|
215 |
+
return hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=self.config.cache_dir)
|
216 |
+
|
217 |
+
async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
|
218 |
+
"""Predict scores for a batch of images."""
|
219 |
+
if not self._available:
|
220 |
+
return [None] * len(images)
|
221 |
+
|
222 |
+
try:
|
223 |
+
# Handle single image case for CLIP compatibility
|
224 |
+
batch_images = images * 2 if len(images) == 1 else images
|
225 |
+
|
226 |
+
# Preprocess images
|
227 |
+
image_tensors = [self._preprocess(img).unsqueeze(0) for img in batch_images]
|
228 |
+
image_batch = torch.cat(image_tensors).to(self.device)
|
229 |
+
|
230 |
+
# Extract features and predict
|
231 |
+
with torch.no_grad():
|
232 |
+
image_features = self._clip_model.encode_image(image_batch)
|
233 |
+
# Normalize features
|
234 |
+
norm = image_features.norm(2, dim=-1, keepdim=True)
|
235 |
+
norm[norm == 0] = 1
|
236 |
+
normalized_features = (image_features / norm).to(device=self.device, dtype=self.dtype)
|
237 |
+
|
238 |
+
predictions = self._model(normalized_features)
|
239 |
+
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
|
240 |
+
|
241 |
+
return scores[:len(images)]
|
242 |
+
|
243 |
+
except Exception as e:
|
244 |
+
logger.error(f"Error in WaifuScorer prediction: {e}")
|
245 |
+
return [None] * len(images)
|
246 |
+
|
247 |
+
def is_available(self) -> bool:
|
248 |
+
return self._available
|
249 |
+
|
250 |
+
def cleanup(self) -> None:
|
251 |
+
"""Clean up model resources."""
|
252 |
+
if self._model is not None:
|
253 |
+
del self._model
|
254 |
+
if self._clip_model is not None:
|
255 |
+
del self._clip_model
|
256 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
257 |
+
|
258 |
+
|
259 |
+
class AestheticShadowModel:
|
260 |
+
"""Wrapper for Aesthetic Shadow model using transformers pipeline."""
|
261 |
+
|
262 |
+
def __init__(self, config: ModelConfig, device: str):
|
263 |
+
self.config = config
|
264 |
+
self.device = device
|
265 |
+
self._available = False
|
266 |
+
self._model = None
|
267 |
+
|
268 |
+
self._initialize_model()
|
269 |
+
|
270 |
+
def _initialize_model(self) -> None:
|
271 |
+
"""Initialize the model pipeline."""
|
272 |
+
try:
|
273 |
+
self._model = pipeline(
|
274 |
+
"image-classification",
|
275 |
+
model="NeoChen1024/aesthetic-shadow-v2-backup",
|
276 |
+
device=self.device
|
277 |
+
)
|
278 |
+
self._available = True
|
279 |
+
logger.info("Aesthetic Shadow model loaded successfully")
|
280 |
+
|
281 |
except Exception as e:
|
282 |
+
logger.error(f"Failed to initialize Aesthetic Shadow: {e}")
|
283 |
+
self._available = False
|
284 |
+
|
285 |
+
async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
|
286 |
+
"""Predict scores for a batch of images."""
|
287 |
+
if not self._available:
|
288 |
+
return [None] * len(images)
|
289 |
+
|
290 |
+
try:
|
291 |
+
results = self._model(images)
|
292 |
+
scores = []
|
293 |
+
|
294 |
+
for result in results:
|
295 |
+
try:
|
296 |
+
hq_score = next(p for p in result if p['label'] == 'hq')['score']
|
297 |
+
score = float(np.clip(hq_score * 10.0, 0.0, 10.0))
|
298 |
+
scores.append(score)
|
299 |
+
except (StopIteration, KeyError, TypeError):
|
300 |
+
scores.append(None)
|
301 |
+
|
302 |
+
return scores
|
303 |
+
|
304 |
+
except Exception as e:
|
305 |
+
logger.error(f"Error in Aesthetic Shadow prediction: {e}")
|
306 |
+
return [None] * len(images)
|
307 |
+
|
308 |
+
def is_available(self) -> bool:
|
309 |
+
return self._available
|
310 |
+
|
311 |
+
def cleanup(self) -> None:
|
312 |
+
if self._model is not None:
|
313 |
+
del self._model
|
314 |
+
|
315 |
+
|
316 |
+
class AestheticPredictorV25Model:
|
317 |
+
"""Wrapper for Aesthetic Predictor V2.5 model."""
|
318 |
+
|
319 |
+
def __init__(self, config: ModelConfig, device: str):
|
320 |
+
self.config = config
|
321 |
+
self.device = device
|
322 |
+
self._available = False
|
323 |
+
self._model = None
|
324 |
+
self._preprocessor = None
|
325 |
+
|
326 |
+
self._initialize_model()
|
327 |
+
|
328 |
+
def _initialize_model(self) -> None:
|
329 |
+
"""Initialize the model."""
|
330 |
+
try:
|
331 |
+
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
|
332 |
+
|
333 |
+
self._model, self._preprocessor = convert_v2_5_from_siglip(
|
334 |
low_cpu_mem_usage=True,
|
335 |
trust_remote_code=True,
|
336 |
)
|
337 |
+
|
338 |
if torch.cuda.is_available():
|
339 |
+
self._model = self._model.to(torch.bfloat16).cuda()
|
340 |
+
|
341 |
+
self._available = True
|
342 |
+
logger.info("Aesthetic Predictor V2.5 loaded successfully")
|
343 |
+
|
344 |
+
except Exception as e:
|
345 |
+
logger.error(f"Failed to initialize Aesthetic Predictor V2.5: {e}")
|
346 |
+
self._available = False
|
347 |
+
|
348 |
+
async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
|
349 |
+
"""Predict scores for a batch of images."""
|
350 |
+
if not self._available:
|
351 |
+
return [None] * len(images)
|
352 |
+
|
353 |
+
try:
|
354 |
+
rgb_images = [img.convert("RGB") for img in images]
|
355 |
+
pixel_values = self._preprocessor(images=rgb_images, return_tensors="pt").pixel_values
|
356 |
+
|
357 |
+
if torch.cuda.is_available():
|
358 |
+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
359 |
+
|
360 |
+
with torch.inference_mode():
|
361 |
+
scores = self._model(pixel_values).logits.squeeze().float().cpu().numpy()
|
362 |
+
|
363 |
+
if scores.ndim == 0:
|
364 |
+
scores = np.array([scores])
|
365 |
+
|
366 |
+
return [float(np.round(np.clip(s, 0.0, 10.0), 4)) for s in scores]
|
367 |
+
|
368 |
+
except Exception as e:
|
369 |
+
logger.error(f"Error in Aesthetic Predictor V2.5 prediction: {e}")
|
370 |
+
return [None] * len(images)
|
371 |
+
|
372 |
+
def is_available(self) -> bool:
|
373 |
+
return self._available
|
374 |
+
|
375 |
+
def cleanup(self) -> None:
|
376 |
+
if self._model is not None:
|
377 |
+
del self._model
|
378 |
+
|
379 |
+
|
380 |
+
class AnimeAestheticModel:
|
381 |
+
"""ONNX-based Anime Aesthetic model."""
|
382 |
+
|
383 |
+
def __init__(self, config: ModelConfig, device: str):
|
384 |
+
self.config = config
|
385 |
+
self.device = device
|
386 |
+
self._available = False
|
387 |
+
self._session = None
|
388 |
+
|
389 |
+
self._initialize_model()
|
390 |
+
|
391 |
+
def _initialize_model(self) -> None:
|
392 |
+
"""Initialize the ONNX model."""
|
393 |
+
try:
|
394 |
+
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
|
395 |
+
self._session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
|
396 |
+
self._available = True
|
397 |
+
logger.info("Anime Aesthetic model loaded successfully")
|
398 |
+
|
399 |
+
except Exception as e:
|
400 |
+
logger.error(f"Failed to initialize Anime Aesthetic: {e}")
|
401 |
+
self._available = False
|
402 |
+
|
403 |
+
async def predict(self, images: List[Image.Image]) -> List[Optional[float]]:
|
404 |
+
"""Predict scores for images (single image processing for ONNX)."""
|
405 |
+
if not self._available:
|
406 |
+
return [None] * len(images)
|
407 |
+
|
408 |
+
scores = []
|
409 |
+
for img in images:
|
410 |
+
try:
|
411 |
+
score = self._predict_single(img)
|
412 |
+
scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
|
413 |
+
except Exception as e:
|
414 |
+
logger.error(f"Error predicting anime aesthetic for image: {e}")
|
415 |
+
scores.append(None)
|
416 |
+
|
417 |
+
return scores
|
418 |
+
|
419 |
+
def _predict_single(self, img: Image.Image) -> float:
|
420 |
+
"""Predict score for a single image."""
|
421 |
+
img_np = np.array(img).astype(np.float32) / 255.0
|
422 |
+
s = 768
|
423 |
+
h, w = img_np.shape[:2]
|
424 |
+
|
425 |
+
# Resize while maintaining aspect ratio
|
426 |
+
if h > w:
|
427 |
+
new_h, new_w = s, int(s * w / h)
|
428 |
+
else:
|
429 |
+
new_h, new_w = int(s * h / w), s
|
430 |
+
|
431 |
+
resized = cv2.resize(img_np, (new_w, new_h))
|
432 |
+
|
433 |
+
# Center crop/pad to square
|
434 |
+
canvas = np.zeros((s, s, 3), dtype=np.float32)
|
435 |
+
pad_h = (s - new_h) // 2
|
436 |
+
pad_w = (s - new_w) // 2
|
437 |
+
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
|
438 |
+
|
439 |
+
# Prepare input
|
440 |
+
input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
|
441 |
+
return self._session.run(None, {"img": input_tensor})[0].item()
|
442 |
+
|
443 |
+
def is_available(self) -> bool:
|
444 |
+
return self._available
|
445 |
+
|
446 |
+
def cleanup(self) -> None:
|
447 |
+
if self._session is not None:
|
448 |
+
del self._session
|
449 |
+
|
450 |
+
|
451 |
+
# =============================================================================
|
452 |
+
# Model Factory and Manager
|
453 |
+
# =============================================================================
|
454 |
+
|
455 |
+
class ModelFactory:
|
456 |
+
"""Factory for creating model instances."""
|
457 |
+
|
458 |
+
_MODEL_CLASSES = {
|
459 |
+
ModelType.AESTHETIC_SHADOW: AestheticShadowModel,
|
460 |
+
ModelType.WAIFU_SCORER: WaifuScorerModel,
|
461 |
+
ModelType.AESTHETIC_PREDICTOR_V2_5: AestheticPredictorV25Model,
|
462 |
+
ModelType.ANIME_AESTHETIC: AnimeAestheticModel,
|
463 |
+
}
|
464 |
+
|
465 |
+
@classmethod
|
466 |
+
def create_model(cls, model_type: ModelType, config: ModelConfig, device: str) -> BaseModel:
|
467 |
+
"""Create a model instance based on type."""
|
468 |
+
model_class = cls._MODEL_CLASSES.get(model_type)
|
469 |
+
if not model_class:
|
470 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
471 |
+
|
472 |
+
return model_class(config, device)
|
473 |
+
|
474 |
|
475 |
class ModelManager:
|
476 |
+
"""Advanced model manager with async processing and resource management."""
|
477 |
+
|
478 |
+
def __init__(self, processing_config: ProcessingConfig):
|
479 |
+
self.config = processing_config
|
480 |
+
self.models: Dict[ModelType, BaseModel] = {}
|
481 |
+
self.model_configs = self._create_default_configs()
|
482 |
+
self._processing_queue = asyncio.Queue()
|
483 |
+
self._worker_task: Optional[asyncio.Task] = None
|
484 |
+
self._temp_dir = Path(tempfile.mkdtemp())
|
485 |
+
|
486 |
+
self._initialize_models()
|
487 |
+
|
488 |
+
def _create_default_configs(self) -> Dict[ModelType, ModelConfig]:
|
489 |
+
"""Create default model configurations."""
|
490 |
+
return {
|
491 |
+
ModelType.AESTHETIC_SHADOW: ModelConfig(
|
492 |
+
name="aesthetic_shadow",
|
493 |
+
display_name="Aesthetic Shadow"
|
494 |
+
),
|
495 |
+
ModelType.WAIFU_SCORER: ModelConfig(
|
496 |
+
name="waifu_scorer",
|
497 |
+
display_name="Waifu Scorer"
|
498 |
+
),
|
499 |
+
ModelType.AESTHETIC_PREDICTOR_V2_5: ModelConfig(
|
500 |
+
name="aesthetic_predictor_v2_5",
|
501 |
+
display_name="Aesthetic V2.5"
|
502 |
+
),
|
503 |
+
ModelType.ANIME_AESTHETIC: ModelConfig(
|
504 |
+
name="anime_aesthetic",
|
505 |
+
display_name="Anime Score",
|
506 |
+
batch_supported=False
|
507 |
+
),
|
508 |
}
|
509 |
+
|
510 |
+
def _initialize_models(self) -> None:
|
511 |
+
"""Initialize all models."""
|
512 |
+
logger.info("Initializing models...")
|
513 |
+
|
514 |
+
for model_type, config in self.model_configs.items():
|
515 |
+
if config.enabled:
|
516 |
+
try:
|
517 |
+
model = ModelFactory.create_model(model_type, config, self.config.device)
|
518 |
+
if model.is_available():
|
519 |
+
self.models[model_type] = model
|
520 |
+
logger.info(f"✓ {config.display_name} loaded successfully")
|
521 |
+
else:
|
522 |
+
logger.warning(f"✗ {config.display_name} failed to load")
|
523 |
+
except Exception as e:
|
524 |
+
logger.error(f"✗ {config.display_name} initialization error: {e}")
|
525 |
+
|
526 |
+
logger.info(f"Initialized {len(self.models)} models successfully")
|
527 |
+
|
528 |
+
async def start_worker(self) -> None:
|
529 |
+
"""Start the background processing worker."""
|
530 |
+
if self._worker_task is None:
|
531 |
+
self._worker_task = asyncio.create_task(self._worker_loop())
|
532 |
+
logger.info("Background worker started")
|
533 |
+
|
534 |
+
async def _worker_loop(self) -> None:
|
535 |
+
"""Main worker loop for processing requests."""
|
536 |
while True:
|
537 |
+
request = await self._processing_queue.get()
|
538 |
+
|
539 |
+
if request is None: # Shutdown signal
|
540 |
break
|
541 |
+
|
542 |
try:
|
543 |
+
result = await self._process_request(request)
|
544 |
+
request['future'].set_result(result)
|
545 |
except Exception as e:
|
546 |
+
request['future'].set_exception(e)
|
547 |
finally:
|
548 |
+
self._processing_queue.task_done()
|
549 |
+
|
550 |
+
async def process_images(
|
551 |
+
self,
|
552 |
+
file_paths: List[str],
|
553 |
+
selected_models: List[ModelType],
|
554 |
+
auto_batch: bool = False,
|
555 |
+
manual_batch_size: int = 1
|
556 |
+
) -> BatchResult:
|
557 |
+
"""Process images with selected models."""
|
558 |
+
future = asyncio.Future()
|
559 |
+
request = {
|
560 |
+
'file_paths': file_paths,
|
561 |
+
'selected_models': selected_models,
|
562 |
+
'auto_batch': auto_batch,
|
563 |
+
'manual_batch_size': manual_batch_size,
|
564 |
+
'future': future
|
565 |
+
}
|
566 |
+
|
567 |
+
await self._processing_queue.put(request)
|
568 |
+
return await future
|
569 |
+
|
570 |
+
async def _process_request(self, request: Dict) -> BatchResult:
|
571 |
+
"""Process a single batch request."""
|
572 |
+
start_time = asyncio.get_event_loop().time()
|
573 |
+
logs = []
|
574 |
+
results = []
|
575 |
+
|
576 |
file_paths = request['file_paths']
|
577 |
+
selected_models = request['selected_models']
|
578 |
auto_batch = request['auto_batch']
|
579 |
manual_batch_size = request['manual_batch_size']
|
580 |
+
|
581 |
+
# Load images
|
582 |
+
images, valid_paths = await self._load_images(file_paths, logs)
|
583 |
+
|
584 |
+
if not images:
|
585 |
+
return BatchResult([], logs, 0.0, 0, 0, len(file_paths))
|
586 |
+
|
587 |
+
# Determine batch size
|
588 |
+
batch_size = await self._determine_batch_size(images, auto_batch, manual_batch_size, logs)
|
589 |
+
|
590 |
+
# Process in batches
|
591 |
+
for i in range(0, len(images), batch_size):
|
592 |
+
batch_images = images[i:i+batch_size]
|
593 |
+
batch_paths = valid_paths[i:i+batch_size]
|
594 |
+
|
595 |
+
batch_results = await self._process_batch(batch_images, batch_paths, selected_models, logs)
|
596 |
+
results.extend(batch_results)
|
597 |
+
|
598 |
+
processing_time = asyncio.get_event_loop().time() - start_time
|
599 |
+
success_count = sum(1 for r in results if r.error is None)
|
600 |
+
error_count = len(results) - success_count
|
601 |
+
|
602 |
+
return BatchResult(
|
603 |
+
results=results,
|
604 |
+
logs=logs,
|
605 |
+
processing_time=processing_time,
|
606 |
+
batch_size_used=batch_size,
|
607 |
+
success_count=success_count,
|
608 |
+
error_count=error_count
|
609 |
+
)
|
610 |
+
|
611 |
+
async def _load_images(self, file_paths: List[str], logs: List[str]) -> Tuple[List[Image.Image], List[str]]:
|
612 |
+
"""Load and validate images."""
|
613 |
images = []
|
614 |
+
valid_paths = []
|
615 |
+
|
616 |
+
logs.append(f"Loading {len(file_paths)} images...")
|
617 |
+
|
618 |
+
for path in file_paths:
|
|
|
|
|
619 |
try:
|
620 |
+
img = Image.open(path).convert("RGB")
|
621 |
images.append(img)
|
622 |
+
valid_paths.append(path)
|
623 |
except Exception as e:
|
624 |
+
logs.append(f"Failed to load {path}: {e}")
|
625 |
+
|
626 |
+
logs.append(f"Successfully loaded {len(images)} images")
|
627 |
+
return images, valid_paths
|
628 |
+
|
629 |
+
async def _determine_batch_size(
|
630 |
+
self,
|
631 |
+
images: List[Image.Image],
|
632 |
+
auto_batch: bool,
|
633 |
+
manual_batch_size: int,
|
634 |
+
logs: List[str]
|
635 |
+
) -> int:
|
636 |
+
"""Determine optimal batch size."""
|
637 |
+
if not auto_batch:
|
638 |
+
return min(manual_batch_size, len(images))
|
639 |
+
|
640 |
+
# Auto-tune batch size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
641 |
batch_size = 1
|
|
|
642 |
test_image = images[0:1]
|
643 |
+
|
644 |
+
while batch_size <= min(len(images), self.config.max_batch_size):
|
645 |
try:
|
646 |
+
# Test with a sample of available models
|
647 |
+
test_batch = test_image * batch_size
|
648 |
+
for model_type, model in list(self.models.items())[:2]: # Test with first 2 models
|
649 |
+
await model.predict(test_batch)
|
650 |
+
|
|
|
651 |
batch_size *= 2
|
|
|
|
|
652 |
except Exception:
|
653 |
break
|
654 |
+
|
655 |
+
optimal_batch = max(1, batch_size // 2)
|
656 |
+
logs.append(f"Auto-tuned batch size: {optimal_batch}")
|
657 |
+
return optimal_batch
|
658 |
+
|
659 |
+
async def _process_batch(
|
660 |
+
self,
|
661 |
+
images: List[Image.Image],
|
662 |
+
paths: List[str],
|
663 |
+
selected_models: List[ModelType],
|
664 |
+
logs: List[str]
|
665 |
+
) -> List[EvaluationResult]:
|
666 |
+
"""Process a single batch of images."""
|
667 |
+
batch_results = []
|
668 |
+
|
669 |
+
# Get predictions from all models
|
670 |
+
model_predictions = {}
|
671 |
+
for model_type in selected_models:
|
672 |
+
if model_type in self.models:
|
673 |
+
try:
|
674 |
+
predictions = await self.models[model_type].predict(images)
|
675 |
+
model_predictions[model_type.value] = predictions
|
676 |
+
logs.append(f"✓ {self.model_configs[model_type].display_name} processed batch")
|
677 |
+
except Exception as e:
|
678 |
+
logs.append(f"✗ {self.model_configs[model_type].display_name} error: {e}")
|
679 |
+
model_predictions[model_type.value] = [None] * len(images)
|
680 |
+
|
681 |
+
# Create results
|
682 |
+
for i, (image, path) in enumerate(zip(images, paths)):
|
683 |
+
# Collect scores for this image
|
684 |
+
scores = {}
|
685 |
+
valid_scores = []
|
686 |
+
|
687 |
+
for model_type in selected_models:
|
688 |
+
score = model_predictions.get(model_type.value, [None] * len(images))[i]
|
689 |
+
scores[model_type.value] = score
|
690 |
+
if score is not None:
|
691 |
+
valid_scores.append(score)
|
692 |
+
|
693 |
+
# Calculate final score
|
694 |
+
final_score = np.mean(valid_scores) if valid_scores else None
|
695 |
+
if final_score is not None:
|
696 |
+
final_score = float(np.clip(final_score, *self.config.score_range))
|
697 |
+
|
698 |
+
# Create thumbnail
|
699 |
+
thumbnail = image.copy()
|
700 |
+
thumbnail.thumbnail((200, 200), Image.Resampling.LANCZOS)
|
701 |
+
thumbnail_b64 = self._image_to_base64(thumbnail)
|
702 |
+
|
703 |
+
result = EvaluationResult(
|
704 |
+
file_name=Path(path).name,
|
705 |
+
file_path=path,
|
706 |
+
thumbnail_b64=thumbnail_b64,
|
707 |
+
model_scores=scores,
|
708 |
+
final_score=final_score
|
709 |
+
)
|
710 |
+
|
711 |
+
batch_results.append(result)
|
712 |
+
|
713 |
+
return batch_results
|
714 |
+
|
715 |
+
def _image_to_base64(self, image: Image.Image) -> str:
|
716 |
+
"""Convert PIL Image to base64 string."""
|
717 |
+
buffer = BytesIO()
|
718 |
+
image.save(buffer, format="JPEG", quality=85, optimize=True)
|
719 |
+
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
720 |
+
|
721 |
+
def get_available_models(self) -> Dict[ModelType, str]:
|
722 |
+
"""Get available models with their display names."""
|
723 |
+
return {
|
724 |
+
model_type: self.model_configs[model_type].display_name
|
725 |
+
for model_type in self.models.keys()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
726 |
}
|
727 |
+
|
728 |
+
async def cleanup(self) -> None:
|
729 |
+
"""Clean up resources."""
|
730 |
+
# Shutdown worker
|
731 |
+
if self._worker_task:
|
732 |
+
await self._processing_queue.put(None)
|
733 |
+
await self._worker_task
|
734 |
+
|
735 |
+
# Clean up models
|
736 |
+
for model in self.models.values():
|
737 |
+
model.cleanup()
|
738 |
+
|
739 |
+
# Clean up temp directory
|
740 |
+
if self._temp_dir.exists():
|
741 |
+
shutil.rmtree(self._temp_dir)
|
742 |
+
|
743 |
+
logger.info("Model manager cleanup completed")
|
744 |
+
|
745 |
+
|
746 |
+
# =============================================================================
|
747 |
+
# Results Processing and Export
|
748 |
+
# =============================================================================
|
749 |
+
|
750 |
+
class ResultsProcessor:
|
751 |
+
"""Handle result processing, sorting, and export functionality."""
|
752 |
+
|
753 |
+
@staticmethod
|
754 |
+
def sort_results(results: List[EvaluationResult], sort_by: str, reverse: bool = True) -> List[EvaluationResult]:
|
755 |
+
"""Sort results by specified criteria."""
|
756 |
+
sort_key_map = {
|
757 |
+
"Final Score": lambda r: r.final_score if r.final_score is not None else -float('inf'),
|
758 |
+
"File Name": lambda r: r.file_name.lower(),
|
759 |
+
**{f"model_{model_type.value}": lambda r, mt=model_type.value: r.model_scores.get(mt) or -float('inf')
|
760 |
+
for model_type in ModelType}
|
761 |
+
}
|
762 |
+
|
763 |
+
sort_key = sort_key_map.get(sort_by, sort_key_map["Final Score"])
|
764 |
+
return sorted(results, key=sort_key, reverse=reverse and sort_by != "File Name")
|
765 |
+
|
766 |
+
@staticmethod
|
767 |
+
def generate_html_table(results: List[EvaluationResult], selected_models: List[ModelType]) -> str:
|
768 |
+
"""Generate HTML table for results display."""
|
769 |
+
if not results:
|
770 |
+
return "<p>No results to display</p>"
|
771 |
+
|
772 |
+
# CSS styles
|
773 |
+
styles = """
|
774 |
<style>
|
775 |
+
.results-table {
|
776 |
+
width: 100%; border-collapse: collapse; margin: 20px 0;
|
777 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
778 |
+
}
|
779 |
+
.results-table th, .results-table td {
|
780 |
+
border: 1px solid #ddd; padding: 12px; text-align: center;
|
781 |
+
}
|
782 |
+
.results-table th {
|
783 |
+
background-color: #f8f9fa; font-weight: 600; color: #495057;
|
784 |
+
}
|
785 |
+
.results-table tr:nth-child(even) { background-color: #f8f9fa; }
|
786 |
+
.results-table tr:hover { background-color: #e9ecef; }
|
787 |
+
.image-preview {
|
788 |
+
max-width: 120px; max-height: 120px; border-radius: 8px;
|
789 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
790 |
+
}
|
791 |
+
.score-excellent { color: #28a745; font-weight: bold; }
|
792 |
+
.score-good { color: #ffc107; font-weight: bold; }
|
793 |
+
.score-poor { color: #dc3545; font-weight: bold; }
|
794 |
+
.score-na { color: #6c757d; font-style: italic; }
|
795 |
</style>
|
|
|
|
|
|
|
|
|
|
|
796 |
"""
|
797 |
+
|
798 |
+
# Table header
|
799 |
+
html = styles + '<table class="results-table"><thead><tr>'
|
800 |
+
html += '<th>Image</th><th>File Name</th>'
|
801 |
+
|
802 |
+
for model_type in selected_models:
|
803 |
+
model_name = ModelType(model_type).name.replace('_', ' ').title()
|
804 |
+
html += f'<th>{model_name}</th>'
|
805 |
+
|
806 |
+
html += '<th>Final Score</th></tr></thead><tbody>'
|
807 |
+
|
808 |
+
# Table rows
|
|
|
|
|
|
|
|
|
809 |
for result in results:
|
810 |
+
html += '<tr>'
|
811 |
+
html += f'<td><img src="data:image/jpeg;base64,{result.thumbnail_b64}" class="image-preview" alt="{result.file_name}"></td>'
|
812 |
+
html += f'<td>{result.file_name}</td>'
|
813 |
+
|
814 |
+
# Model scores
|
815 |
+
for model_type in selected_models:
|
816 |
+
score = result.model_scores.get(model_type.value)
|
817 |
+
html += ResultsProcessor._format_score_cell(score)
|
818 |
+
|
819 |
+
# Final score
|
820 |
+
html += ResultsProcessor._format_score_cell(result.final_score)
|
821 |
+
html += '</tr>'
|
822 |
+
|
823 |
+
html += '</tbody></table>'
|
824 |
+
return html
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|