|
import os |
|
import asyncio |
|
from typing import List, Dict, Optional, Tuple, Any |
|
from dataclasses import dataclass, field |
|
from pathlib import Path |
|
import logging |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import onnxruntime as rt |
|
from PIL import Image |
|
import gradio as gr |
|
from transformers import pipeline |
|
from huggingface_hub import hf_hub_download |
|
import pandas as pd |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip |
|
|
|
|
|
@dataclass |
|
class EvaluationResult: |
|
"""Data class for storing image evaluation results""" |
|
file_name: str |
|
image_path: str |
|
scores: Dict[str, Optional[float]] = field(default_factory=dict) |
|
final_score: Optional[float] = None |
|
|
|
def calculate_final_score(self, selected_models: List[str]) -> None: |
|
"""Calculate the average score from selected models""" |
|
valid_scores = [ |
|
score for model, score in self.scores.items() |
|
if model in selected_models and score is not None |
|
] |
|
self.final_score = np.mean(valid_scores) if valid_scores else None |
|
|
|
|
|
class BaseModel: |
|
"""Base class for all evaluation models""" |
|
def __init__(self, name: str): |
|
self.name = name |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]: |
|
"""Evaluate a batch of images""" |
|
raise NotImplementedError |
|
|
|
|
|
class AestheticShadowModel(BaseModel): |
|
"""Aesthetic Shadow V2 model implementation""" |
|
def __init__(self): |
|
super().__init__("Aesthetic Shadow") |
|
logger.info(f"Loading {self.name} model...") |
|
self.model = pipeline( |
|
"image-classification", |
|
model="NeoChen1024/aesthetic-shadow-v2-backup", |
|
device=0 if self.device == 'cuda' else -1 |
|
) |
|
|
|
async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]: |
|
try: |
|
results = self.model(images) |
|
scores = [] |
|
for result in results: |
|
hq_score = next((p['score'] for p in result if p['label'] == 'hq'), 0) |
|
scores.append(float(np.clip(hq_score * 10.0, 0.0, 10.0))) |
|
return scores |
|
except Exception as e: |
|
logger.error(f"Error in {self.name}: {e}") |
|
return [None] * len(images) |
|
|
|
|
|
class WaifuScorerModel(BaseModel): |
|
"""Waifu Scorer V3 model implementation""" |
|
def __init__(self): |
|
super().__init__("Waifu Scorer") |
|
logger.info(f"Loading {self.name} model...") |
|
self._load_model() |
|
|
|
def _load_model(self): |
|
try: |
|
import clip |
|
|
|
|
|
self.mlp = self._create_mlp() |
|
model_path = hf_hub_download("Eugeoter/waifu-scorer-v3", "model.pth") |
|
state_dict = torch.load(model_path, map_location=self.device) |
|
self.mlp.load_state_dict(state_dict) |
|
self.mlp.to(self.device).eval() |
|
|
|
|
|
self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device) |
|
self.available = True |
|
except Exception as e: |
|
logger.error(f"Failed to load {self.name}: {e}") |
|
self.available = False |
|
|
|
def _create_mlp(self) -> torch.nn.Module: |
|
"""Create the MLP architecture""" |
|
return torch.nn.Sequential( |
|
torch.nn.Linear(768, 2048), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(2048), |
|
torch.nn.Dropout(0.3), |
|
torch.nn.Linear(2048, 512), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(512), |
|
torch.nn.Dropout(0.3), |
|
torch.nn.Linear(512, 256), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(256), |
|
torch.nn.Dropout(0.2), |
|
torch.nn.Linear(256, 128), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(128), |
|
torch.nn.Dropout(0.1), |
|
torch.nn.Linear(128, 32), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(32, 1) |
|
) |
|
|
|
@torch.no_grad() |
|
async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]: |
|
if not self.available: |
|
return [None] * len(images) |
|
|
|
try: |
|
|
|
image_tensors = torch.cat([self.preprocess(img).unsqueeze(0) for img in images]) |
|
image_tensors = image_tensors.to(self.device) |
|
|
|
|
|
features = self.clip_model.encode_image(image_tensors) |
|
features = features / features.norm(dim=-1, keepdim=True) |
|
predictions = self.mlp(features) |
|
|
|
scores = predictions.clamp(0, 10).cpu().numpy().flatten().tolist() |
|
return scores |
|
except Exception as e: |
|
logger.error(f"Error in {self.name}: {e}") |
|
return [None] * len(images) |
|
|
|
|
|
class AestheticPredictorV25Model(BaseModel): |
|
"""Aesthetic Predictor V2.5 model implementation""" |
|
def __init__(self): |
|
super().__init__("Aesthetic V2.5") |
|
logger.info(f"Loading {self.name} model...") |
|
self.model, self.preprocessor = convert_v2_5_from_siglip( |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
) |
|
if self.device == 'cuda': |
|
self.model = self.model.to(torch.bfloat16).cuda() |
|
|
|
@torch.no_grad() |
|
async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]: |
|
try: |
|
images_rgb = [img.convert("RGB") for img in images] |
|
pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt").pixel_values |
|
|
|
if self.device == 'cuda': |
|
pixel_values = pixel_values.to(torch.bfloat16).cuda() |
|
|
|
scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy() |
|
if scores.ndim == 0: |
|
scores = np.array([scores]) |
|
|
|
return [float(np.clip(s, 0.0, 10.0)) for s in scores] |
|
except Exception as e: |
|
logger.error(f"Error in {self.name}: {e}") |
|
return [None] * len(images) |
|
|
|
|
|
class AnimeAestheticModel(BaseModel): |
|
"""Anime Aesthetic model implementation""" |
|
def __init__(self): |
|
super().__init__("Anime Score") |
|
logger.info(f"Loading {self.name} model...") |
|
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx") |
|
self.session = rt.InferenceSession(model_path, providers=['CPUExecutionProvider']) |
|
|
|
async def evaluate_batch(self, images: List[Image.Image]) -> List[Optional[float]]: |
|
scores = [] |
|
for img in images: |
|
try: |
|
score = self._process_single_image(img) |
|
scores.append(float(np.clip(score * 10.0, 0.0, 10.0))) |
|
except Exception as e: |
|
logger.error(f"Error in {self.name} for single image: {e}") |
|
scores.append(None) |
|
return scores |
|
|
|
def _process_single_image(self, img: Image.Image) -> float: |
|
"""Process a single image through the model""" |
|
img_np = np.array(img).astype(np.float32) / 255.0 |
|
size = 768 |
|
h, w = img_np.shape[:2] |
|
|
|
|
|
if h > w: |
|
new_h, new_w = size, int(size * w / h) |
|
else: |
|
new_h, new_w = int(size * h / w), size |
|
|
|
|
|
resized = cv2.resize(img_np, (new_w, new_h)) |
|
canvas = np.zeros((size, size, 3), dtype=np.float32) |
|
pad_h = (size - new_h) // 2 |
|
pad_w = (size - new_w) // 2 |
|
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized |
|
|
|
|
|
input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :] |
|
return self.session.run(None, {"img": input_tensor})[0].item() |
|
|
|
|
|
class ImageEvaluator: |
|
"""Main class for managing image evaluation""" |
|
def __init__(self): |
|
self.models: Dict[str, BaseModel] = {} |
|
self._initialize_models() |
|
self.results: List[EvaluationResult] = [] |
|
|
|
def _initialize_models(self): |
|
"""Initialize all evaluation models""" |
|
model_classes = [ |
|
("aesthetic_shadow", AestheticShadowModel), |
|
("waifu_scorer", WaifuScorerModel), |
|
("aesthetic_predictor_v2_5", AestheticPredictorV25Model), |
|
("anime_aesthetic", AnimeAestheticModel), |
|
] |
|
|
|
for key, model_class in model_classes: |
|
try: |
|
self.models[key] = model_class() |
|
logger.info(f"Successfully loaded {key}") |
|
except Exception as e: |
|
logger.error(f"Failed to load {key}: {e}") |
|
|
|
async def evaluate_images( |
|
self, |
|
file_paths: List[str], |
|
selected_models: List[str], |
|
batch_size: int = 8, |
|
progress_callback = None |
|
) -> Tuple[List[EvaluationResult], List[str]]: |
|
"""Evaluate images with selected models""" |
|
logs = [] |
|
results = [] |
|
|
|
|
|
images = [] |
|
valid_paths = [] |
|
for path in file_paths: |
|
try: |
|
img = Image.open(path).convert("RGB") |
|
images.append(img) |
|
valid_paths.append(path) |
|
except Exception as e: |
|
logs.append(f"Failed to load {Path(path).name}: {e}") |
|
|
|
if not images: |
|
logs.append("No valid images to process") |
|
return results, logs |
|
|
|
logs.append(f"Loaded {len(images)} images") |
|
|
|
|
|
total_batches = (len(images) + batch_size - 1) // batch_size |
|
|
|
for batch_idx in range(0, len(images), batch_size): |
|
batch_images = images[batch_idx:batch_idx + batch_size] |
|
batch_paths = valid_paths[batch_idx:batch_idx + batch_size] |
|
|
|
|
|
batch_results = {} |
|
for model_key in selected_models: |
|
if model_key in self.models: |
|
scores = await self.models[model_key].evaluate_batch(batch_images) |
|
batch_results[model_key] = scores |
|
logs.append(f"Processed batch {batch_idx//batch_size + 1}/{total_batches} with {self.models[model_key].name}") |
|
|
|
|
|
for i, (path, img) in enumerate(zip(batch_paths, batch_images)): |
|
result = EvaluationResult( |
|
file_name=Path(path).name, |
|
image_path=path |
|
) |
|
|
|
for model_key in selected_models: |
|
if model_key in batch_results: |
|
result.scores[model_key] = batch_results[model_key][i] |
|
|
|
result.calculate_final_score(selected_models) |
|
results.append(result) |
|
|
|
|
|
if progress_callback: |
|
progress = (batch_idx + batch_size) / len(images) * 100 |
|
progress_callback(min(progress, 100)) |
|
|
|
self.results = results |
|
return results, logs |
|
|
|
def get_results_dataframe(self, selected_models: List[str]) -> pd.DataFrame: |
|
"""Convert results to pandas DataFrame""" |
|
if not self.results: |
|
return pd.DataFrame() |
|
|
|
data = [] |
|
for result in self.results: |
|
row = { |
|
'File Name': result.file_name, |
|
'Image': result.image_path, |
|
} |
|
|
|
|
|
for model_key in selected_models: |
|
if model_key in self.models: |
|
score = result.scores.get(model_key) |
|
row[self.models[model_key].name] = f"{score:.4f}" if score is not None else "N/A" |
|
|
|
row['Final Score'] = f"{result.final_score:.4f}" if result.final_score is not None else "N/A" |
|
data.append(row) |
|
|
|
return pd.DataFrame(data) |
|
|
|
|
|
def create_interface(): |
|
"""Create the Gradio interface""" |
|
evaluator = ImageEvaluator() |
|
|
|
|
|
model_options = [ |
|
("Aesthetic Shadow", "aesthetic_shadow"), |
|
("Waifu Scorer", "waifu_scorer"), |
|
("Aesthetic V2.5", "aesthetic_predictor_v2_5"), |
|
("Anime Score", "anime_aesthetic") |
|
] |
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="Image Evaluation Tool") as demo: |
|
gr.Markdown(""" |
|
# π¨ Advanced Image Evaluation Tool |
|
|
|
Evaluate images using state-of-the-art aesthetic and quality prediction models. |
|
Upload your images and select the models you want to use for evaluation. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_files = gr.File( |
|
label="Upload Images", |
|
file_count="multiple", |
|
file_types=["image"] |
|
) |
|
|
|
model_checkboxes = gr.CheckboxGroup( |
|
choices=[label for label, _ in model_options], |
|
value=[label for label, _ in model_options], |
|
label="Select Models", |
|
info="Choose which models to use for evaluation" |
|
) |
|
|
|
with gr.Row(): |
|
batch_size = gr.Slider( |
|
minimum=1, |
|
maximum=64, |
|
value=8, |
|
step=1, |
|
label="Batch Size", |
|
info="Number of images to process at once" |
|
) |
|
|
|
with gr.Row(): |
|
evaluate_btn = gr.Button("π Evaluate Images", variant="primary", scale=2) |
|
clear_btn = gr.Button("ποΈ Clear", variant="secondary", scale=1) |
|
|
|
with gr.Column(scale=2): |
|
progress = gr.Progress() |
|
logs = gr.Textbox( |
|
label="Processing Logs", |
|
lines=10, |
|
max_lines=10, |
|
autoscroll=True |
|
) |
|
|
|
results_df = gr.Dataframe( |
|
label="Evaluation Results", |
|
interactive=False, |
|
wrap=True |
|
) |
|
|
|
download_btn = gr.Button("π₯ Download Results (CSV)", variant="secondary") |
|
download_file = gr.File(visible=False) |
|
|
|
|
|
results_state = gr.State([]) |
|
|
|
async def process_images(files, selected_model_labels, batch_size, progress=gr.Progress()): |
|
"""Process uploaded images""" |
|
if not files: |
|
return "Please upload images first", pd.DataFrame(), [] |
|
|
|
|
|
selected_models = [key for label, key in model_options if label in selected_model_labels] |
|
|
|
|
|
file_paths = [f.name for f in files] |
|
|
|
|
|
def update_progress(value): |
|
progress(value / 100, desc=f"Processing images... {value:.0f}%") |
|
|
|
|
|
results, logs = await evaluator.evaluate_images( |
|
file_paths, |
|
selected_models, |
|
batch_size, |
|
update_progress |
|
) |
|
|
|
|
|
df = evaluator.get_results_dataframe(selected_models) |
|
|
|
|
|
log_text = "\n".join(logs[-10:]) |
|
|
|
return log_text, df, results |
|
|
|
def update_results_on_model_change(selected_model_labels, results): |
|
"""Update results when model selection changes""" |
|
if not results: |
|
return pd.DataFrame() |
|
|
|
|
|
selected_models = [key for label, key in model_options if label in selected_model_labels] |
|
|
|
|
|
for result in results: |
|
result.calculate_final_score(selected_models) |
|
|
|
|
|
evaluator.results = results |
|
|
|
|
|
return evaluator.get_results_dataframe(selected_models) |
|
|
|
def clear_interface(): |
|
"""Clear all results""" |
|
return "", pd.DataFrame(), [], None |
|
|
|
def prepare_download(selected_model_labels, results): |
|
"""Prepare CSV file for download""" |
|
if not results: |
|
return None |
|
|
|
|
|
selected_models = [key for label, key in model_options if label in selected_model_labels] |
|
|
|
|
|
df = evaluator.get_results_dataframe(selected_models) |
|
|
|
|
|
import tempfile |
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f: |
|
df.to_csv(f, index=False) |
|
return f.name |
|
|
|
|
|
evaluate_btn.click( |
|
fn=process_images, |
|
inputs=[input_files, model_checkboxes, batch_size], |
|
outputs=[logs, results_df, results_state] |
|
) |
|
|
|
model_checkboxes.change( |
|
fn=update_results_on_model_change, |
|
inputs=[model_checkboxes, results_state], |
|
outputs=[results_df] |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear_interface, |
|
outputs=[logs, results_df, results_state, download_file] |
|
) |
|
|
|
download_btn.click( |
|
fn=prepare_download, |
|
inputs=[model_checkboxes, results_state], |
|
outputs=[download_file] |
|
) |
|
|
|
gr.Markdown(""" |
|
### π Notes |
|
- **Model Selection**: Choose which models to use for evaluation. Final score is the average of selected models. |
|
- **Batch Size**: Adjust based on your GPU memory. Larger batches process faster. |
|
- **Results Table**: Click column headers to sort. The table updates automatically when models are selected/deselected. |
|
- **Download**: Export results as CSV for further analysis. |
|
|
|
### π― Score Interpretation |
|
- **7-10**: High quality/aesthetic appeal |
|
- **5-7**: Medium quality |
|
- **0-5**: Lower quality |
|
""") |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo = create_interface() |
|
demo.queue().launch() |
|
|