Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from transformers import AutoModel, AutoProcessor | |
import logging | |
logger = logging.getLogger(__name__) | |
class QualityEvaluator: | |
"""Image quality assessment using multiple SOTA models""" | |
def __init__(self): | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.models = {} | |
self.processors = {} | |
self.load_models() | |
def load_models(self): | |
"""Load quality assessment models""" | |
try: | |
# Load LAR-IQA model (primary) | |
logger.info("Loading LAR-IQA model...") | |
self.load_lar_iqa() | |
# Load DGIQA model (secondary) | |
logger.info("Loading DGIQA model...") | |
self.load_dgiqa() | |
# Load traditional metrics as fallback | |
logger.info("Loading traditional quality metrics...") | |
self.load_traditional_metrics() | |
except Exception as e: | |
logger.error(f"Error loading quality models: {str(e)}") | |
# Use fallback implementation | |
self.use_fallback_implementation() | |
def load_lar_iqa(self): | |
"""Load LAR-IQA model""" | |
try: | |
# For now, use a placeholder implementation | |
# In production, this would load the actual LAR-IQA model | |
self.models['lar_iqa'] = self.create_mock_model() | |
self.processors['lar_iqa'] = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
except Exception as e: | |
logger.warning(f"Could not load LAR-IQA: {str(e)}") | |
def load_dgiqa(self): | |
"""Load DGIQA model""" | |
try: | |
# Placeholder implementation | |
self.models['dgiqa'] = self.create_mock_model() | |
self.processors['dgiqa'] = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
except Exception as e: | |
logger.warning(f"Could not load DGIQA: {str(e)}") | |
def load_traditional_metrics(self): | |
"""Load traditional quality metrics (BRISQUE, NIQE, etc.)""" | |
try: | |
# These would be implemented using scikit-image or opencv | |
self.traditional_metrics_available = True | |
except Exception as e: | |
logger.warning(f"Could not load traditional metrics: {str(e)}") | |
self.traditional_metrics_available = False | |
def create_mock_model(self): | |
"""Create a mock model for demonstration purposes""" | |
class MockQualityModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.backbone = torch.nn.Sequential( | |
torch.nn.Conv2d(3, 64, 3, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.AdaptiveAvgPool2d((1, 1)), | |
torch.nn.Flatten(), | |
torch.nn.Linear(64, 1), | |
torch.nn.Sigmoid() | |
) | |
def forward(self, x): | |
return self.backbone(x) * 10 # Scale to 0-10 | |
model = MockQualityModel().to(self.device) | |
model.eval() | |
return model | |
def use_fallback_implementation(self): | |
"""Use simple fallback quality assessment""" | |
logger.info("Using fallback quality assessment implementation") | |
self.fallback_mode = True | |
def evaluate_with_lar_iqa(self, image: Image.Image) -> float: | |
"""Evaluate image quality using LAR-IQA""" | |
try: | |
if 'lar_iqa' not in self.models: | |
return self.fallback_quality_score(image) | |
# Preprocess image | |
tensor = self.processors['lar_iqa'](image).unsqueeze(0).to(self.device) | |
# Get prediction | |
with torch.no_grad(): | |
score = self.models['lar_iqa'](tensor).item() | |
return max(0.0, min(10.0, score)) | |
except Exception as e: | |
logger.error(f"Error in LAR-IQA evaluation: {str(e)}") | |
return self.fallback_quality_score(image) | |
def evaluate_with_dgiqa(self, image: Image.Image) -> float: | |
"""Evaluate image quality using DGIQA""" | |
try: | |
if 'dgiqa' not in self.models: | |
return self.fallback_quality_score(image) | |
# Preprocess image | |
tensor = self.processors['dgiqa'](image).unsqueeze(0).to(self.device) | |
# Get prediction | |
with torch.no_grad(): | |
score = self.models['dgiqa'](tensor).item() | |
return max(0.0, min(10.0, score)) | |
except Exception as e: | |
logger.error(f"Error in DGIQA evaluation: {str(e)}") | |
return self.fallback_quality_score(image) | |
def evaluate_traditional_metrics(self, image: Image.Image) -> float: | |
"""Evaluate using traditional quality metrics""" | |
try: | |
# Convert to numpy array | |
img_array = np.array(image) | |
# Simple quality metrics based on image statistics | |
# In production, this would use BRISQUE, NIQE, etc. | |
# Calculate sharpness (Laplacian variance) | |
from scipy import ndimage | |
gray = np.dot(img_array[...,:3], [0.2989, 0.5870, 0.1140]) | |
laplacian_var = ndimage.laplace(gray).var() | |
sharpness_score = min(10.0, laplacian_var / 100.0) | |
# Calculate contrast | |
contrast_score = min(10.0, gray.std() / 25.0) | |
# Calculate brightness distribution | |
brightness_score = 10.0 - abs(gray.mean() - 127.5) / 12.75 | |
# Combine scores | |
quality_score = (sharpness_score * 0.4 + | |
contrast_score * 0.3 + | |
brightness_score * 0.3) | |
return max(0.0, min(10.0, quality_score)) | |
except Exception as e: | |
logger.error(f"Error in traditional metrics: {str(e)}") | |
return 5.0 # Default score | |
def fallback_quality_score(self, image: Image.Image) -> float: | |
"""Simple fallback quality assessment""" | |
try: | |
# Basic quality assessment based on image properties | |
width, height = image.size | |
# Resolution score | |
total_pixels = width * height | |
resolution_score = min(10.0, total_pixels / 100000.0) # Normalize by 1MP | |
# Aspect ratio score (prefer standard ratios) | |
aspect_ratio = width / height | |
if 0.5 <= aspect_ratio <= 2.0: | |
aspect_score = 8.0 | |
else: | |
aspect_score = 5.0 | |
# File format score (prefer lossless) | |
format_score = 8.0 if image.format == 'PNG' else 6.0 | |
# Combine scores | |
quality_score = (resolution_score * 0.5 + | |
aspect_score * 0.3 + | |
format_score * 0.2) | |
return max(0.0, min(10.0, quality_score)) | |
except Exception: | |
return 5.0 # Default neutral score | |
def evaluate(self, image: Image.Image, anime_mode: bool = False) -> float: | |
""" | |
Evaluate image quality using ensemble of models | |
Args: | |
image: PIL Image to evaluate | |
anime_mode: Whether to use anime-specific evaluation | |
Returns: | |
Quality score from 0-10 | |
""" | |
try: | |
scores = [] | |
# LAR-IQA evaluation | |
lar_score = self.evaluate_with_lar_iqa(image) | |
scores.append(lar_score) | |
# DGIQA evaluation | |
dgiqa_score = self.evaluate_with_dgiqa(image) | |
scores.append(dgiqa_score) | |
# Traditional metrics | |
traditional_score = self.evaluate_traditional_metrics(image) | |
scores.append(traditional_score) | |
# Ensemble scoring | |
if anime_mode: | |
# For anime images, weight traditional metrics higher | |
# as they may be more reliable for stylized content | |
weights = [0.3, 0.3, 0.4] | |
else: | |
# For realistic images, weight modern models higher | |
weights = [0.4, 0.4, 0.2] | |
final_score = sum(score * weight for score, weight in zip(scores, weights)) | |
logger.info(f"Quality scores - LAR: {lar_score:.2f}, DGIQA: {dgiqa_score:.2f}, " | |
f"Traditional: {traditional_score:.2f}, Final: {final_score:.2f}") | |
return max(0.0, min(10.0, final_score)) | |
except Exception as e: | |
logger.error(f"Error in quality evaluation: {str(e)}") | |
return self.fallback_quality_score(image) | |