Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from transformers import AutoModel, AutoProcessor | |
import cv2 | |
import logging | |
from scipy import ndimage | |
logger = logging.getLogger(__name__) | |
class AIDetectionEvaluator: | |
"""AI-generated image detection using multiple approaches""" | |
def __init__(self): | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.models = {} | |
self.processors = {} | |
self.load_models() | |
def load_models(self): | |
"""Load AI detection models""" | |
try: | |
# Load Sentry-Image model (primary) | |
logger.info("Loading Sentry-Image model...") | |
self.load_sentry_image() | |
# Load custom ensemble model (secondary) | |
logger.info("Loading custom ensemble model...") | |
self.load_custom_ensemble() | |
# Load traditional artifact detection | |
logger.info("Loading traditional artifact detection...") | |
self.load_artifact_detection() | |
except Exception as e: | |
logger.error(f"Error loading AI detection models: {str(e)}") | |
self.use_fallback_implementation() | |
def load_sentry_image(self): | |
"""Load Sentry-Image model""" | |
try: | |
# Placeholder implementation for Sentry-Image | |
# In production, this would load the actual Sentry-Image model | |
self.models['sentry'] = self.create_mock_detection_model() | |
self.processors['sentry'] = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
except Exception as e: | |
logger.warning(f"Could not load Sentry-Image: {str(e)}") | |
def load_custom_ensemble(self): | |
"""Load custom ensemble detection model""" | |
try: | |
# Placeholder for custom ensemble | |
self.models['ensemble'] = self.create_mock_detection_model() | |
self.processors['ensemble'] = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
except Exception as e: | |
logger.warning(f"Could not load custom ensemble: {str(e)}") | |
def load_artifact_detection(self): | |
"""Load traditional artifact detection methods""" | |
try: | |
# These would be implemented using opencv and scipy | |
self.artifact_detection_available = True | |
except Exception as e: | |
logger.warning(f"Could not load artifact detection: {str(e)}") | |
self.artifact_detection_available = False | |
def create_mock_detection_model(self): | |
"""Create a mock detection model for demonstration""" | |
class MockDetectionModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.backbone = torch.nn.Sequential( | |
torch.nn.Conv2d(3, 64, 3, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d(64, 128, 3, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.AdaptiveAvgPool2d((1, 1)), | |
torch.nn.Flatten(), | |
torch.nn.Linear(128, 64), | |
torch.nn.ReLU(), | |
torch.nn.Linear(64, 1), | |
torch.nn.Sigmoid() | |
) | |
def forward(self, x): | |
return self.backbone(x) # Returns probability 0-1 | |
model = MockDetectionModel().to(self.device) | |
model.eval() | |
return model | |
def use_fallback_implementation(self): | |
"""Use simple fallback AI detection""" | |
logger.info("Using fallback AI detection implementation") | |
self.fallback_mode = True | |
def evaluate_with_sentry(self, image: Image.Image) -> float: | |
"""Evaluate AI generation probability using Sentry-Image""" | |
try: | |
if 'sentry' not in self.models: | |
return self.fallback_detection_score(image) | |
# Preprocess image | |
tensor = self.processors['sentry'](image).unsqueeze(0).to(self.device) | |
# Get prediction | |
with torch.no_grad(): | |
probability = self.models['sentry'](tensor).item() | |
return max(0.0, min(1.0, probability)) | |
except Exception as e: | |
logger.error(f"Error in Sentry evaluation: {str(e)}") | |
return self.fallback_detection_score(image) | |
def evaluate_with_ensemble(self, image: Image.Image) -> float: | |
"""Evaluate AI generation probability using custom ensemble""" | |
try: | |
if 'ensemble' not in self.models: | |
return self.fallback_detection_score(image) | |
# Preprocess image | |
tensor = self.processors['ensemble'](image).unsqueeze(0).to(self.device) | |
# Get prediction | |
with torch.no_grad(): | |
probability = self.models['ensemble'](tensor).item() | |
return max(0.0, min(1.0, probability)) | |
except Exception as e: | |
logger.error(f"Error in ensemble evaluation: {str(e)}") | |
return self.fallback_detection_score(image) | |
def detect_compression_artifacts(self, image: Image.Image) -> float: | |
"""Detect compression artifacts that might indicate AI generation""" | |
try: | |
# Convert to numpy array | |
img_array = np.array(image) | |
# Convert to grayscale | |
if len(img_array.shape) == 3: | |
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
else: | |
gray = img_array | |
# Detect JPEG compression artifacts using DCT analysis | |
# This is a simplified version - real implementation would be more complex | |
# Calculate local variance to detect blocking artifacts | |
kernel = np.ones((8, 8), np.float32) / 64 | |
local_mean = cv2.filter2D(gray.astype(np.float32), -1, kernel) | |
local_variance = cv2.filter2D((gray.astype(np.float32) - local_mean) ** 2, -1, kernel) | |
# High variance in 8x8 blocks might indicate JPEG artifacts | |
block_variance = np.mean(local_variance) | |
# Normalize to 0-1 probability | |
artifact_probability = min(1.0, block_variance / 1000.0) | |
return artifact_probability | |
except Exception as e: | |
logger.error(f"Error in compression artifact detection: {str(e)}") | |
return 0.5 | |
def detect_frequency_anomalies(self, image: Image.Image) -> float: | |
"""Detect frequency domain anomalies common in AI-generated images""" | |
try: | |
# Convert to numpy array and grayscale | |
img_array = np.array(image) | |
if len(img_array.shape) == 3: | |
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
else: | |
gray = img_array | |
# Apply FFT | |
f_transform = np.fft.fft2(gray) | |
f_shift = np.fft.fftshift(f_transform) | |
magnitude_spectrum = np.log(np.abs(f_shift) + 1) | |
# Analyze frequency distribution | |
# AI-generated images often have specific frequency patterns | |
# Calculate radial frequency distribution | |
h, w = magnitude_spectrum.shape | |
center_y, center_x = h // 2, w // 2 | |
# Create radial mask | |
y, x = np.ogrid[:h, :w] | |
mask = (x - center_x) ** 2 + (y - center_y) ** 2 | |
# Calculate mean magnitude at different frequencies | |
low_freq_mask = mask <= (min(h, w) // 8) ** 2 | |
high_freq_mask = mask >= (min(h, w) // 4) ** 2 | |
low_freq_energy = np.mean(magnitude_spectrum[low_freq_mask]) | |
high_freq_energy = np.mean(magnitude_spectrum[high_freq_mask]) | |
# AI images often have unusual low/high frequency ratios | |
if high_freq_energy > 0: | |
freq_ratio = low_freq_energy / high_freq_energy | |
# Normalize to probability | |
anomaly_probability = min(1.0, abs(freq_ratio - 10.0) / 20.0) | |
else: | |
anomaly_probability = 0.5 | |
return anomaly_probability | |
except Exception as e: | |
logger.error(f"Error in frequency analysis: {str(e)}") | |
return 0.5 | |
def detect_pixel_patterns(self, image: Image.Image) -> float: | |
"""Detect suspicious pixel patterns common in AI-generated images""" | |
try: | |
img_array = np.array(image) | |
# Check for perfect pixel repetitions (uncommon in natural images) | |
if len(img_array.shape) == 3: | |
# Flatten to check for repeated pixel values | |
pixels = img_array.reshape(-1, 3) | |
unique_pixels = np.unique(pixels, axis=0) | |
# Calculate pixel diversity | |
pixel_diversity = len(unique_pixels) / len(pixels) | |
# Very low diversity might indicate AI generation | |
if pixel_diversity < 0.1: | |
pattern_probability = 0.8 | |
elif pixel_diversity < 0.3: | |
pattern_probability = 0.6 | |
else: | |
pattern_probability = 0.2 | |
else: | |
pattern_probability = 0.5 | |
# Check for unnatural smoothness | |
if len(img_array.shape) == 3: | |
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
else: | |
gray = img_array | |
# Calculate local standard deviation | |
local_std = ndimage.generic_filter(gray.astype(np.float32), np.std, size=3) | |
avg_local_std = np.mean(local_std) | |
# Very smooth images might be AI-generated | |
if avg_local_std < 5.0: | |
smoothness_probability = 0.7 | |
elif avg_local_std < 15.0: | |
smoothness_probability = 0.4 | |
else: | |
smoothness_probability = 0.2 | |
# Combine pattern and smoothness indicators | |
combined_probability = (pattern_probability + smoothness_probability) / 2 | |
return max(0.0, min(1.0, combined_probability)) | |
except Exception as e: | |
logger.error(f"Error in pixel pattern detection: {str(e)}") | |
return 0.5 | |
def analyze_metadata_indicators(self, image: Image.Image) -> float: | |
"""Analyze image metadata for AI generation indicators""" | |
try: | |
# Check image format and properties | |
format_probability = 0.0 | |
# PNG format is more common for AI-generated images | |
if image.format == 'PNG': | |
format_probability += 0.3 | |
# Check for specific dimensions common in AI generation | |
width, height = image.size | |
# Common AI generation resolutions | |
ai_resolutions = [ | |
(512, 512), (768, 768), (1024, 1024), # Square formats | |
(512, 768), (768, 512), # 2:3 ratios | |
(1024, 768), (768, 1024) # 4:3 ratios | |
] | |
if (width, height) in ai_resolutions: | |
format_probability += 0.4 | |
# Check for perfect aspect ratios (less common in natural photos) | |
aspect_ratio = width / height | |
common_ai_ratios = [1.0, 1.5, 0.67, 1.33, 0.75, 1.25] | |
for ratio in common_ai_ratios: | |
if abs(aspect_ratio - ratio) < 0.01: | |
format_probability += 0.2 | |
break | |
return max(0.0, min(1.0, format_probability)) | |
except Exception as e: | |
logger.error(f"Error in metadata analysis: {str(e)}") | |
return 0.5 | |
def fallback_detection_score(self, image: Image.Image) -> float: | |
"""Simple fallback AI detection""" | |
try: | |
# Combine multiple simple heuristics | |
scores = [] | |
# Compression artifacts | |
artifact_score = self.detect_compression_artifacts(image) | |
scores.append(artifact_score) | |
# Frequency anomalies | |
freq_score = self.detect_frequency_anomalies(image) | |
scores.append(freq_score) | |
# Pixel patterns | |
pattern_score = self.detect_pixel_patterns(image) | |
scores.append(pattern_score) | |
# Metadata indicators | |
metadata_score = self.analyze_metadata_indicators(image) | |
scores.append(metadata_score) | |
# Average the scores | |
final_score = np.mean(scores) | |
return max(0.0, min(1.0, final_score)) | |
except Exception: | |
return 0.5 # Default neutral probability | |
def evaluate(self, image: Image.Image) -> float: | |
""" | |
Evaluate probability that image is AI-generated | |
Args: | |
image: PIL Image to evaluate | |
Returns: | |
AI generation probability from 0-1 (0 = likely real, 1 = likely AI) | |
""" | |
try: | |
scores = [] | |
# Sentry-Image evaluation (primary) | |
sentry_score = self.evaluate_with_sentry(image) | |
scores.append(sentry_score) | |
# Custom ensemble evaluation (secondary) | |
ensemble_score = self.evaluate_with_ensemble(image) | |
scores.append(ensemble_score) | |
# Traditional artifact detection | |
artifact_score = self.fallback_detection_score(image) | |
scores.append(artifact_score) | |
# Ensemble scoring | |
weights = [0.5, 0.3, 0.2] # Sentry gets highest weight | |
final_score = sum(score * weight for score, weight in zip(scores, weights)) | |
logger.info(f"AI detection scores - Sentry: {sentry_score:.3f}, " | |
f"Ensemble: {ensemble_score:.3f}, Artifacts: {artifact_score:.3f}, " | |
f"Final: {final_score:.3f}") | |
return max(0.0, min(1.0, final_score)) | |
except Exception as e: | |
logger.error(f"Error in AI detection evaluation: {str(e)}") | |
return self.fallback_detection_score(image) | |