image-evaluation-tool / models /quality_evaluator.py
VOIDER's picture
Upload 14 files
83b7522 verified
raw
history blame
9.46 kB
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from transformers import AutoModel, AutoProcessor
import logging
logger = logging.getLogger(__name__)
class QualityEvaluator:
"""Image quality assessment using multiple SOTA models"""
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.models = {}
self.processors = {}
self.load_models()
def load_models(self):
"""Load quality assessment models"""
try:
# Load LAR-IQA model (primary)
logger.info("Loading LAR-IQA model...")
self.load_lar_iqa()
# Load DGIQA model (secondary)
logger.info("Loading DGIQA model...")
self.load_dgiqa()
# Load traditional metrics as fallback
logger.info("Loading traditional quality metrics...")
self.load_traditional_metrics()
except Exception as e:
logger.error(f"Error loading quality models: {str(e)}")
# Use fallback implementation
self.use_fallback_implementation()
def load_lar_iqa(self):
"""Load LAR-IQA model"""
try:
# For now, use a placeholder implementation
# In production, this would load the actual LAR-IQA model
self.models['lar_iqa'] = self.create_mock_model()
self.processors['lar_iqa'] = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
except Exception as e:
logger.warning(f"Could not load LAR-IQA: {str(e)}")
def load_dgiqa(self):
"""Load DGIQA model"""
try:
# Placeholder implementation
self.models['dgiqa'] = self.create_mock_model()
self.processors['dgiqa'] = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
except Exception as e:
logger.warning(f"Could not load DGIQA: {str(e)}")
def load_traditional_metrics(self):
"""Load traditional quality metrics (BRISQUE, NIQE, etc.)"""
try:
# These would be implemented using scikit-image or opencv
self.traditional_metrics_available = True
except Exception as e:
logger.warning(f"Could not load traditional metrics: {str(e)}")
self.traditional_metrics_available = False
def create_mock_model(self):
"""Create a mock model for demonstration purposes"""
class MockQualityModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, 3, padding=1),
torch.nn.ReLU(),
torch.nn.AdaptiveAvgPool2d((1, 1)),
torch.nn.Flatten(),
torch.nn.Linear(64, 1),
torch.nn.Sigmoid()
)
def forward(self, x):
return self.backbone(x) * 10 # Scale to 0-10
model = MockQualityModel().to(self.device)
model.eval()
return model
def use_fallback_implementation(self):
"""Use simple fallback quality assessment"""
logger.info("Using fallback quality assessment implementation")
self.fallback_mode = True
def evaluate_with_lar_iqa(self, image: Image.Image) -> float:
"""Evaluate image quality using LAR-IQA"""
try:
if 'lar_iqa' not in self.models:
return self.fallback_quality_score(image)
# Preprocess image
tensor = self.processors['lar_iqa'](image).unsqueeze(0).to(self.device)
# Get prediction
with torch.no_grad():
score = self.models['lar_iqa'](tensor).item()
return max(0.0, min(10.0, score))
except Exception as e:
logger.error(f"Error in LAR-IQA evaluation: {str(e)}")
return self.fallback_quality_score(image)
def evaluate_with_dgiqa(self, image: Image.Image) -> float:
"""Evaluate image quality using DGIQA"""
try:
if 'dgiqa' not in self.models:
return self.fallback_quality_score(image)
# Preprocess image
tensor = self.processors['dgiqa'](image).unsqueeze(0).to(self.device)
# Get prediction
with torch.no_grad():
score = self.models['dgiqa'](tensor).item()
return max(0.0, min(10.0, score))
except Exception as e:
logger.error(f"Error in DGIQA evaluation: {str(e)}")
return self.fallback_quality_score(image)
def evaluate_traditional_metrics(self, image: Image.Image) -> float:
"""Evaluate using traditional quality metrics"""
try:
# Convert to numpy array
img_array = np.array(image)
# Simple quality metrics based on image statistics
# In production, this would use BRISQUE, NIQE, etc.
# Calculate sharpness (Laplacian variance)
from scipy import ndimage
gray = np.dot(img_array[...,:3], [0.2989, 0.5870, 0.1140])
laplacian_var = ndimage.laplace(gray).var()
sharpness_score = min(10.0, laplacian_var / 100.0)
# Calculate contrast
contrast_score = min(10.0, gray.std() / 25.0)
# Calculate brightness distribution
brightness_score = 10.0 - abs(gray.mean() - 127.5) / 12.75
# Combine scores
quality_score = (sharpness_score * 0.4 +
contrast_score * 0.3 +
brightness_score * 0.3)
return max(0.0, min(10.0, quality_score))
except Exception as e:
logger.error(f"Error in traditional metrics: {str(e)}")
return 5.0 # Default score
def fallback_quality_score(self, image: Image.Image) -> float:
"""Simple fallback quality assessment"""
try:
# Basic quality assessment based on image properties
width, height = image.size
# Resolution score
total_pixels = width * height
resolution_score = min(10.0, total_pixels / 100000.0) # Normalize by 1MP
# Aspect ratio score (prefer standard ratios)
aspect_ratio = width / height
if 0.5 <= aspect_ratio <= 2.0:
aspect_score = 8.0
else:
aspect_score = 5.0
# File format score (prefer lossless)
format_score = 8.0 if image.format == 'PNG' else 6.0
# Combine scores
quality_score = (resolution_score * 0.5 +
aspect_score * 0.3 +
format_score * 0.2)
return max(0.0, min(10.0, quality_score))
except Exception:
return 5.0 # Default neutral score
def evaluate(self, image: Image.Image, anime_mode: bool = False) -> float:
"""
Evaluate image quality using ensemble of models
Args:
image: PIL Image to evaluate
anime_mode: Whether to use anime-specific evaluation
Returns:
Quality score from 0-10
"""
try:
scores = []
# LAR-IQA evaluation
lar_score = self.evaluate_with_lar_iqa(image)
scores.append(lar_score)
# DGIQA evaluation
dgiqa_score = self.evaluate_with_dgiqa(image)
scores.append(dgiqa_score)
# Traditional metrics
traditional_score = self.evaluate_traditional_metrics(image)
scores.append(traditional_score)
# Ensemble scoring
if anime_mode:
# For anime images, weight traditional metrics higher
# as they may be more reliable for stylized content
weights = [0.3, 0.3, 0.4]
else:
# For realistic images, weight modern models higher
weights = [0.4, 0.4, 0.2]
final_score = sum(score * weight for score, weight in zip(scores, weights))
logger.info(f"Quality scores - LAR: {lar_score:.2f}, DGIQA: {dgiqa_score:.2f}, "
f"Traditional: {traditional_score:.2f}, Final: {final_score:.2f}")
return max(0.0, min(10.0, final_score))
except Exception as e:
logger.error(f"Error in quality evaluation: {str(e)}")
return self.fallback_quality_score(image)