import torch import torch.nn as nn import numpy as np from PIL import Image import torchvision.transforms as transforms from transformers import AutoModel, AutoProcessor import logging logger = logging.getLogger(__name__) class QualityEvaluator: """Image quality assessment using multiple SOTA models""" def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.models = {} self.processors = {} self.load_models() def load_models(self): """Load quality assessment models""" try: # Load LAR-IQA model (primary) logger.info("Loading LAR-IQA model...") self.load_lar_iqa() # Load DGIQA model (secondary) logger.info("Loading DGIQA model...") self.load_dgiqa() # Load traditional metrics as fallback logger.info("Loading traditional quality metrics...") self.load_traditional_metrics() except Exception as e: logger.error(f"Error loading quality models: {str(e)}") # Use fallback implementation self.use_fallback_implementation() def load_lar_iqa(self): """Load LAR-IQA model""" try: # For now, use a placeholder implementation # In production, this would load the actual LAR-IQA model self.models['lar_iqa'] = self.create_mock_model() self.processors['lar_iqa'] = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) except Exception as e: logger.warning(f"Could not load LAR-IQA: {str(e)}") def load_dgiqa(self): """Load DGIQA model""" try: # Placeholder implementation self.models['dgiqa'] = self.create_mock_model() self.processors['dgiqa'] = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) except Exception as e: logger.warning(f"Could not load DGIQA: {str(e)}") def load_traditional_metrics(self): """Load traditional quality metrics (BRISQUE, NIQE, etc.)""" try: # These would be implemented using scikit-image or opencv self.traditional_metrics_available = True except Exception as e: logger.warning(f"Could not load traditional metrics: {str(e)}") self.traditional_metrics_available = False def create_mock_model(self): """Create a mock model for demonstration purposes""" class MockQualityModel(nn.Module): def __init__(self): super().__init__() self.backbone = torch.nn.Sequential( torch.nn.Conv2d(3, 64, 3, padding=1), torch.nn.ReLU(), torch.nn.AdaptiveAvgPool2d((1, 1)), torch.nn.Flatten(), torch.nn.Linear(64, 1), torch.nn.Sigmoid() ) def forward(self, x): return self.backbone(x) * 10 # Scale to 0-10 model = MockQualityModel().to(self.device) model.eval() return model def use_fallback_implementation(self): """Use simple fallback quality assessment""" logger.info("Using fallback quality assessment implementation") self.fallback_mode = True def evaluate_with_lar_iqa(self, image: Image.Image) -> float: """Evaluate image quality using LAR-IQA""" try: if 'lar_iqa' not in self.models: return self.fallback_quality_score(image) # Preprocess image tensor = self.processors['lar_iqa'](image).unsqueeze(0).to(self.device) # Get prediction with torch.no_grad(): score = self.models['lar_iqa'](tensor).item() return max(0.0, min(10.0, score)) except Exception as e: logger.error(f"Error in LAR-IQA evaluation: {str(e)}") return self.fallback_quality_score(image) def evaluate_with_dgiqa(self, image: Image.Image) -> float: """Evaluate image quality using DGIQA""" try: if 'dgiqa' not in self.models: return self.fallback_quality_score(image) # Preprocess image tensor = self.processors['dgiqa'](image).unsqueeze(0).to(self.device) # Get prediction with torch.no_grad(): score = self.models['dgiqa'](tensor).item() return max(0.0, min(10.0, score)) except Exception as e: logger.error(f"Error in DGIQA evaluation: {str(e)}") return self.fallback_quality_score(image) def evaluate_traditional_metrics(self, image: Image.Image) -> float: """Evaluate using traditional quality metrics""" try: # Convert to numpy array img_array = np.array(image) # Simple quality metrics based on image statistics # In production, this would use BRISQUE, NIQE, etc. # Calculate sharpness (Laplacian variance) from scipy import ndimage gray = np.dot(img_array[...,:3], [0.2989, 0.5870, 0.1140]) laplacian_var = ndimage.laplace(gray).var() sharpness_score = min(10.0, laplacian_var / 100.0) # Calculate contrast contrast_score = min(10.0, gray.std() / 25.0) # Calculate brightness distribution brightness_score = 10.0 - abs(gray.mean() - 127.5) / 12.75 # Combine scores quality_score = (sharpness_score * 0.4 + contrast_score * 0.3 + brightness_score * 0.3) return max(0.0, min(10.0, quality_score)) except Exception as e: logger.error(f"Error in traditional metrics: {str(e)}") return 5.0 # Default score def fallback_quality_score(self, image: Image.Image) -> float: """Simple fallback quality assessment""" try: # Basic quality assessment based on image properties width, height = image.size # Resolution score total_pixels = width * height resolution_score = min(10.0, total_pixels / 100000.0) # Normalize by 1MP # Aspect ratio score (prefer standard ratios) aspect_ratio = width / height if 0.5 <= aspect_ratio <= 2.0: aspect_score = 8.0 else: aspect_score = 5.0 # File format score (prefer lossless) format_score = 8.0 if image.format == 'PNG' else 6.0 # Combine scores quality_score = (resolution_score * 0.5 + aspect_score * 0.3 + format_score * 0.2) return max(0.0, min(10.0, quality_score)) except Exception: return 5.0 # Default neutral score def evaluate(self, image: Image.Image, anime_mode: bool = False) -> float: """ Evaluate image quality using ensemble of models Args: image: PIL Image to evaluate anime_mode: Whether to use anime-specific evaluation Returns: Quality score from 0-10 """ try: scores = [] # LAR-IQA evaluation lar_score = self.evaluate_with_lar_iqa(image) scores.append(lar_score) # DGIQA evaluation dgiqa_score = self.evaluate_with_dgiqa(image) scores.append(dgiqa_score) # Traditional metrics traditional_score = self.evaluate_traditional_metrics(image) scores.append(traditional_score) # Ensemble scoring if anime_mode: # For anime images, weight traditional metrics higher # as they may be more reliable for stylized content weights = [0.3, 0.3, 0.4] else: # For realistic images, weight modern models higher weights = [0.4, 0.4, 0.2] final_score = sum(score * weight for score, weight in zip(scores, weights)) logger.info(f"Quality scores - LAR: {lar_score:.2f}, DGIQA: {dgiqa_score:.2f}, " f"Traditional: {traditional_score:.2f}, Final: {final_score:.2f}") return max(0.0, min(10.0, final_score)) except Exception as e: logger.error(f"Error in quality evaluation: {str(e)}") return self.fallback_quality_score(image)