Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from transformers import AutoModel, AutoProcessor | |
import logging | |
logger = logging.getLogger(__name__) | |
class AestheticsEvaluator: | |
"""Image aesthetics assessment using multiple SOTA models""" | |
def __init__(self): | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.models = {} | |
self.processors = {} | |
self.load_models() | |
def load_models(self): | |
"""Load aesthetics assessment models""" | |
try: | |
# Load UNIAA model (primary) | |
logger.info("Loading UNIAA model...") | |
self.load_uniaa() | |
# Load MUSIQ model (secondary) | |
logger.info("Loading MUSIQ model...") | |
self.load_musiq() | |
# Load anime-specific aesthetic model | |
logger.info("Loading anime aesthetic model...") | |
self.load_anime_aesthetic_model() | |
except Exception as e: | |
logger.error(f"Error loading aesthetic models: {str(e)}") | |
self.use_fallback_implementation() | |
def load_uniaa(self): | |
"""Load UNIAA model""" | |
try: | |
# Placeholder implementation for UNIAA | |
self.models['uniaa'] = self.create_mock_aesthetic_model() | |
self.processors['uniaa'] = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
except Exception as e: | |
logger.warning(f"Could not load UNIAA: {str(e)}") | |
def load_musiq(self): | |
"""Load MUSIQ model""" | |
try: | |
# Placeholder implementation for MUSIQ | |
self.models['musiq'] = self.create_mock_aesthetic_model() | |
self.processors['musiq'] = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
except Exception as e: | |
logger.warning(f"Could not load MUSIQ: {str(e)}") | |
def load_anime_aesthetic_model(self): | |
"""Load anime-specific aesthetic model""" | |
try: | |
# Placeholder for anime-specific model | |
self.models['anime_aesthetic'] = self.create_mock_aesthetic_model() | |
self.processors['anime_aesthetic'] = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
except Exception as e: | |
logger.warning(f"Could not load anime aesthetic model: {str(e)}") | |
def create_mock_aesthetic_model(self): | |
"""Create a mock aesthetic model for demonstration""" | |
class MockAestheticModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.backbone = torch.nn.Sequential( | |
torch.nn.Conv2d(3, 64, 3, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d(64, 128, 3, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.AdaptiveAvgPool2d((1, 1)), | |
torch.nn.Flatten(), | |
torch.nn.Linear(128, 64), | |
torch.nn.ReLU(), | |
torch.nn.Linear(64, 1), | |
torch.nn.Sigmoid() | |
) | |
def forward(self, x): | |
return self.backbone(x) * 10 # Scale to 0-10 | |
model = MockAestheticModel().to(self.device) | |
model.eval() | |
return model | |
def use_fallback_implementation(self): | |
"""Use simple fallback aesthetic assessment""" | |
logger.info("Using fallback aesthetic assessment implementation") | |
self.fallback_mode = True | |
def evaluate_with_uniaa(self, image: Image.Image) -> float: | |
"""Evaluate aesthetics using UNIAA""" | |
try: | |
if 'uniaa' not in self.models: | |
return self.fallback_aesthetic_score(image) | |
# Preprocess image | |
tensor = self.processors['uniaa'](image).unsqueeze(0).to(self.device) | |
# Get prediction | |
with torch.no_grad(): | |
score = self.models['uniaa'](tensor).item() | |
return max(0.0, min(10.0, score)) | |
except Exception as e: | |
logger.error(f"Error in UNIAA evaluation: {str(e)}") | |
return self.fallback_aesthetic_score(image) | |
def evaluate_with_musiq(self, image: Image.Image) -> float: | |
"""Evaluate aesthetics using MUSIQ""" | |
try: | |
if 'musiq' not in self.models: | |
return self.fallback_aesthetic_score(image) | |
# Preprocess image | |
tensor = self.processors['musiq'](image).unsqueeze(0).to(self.device) | |
# Get prediction | |
with torch.no_grad(): | |
score = self.models['musiq'](tensor).item() | |
return max(0.0, min(10.0, score)) | |
except Exception as e: | |
logger.error(f"Error in MUSIQ evaluation: {str(e)}") | |
return self.fallback_aesthetic_score(image) | |
def evaluate_with_anime_model(self, image: Image.Image) -> float: | |
"""Evaluate aesthetics using anime-specific model""" | |
try: | |
if 'anime_aesthetic' not in self.models: | |
return self.fallback_aesthetic_score(image) | |
# Preprocess image | |
tensor = self.processors['anime_aesthetic'](image).unsqueeze(0).to(self.device) | |
# Get prediction | |
with torch.no_grad(): | |
score = self.models['anime_aesthetic'](tensor).item() | |
return max(0.0, min(10.0, score)) | |
except Exception as e: | |
logger.error(f"Error in anime aesthetic evaluation: {str(e)}") | |
return self.fallback_aesthetic_score(image) | |
def evaluate_composition_rules(self, image: Image.Image) -> float: | |
"""Evaluate based on composition rules (rule of thirds, etc.)""" | |
try: | |
# Convert to numpy array | |
img_array = np.array(image) | |
height, width = img_array.shape[:2] | |
# Convert to grayscale for analysis | |
if len(img_array.shape) == 3: | |
gray = np.dot(img_array[...,:3], [0.2989, 0.5870, 0.1140]) | |
else: | |
gray = img_array | |
# Rule of thirds analysis | |
third_h, third_w = height // 3, width // 3 | |
# Check for interesting content at rule of thirds intersections | |
intersections = [ | |
(third_h, third_w), (third_h, 2*third_w), | |
(2*third_h, third_w), (2*third_h, 2*third_w) | |
] | |
composition_score = 0.0 | |
for y, x in intersections: | |
# Check local variance around intersection points | |
region = gray[max(0, y-10):min(height, y+10), | |
max(0, x-10):min(width, x+10)] | |
if region.size > 0: | |
composition_score += region.var() | |
# Normalize composition score | |
composition_score = min(10.0, composition_score / 1000.0) | |
# Color harmony analysis | |
if len(img_array.shape) == 3: | |
# Calculate color distribution | |
colors = img_array.reshape(-1, 3) | |
color_std = np.std(colors, axis=0).mean() | |
color_harmony_score = min(10.0, color_std / 25.0) | |
else: | |
color_harmony_score = 5.0 | |
# Combine scores | |
final_score = (composition_score * 0.6 + color_harmony_score * 0.4) | |
return max(0.0, min(10.0, final_score)) | |
except Exception as e: | |
logger.error(f"Error in composition analysis: {str(e)}") | |
return 5.0 | |
def fallback_aesthetic_score(self, image: Image.Image) -> float: | |
"""Simple fallback aesthetic assessment""" | |
try: | |
# Basic aesthetic assessment based on image properties | |
width, height = image.size | |
# Aspect ratio score (prefer aesthetically pleasing ratios) | |
aspect_ratio = width / height | |
golden_ratio = 1.618 | |
if abs(aspect_ratio - golden_ratio) < 0.1 or abs(aspect_ratio - 1/golden_ratio) < 0.1: | |
aspect_score = 9.0 | |
elif 0.7 <= aspect_ratio <= 1.4: # Square-ish | |
aspect_score = 7.0 | |
elif 1.4 <= aspect_ratio <= 2.0: # Landscape | |
aspect_score = 8.0 | |
else: | |
aspect_score = 5.0 | |
# Resolution score (higher resolution often looks better) | |
total_pixels = width * height | |
resolution_score = min(10.0, total_pixels / 200000.0) # Normalize by 2MP | |
# Color analysis | |
img_array = np.array(image) | |
if len(img_array.shape) == 3: | |
# Color variety score | |
unique_colors = len(np.unique(img_array.reshape(-1, 3), axis=0)) | |
color_variety_score = min(10.0, unique_colors / 1000.0) | |
# Brightness distribution | |
brightness = np.mean(img_array, axis=2) | |
brightness_score = 10.0 - abs(brightness.mean() - 127.5) / 12.75 | |
else: | |
color_variety_score = 5.0 | |
brightness_score = 5.0 | |
# Combine scores | |
aesthetic_score = (aspect_score * 0.3 + | |
resolution_score * 0.2 + | |
color_variety_score * 0.3 + | |
brightness_score * 0.2) | |
return max(0.0, min(10.0, aesthetic_score)) | |
except Exception: | |
return 5.0 # Default neutral score | |
def evaluate(self, image: Image.Image, anime_mode: bool = False) -> float: | |
""" | |
Evaluate image aesthetics using ensemble of models | |
Args: | |
image: PIL Image to evaluate | |
anime_mode: Whether to use anime-specific evaluation | |
Returns: | |
Aesthetic score from 0-10 | |
""" | |
try: | |
scores = [] | |
if anime_mode: | |
# For anime images, prioritize anime-specific model | |
anime_score = self.evaluate_with_anime_model(image) | |
scores.append(anime_score) | |
# Also use general models but with lower weight | |
uniaa_score = self.evaluate_with_uniaa(image) | |
scores.append(uniaa_score) | |
# Composition rules | |
composition_score = self.evaluate_composition_rules(image) | |
scores.append(composition_score) | |
# Weights for anime mode | |
weights = [0.5, 0.3, 0.2] | |
else: | |
# For realistic images, use general aesthetic models | |
uniaa_score = self.evaluate_with_uniaa(image) | |
scores.append(uniaa_score) | |
musiq_score = self.evaluate_with_musiq(image) | |
scores.append(musiq_score) | |
# Composition rules | |
composition_score = self.evaluate_composition_rules(image) | |
scores.append(composition_score) | |
# Weights for realistic mode | |
weights = [0.4, 0.4, 0.2] | |
# Ensemble scoring | |
final_score = sum(score * weight for score, weight in zip(scores, weights)) | |
logger.info(f"Aesthetic scores - Scores: {scores}, Final: {final_score:.2f}") | |
return max(0.0, min(10.0, final_score)) | |
except Exception as e: | |
logger.error(f"Error in aesthetic evaluation: {str(e)}") | |
return self.fallback_aesthetic_score(image) | |