import torch import torch.nn as nn import numpy as np from PIL import Image import torchvision.transforms as transforms from transformers import AutoModel, AutoProcessor import cv2 import logging from scipy import ndimage logger = logging.getLogger(__name__) class AIDetectionEvaluator: """AI-generated image detection using multiple approaches""" def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.models = {} self.processors = {} self.load_models() def load_models(self): """Load AI detection models""" try: # Load Sentry-Image model (primary) logger.info("Loading Sentry-Image model...") self.load_sentry_image() # Load custom ensemble model (secondary) logger.info("Loading custom ensemble model...") self.load_custom_ensemble() # Load traditional artifact detection logger.info("Loading traditional artifact detection...") self.load_artifact_detection() except Exception as e: logger.error(f"Error loading AI detection models: {str(e)}") self.use_fallback_implementation() def load_sentry_image(self): """Load Sentry-Image model""" try: # Placeholder implementation for Sentry-Image # In production, this would load the actual Sentry-Image model self.models['sentry'] = self.create_mock_detection_model() self.processors['sentry'] = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) except Exception as e: logger.warning(f"Could not load Sentry-Image: {str(e)}") def load_custom_ensemble(self): """Load custom ensemble detection model""" try: # Placeholder for custom ensemble self.models['ensemble'] = self.create_mock_detection_model() self.processors['ensemble'] = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) except Exception as e: logger.warning(f"Could not load custom ensemble: {str(e)}") def load_artifact_detection(self): """Load traditional artifact detection methods""" try: # These would be implemented using opencv and scipy self.artifact_detection_available = True except Exception as e: logger.warning(f"Could not load artifact detection: {str(e)}") self.artifact_detection_available = False def create_mock_detection_model(self): """Create a mock detection model for demonstration""" class MockDetectionModel(nn.Module): def __init__(self): super().__init__() self.backbone = torch.nn.Sequential( torch.nn.Conv2d(3, 64, 3, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(64, 128, 3, padding=1), torch.nn.ReLU(), torch.nn.AdaptiveAvgPool2d((1, 1)), torch.nn.Flatten(), torch.nn.Linear(128, 64), torch.nn.ReLU(), torch.nn.Linear(64, 1), torch.nn.Sigmoid() ) def forward(self, x): return self.backbone(x) # Returns probability 0-1 model = MockDetectionModel().to(self.device) model.eval() return model def use_fallback_implementation(self): """Use simple fallback AI detection""" logger.info("Using fallback AI detection implementation") self.fallback_mode = True def evaluate_with_sentry(self, image: Image.Image) -> float: """Evaluate AI generation probability using Sentry-Image""" try: if 'sentry' not in self.models: return self.fallback_detection_score(image) # Preprocess image tensor = self.processors['sentry'](image).unsqueeze(0).to(self.device) # Get prediction with torch.no_grad(): probability = self.models['sentry'](tensor).item() return max(0.0, min(1.0, probability)) except Exception as e: logger.error(f"Error in Sentry evaluation: {str(e)}") return self.fallback_detection_score(image) def evaluate_with_ensemble(self, image: Image.Image) -> float: """Evaluate AI generation probability using custom ensemble""" try: if 'ensemble' not in self.models: return self.fallback_detection_score(image) # Preprocess image tensor = self.processors['ensemble'](image).unsqueeze(0).to(self.device) # Get prediction with torch.no_grad(): probability = self.models['ensemble'](tensor).item() return max(0.0, min(1.0, probability)) except Exception as e: logger.error(f"Error in ensemble evaluation: {str(e)}") return self.fallback_detection_score(image) def detect_compression_artifacts(self, image: Image.Image) -> float: """Detect compression artifacts that might indicate AI generation""" try: # Convert to numpy array img_array = np.array(image) # Convert to grayscale if len(img_array.shape) == 3: gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) else: gray = img_array # Detect JPEG compression artifacts using DCT analysis # This is a simplified version - real implementation would be more complex # Calculate local variance to detect blocking artifacts kernel = np.ones((8, 8), np.float32) / 64 local_mean = cv2.filter2D(gray.astype(np.float32), -1, kernel) local_variance = cv2.filter2D((gray.astype(np.float32) - local_mean) ** 2, -1, kernel) # High variance in 8x8 blocks might indicate JPEG artifacts block_variance = np.mean(local_variance) # Normalize to 0-1 probability artifact_probability = min(1.0, block_variance / 1000.0) return artifact_probability except Exception as e: logger.error(f"Error in compression artifact detection: {str(e)}") return 0.5 def detect_frequency_anomalies(self, image: Image.Image) -> float: """Detect frequency domain anomalies common in AI-generated images""" try: # Convert to numpy array and grayscale img_array = np.array(image) if len(img_array.shape) == 3: gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) else: gray = img_array # Apply FFT f_transform = np.fft.fft2(gray) f_shift = np.fft.fftshift(f_transform) magnitude_spectrum = np.log(np.abs(f_shift) + 1) # Analyze frequency distribution # AI-generated images often have specific frequency patterns # Calculate radial frequency distribution h, w = magnitude_spectrum.shape center_y, center_x = h // 2, w // 2 # Create radial mask y, x = np.ogrid[:h, :w] mask = (x - center_x) ** 2 + (y - center_y) ** 2 # Calculate mean magnitude at different frequencies low_freq_mask = mask <= (min(h, w) // 8) ** 2 high_freq_mask = mask >= (min(h, w) // 4) ** 2 low_freq_energy = np.mean(magnitude_spectrum[low_freq_mask]) high_freq_energy = np.mean(magnitude_spectrum[high_freq_mask]) # AI images often have unusual low/high frequency ratios if high_freq_energy > 0: freq_ratio = low_freq_energy / high_freq_energy # Normalize to probability anomaly_probability = min(1.0, abs(freq_ratio - 10.0) / 20.0) else: anomaly_probability = 0.5 return anomaly_probability except Exception as e: logger.error(f"Error in frequency analysis: {str(e)}") return 0.5 def detect_pixel_patterns(self, image: Image.Image) -> float: """Detect suspicious pixel patterns common in AI-generated images""" try: img_array = np.array(image) # Check for perfect pixel repetitions (uncommon in natural images) if len(img_array.shape) == 3: # Flatten to check for repeated pixel values pixels = img_array.reshape(-1, 3) unique_pixels = np.unique(pixels, axis=0) # Calculate pixel diversity pixel_diversity = len(unique_pixels) / len(pixels) # Very low diversity might indicate AI generation if pixel_diversity < 0.1: pattern_probability = 0.8 elif pixel_diversity < 0.3: pattern_probability = 0.6 else: pattern_probability = 0.2 else: pattern_probability = 0.5 # Check for unnatural smoothness if len(img_array.shape) == 3: gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) else: gray = img_array # Calculate local standard deviation local_std = ndimage.generic_filter(gray.astype(np.float32), np.std, size=3) avg_local_std = np.mean(local_std) # Very smooth images might be AI-generated if avg_local_std < 5.0: smoothness_probability = 0.7 elif avg_local_std < 15.0: smoothness_probability = 0.4 else: smoothness_probability = 0.2 # Combine pattern and smoothness indicators combined_probability = (pattern_probability + smoothness_probability) / 2 return max(0.0, min(1.0, combined_probability)) except Exception as e: logger.error(f"Error in pixel pattern detection: {str(e)}") return 0.5 def analyze_metadata_indicators(self, image: Image.Image) -> float: """Analyze image metadata for AI generation indicators""" try: # Check image format and properties format_probability = 0.0 # PNG format is more common for AI-generated images if image.format == 'PNG': format_probability += 0.3 # Check for specific dimensions common in AI generation width, height = image.size # Common AI generation resolutions ai_resolutions = [ (512, 512), (768, 768), (1024, 1024), # Square formats (512, 768), (768, 512), # 2:3 ratios (1024, 768), (768, 1024) # 4:3 ratios ] if (width, height) in ai_resolutions: format_probability += 0.4 # Check for perfect aspect ratios (less common in natural photos) aspect_ratio = width / height common_ai_ratios = [1.0, 1.5, 0.67, 1.33, 0.75, 1.25] for ratio in common_ai_ratios: if abs(aspect_ratio - ratio) < 0.01: format_probability += 0.2 break return max(0.0, min(1.0, format_probability)) except Exception as e: logger.error(f"Error in metadata analysis: {str(e)}") return 0.5 def fallback_detection_score(self, image: Image.Image) -> float: """Simple fallback AI detection""" try: # Combine multiple simple heuristics scores = [] # Compression artifacts artifact_score = self.detect_compression_artifacts(image) scores.append(artifact_score) # Frequency anomalies freq_score = self.detect_frequency_anomalies(image) scores.append(freq_score) # Pixel patterns pattern_score = self.detect_pixel_patterns(image) scores.append(pattern_score) # Metadata indicators metadata_score = self.analyze_metadata_indicators(image) scores.append(metadata_score) # Average the scores final_score = np.mean(scores) return max(0.0, min(1.0, final_score)) except Exception: return 0.5 # Default neutral probability def evaluate(self, image: Image.Image) -> float: """ Evaluate probability that image is AI-generated Args: image: PIL Image to evaluate Returns: AI generation probability from 0-1 (0 = likely real, 1 = likely AI) """ try: scores = [] # Sentry-Image evaluation (primary) sentry_score = self.evaluate_with_sentry(image) scores.append(sentry_score) # Custom ensemble evaluation (secondary) ensemble_score = self.evaluate_with_ensemble(image) scores.append(ensemble_score) # Traditional artifact detection artifact_score = self.fallback_detection_score(image) scores.append(artifact_score) # Ensemble scoring weights = [0.5, 0.3, 0.2] # Sentry gets highest weight final_score = sum(score * weight for score, weight in zip(scores, weights)) logger.info(f"AI detection scores - Sentry: {sentry_score:.3f}, " f"Ensemble: {ensemble_score:.3f}, Artifacts: {artifact_score:.3f}, " f"Final: {final_score:.3f}") return max(0.0, min(1.0, final_score)) except Exception as e: logger.error(f"Error in AI detection evaluation: {str(e)}") return self.fallback_detection_score(image)