image-evaluation-tool / models /ai_detection_evaluator.py
VOIDER's picture
Upload 14 files
83b7522 verified
raw
history blame
15.2 kB
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from transformers import AutoModel, AutoProcessor
import cv2
import logging
from scipy import ndimage
logger = logging.getLogger(__name__)
class AIDetectionEvaluator:
"""AI-generated image detection using multiple approaches"""
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.models = {}
self.processors = {}
self.load_models()
def load_models(self):
"""Load AI detection models"""
try:
# Load Sentry-Image model (primary)
logger.info("Loading Sentry-Image model...")
self.load_sentry_image()
# Load custom ensemble model (secondary)
logger.info("Loading custom ensemble model...")
self.load_custom_ensemble()
# Load traditional artifact detection
logger.info("Loading traditional artifact detection...")
self.load_artifact_detection()
except Exception as e:
logger.error(f"Error loading AI detection models: {str(e)}")
self.use_fallback_implementation()
def load_sentry_image(self):
"""Load Sentry-Image model"""
try:
# Placeholder implementation for Sentry-Image
# In production, this would load the actual Sentry-Image model
self.models['sentry'] = self.create_mock_detection_model()
self.processors['sentry'] = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
except Exception as e:
logger.warning(f"Could not load Sentry-Image: {str(e)}")
def load_custom_ensemble(self):
"""Load custom ensemble detection model"""
try:
# Placeholder for custom ensemble
self.models['ensemble'] = self.create_mock_detection_model()
self.processors['ensemble'] = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
except Exception as e:
logger.warning(f"Could not load custom ensemble: {str(e)}")
def load_artifact_detection(self):
"""Load traditional artifact detection methods"""
try:
# These would be implemented using opencv and scipy
self.artifact_detection_available = True
except Exception as e:
logger.warning(f"Could not load artifact detection: {str(e)}")
self.artifact_detection_available = False
def create_mock_detection_model(self):
"""Create a mock detection model for demonstration"""
class MockDetectionModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 128, 3, padding=1),
torch.nn.ReLU(),
torch.nn.AdaptiveAvgPool2d((1, 1)),
torch.nn.Flatten(),
torch.nn.Linear(128, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 1),
torch.nn.Sigmoid()
)
def forward(self, x):
return self.backbone(x) # Returns probability 0-1
model = MockDetectionModel().to(self.device)
model.eval()
return model
def use_fallback_implementation(self):
"""Use simple fallback AI detection"""
logger.info("Using fallback AI detection implementation")
self.fallback_mode = True
def evaluate_with_sentry(self, image: Image.Image) -> float:
"""Evaluate AI generation probability using Sentry-Image"""
try:
if 'sentry' not in self.models:
return self.fallback_detection_score(image)
# Preprocess image
tensor = self.processors['sentry'](image).unsqueeze(0).to(self.device)
# Get prediction
with torch.no_grad():
probability = self.models['sentry'](tensor).item()
return max(0.0, min(1.0, probability))
except Exception as e:
logger.error(f"Error in Sentry evaluation: {str(e)}")
return self.fallback_detection_score(image)
def evaluate_with_ensemble(self, image: Image.Image) -> float:
"""Evaluate AI generation probability using custom ensemble"""
try:
if 'ensemble' not in self.models:
return self.fallback_detection_score(image)
# Preprocess image
tensor = self.processors['ensemble'](image).unsqueeze(0).to(self.device)
# Get prediction
with torch.no_grad():
probability = self.models['ensemble'](tensor).item()
return max(0.0, min(1.0, probability))
except Exception as e:
logger.error(f"Error in ensemble evaluation: {str(e)}")
return self.fallback_detection_score(image)
def detect_compression_artifacts(self, image: Image.Image) -> float:
"""Detect compression artifacts that might indicate AI generation"""
try:
# Convert to numpy array
img_array = np.array(image)
# Convert to grayscale
if len(img_array.shape) == 3:
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
else:
gray = img_array
# Detect JPEG compression artifacts using DCT analysis
# This is a simplified version - real implementation would be more complex
# Calculate local variance to detect blocking artifacts
kernel = np.ones((8, 8), np.float32) / 64
local_mean = cv2.filter2D(gray.astype(np.float32), -1, kernel)
local_variance = cv2.filter2D((gray.astype(np.float32) - local_mean) ** 2, -1, kernel)
# High variance in 8x8 blocks might indicate JPEG artifacts
block_variance = np.mean(local_variance)
# Normalize to 0-1 probability
artifact_probability = min(1.0, block_variance / 1000.0)
return artifact_probability
except Exception as e:
logger.error(f"Error in compression artifact detection: {str(e)}")
return 0.5
def detect_frequency_anomalies(self, image: Image.Image) -> float:
"""Detect frequency domain anomalies common in AI-generated images"""
try:
# Convert to numpy array and grayscale
img_array = np.array(image)
if len(img_array.shape) == 3:
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
else:
gray = img_array
# Apply FFT
f_transform = np.fft.fft2(gray)
f_shift = np.fft.fftshift(f_transform)
magnitude_spectrum = np.log(np.abs(f_shift) + 1)
# Analyze frequency distribution
# AI-generated images often have specific frequency patterns
# Calculate radial frequency distribution
h, w = magnitude_spectrum.shape
center_y, center_x = h // 2, w // 2
# Create radial mask
y, x = np.ogrid[:h, :w]
mask = (x - center_x) ** 2 + (y - center_y) ** 2
# Calculate mean magnitude at different frequencies
low_freq_mask = mask <= (min(h, w) // 8) ** 2
high_freq_mask = mask >= (min(h, w) // 4) ** 2
low_freq_energy = np.mean(magnitude_spectrum[low_freq_mask])
high_freq_energy = np.mean(magnitude_spectrum[high_freq_mask])
# AI images often have unusual low/high frequency ratios
if high_freq_energy > 0:
freq_ratio = low_freq_energy / high_freq_energy
# Normalize to probability
anomaly_probability = min(1.0, abs(freq_ratio - 10.0) / 20.0)
else:
anomaly_probability = 0.5
return anomaly_probability
except Exception as e:
logger.error(f"Error in frequency analysis: {str(e)}")
return 0.5
def detect_pixel_patterns(self, image: Image.Image) -> float:
"""Detect suspicious pixel patterns common in AI-generated images"""
try:
img_array = np.array(image)
# Check for perfect pixel repetitions (uncommon in natural images)
if len(img_array.shape) == 3:
# Flatten to check for repeated pixel values
pixels = img_array.reshape(-1, 3)
unique_pixels = np.unique(pixels, axis=0)
# Calculate pixel diversity
pixel_diversity = len(unique_pixels) / len(pixels)
# Very low diversity might indicate AI generation
if pixel_diversity < 0.1:
pattern_probability = 0.8
elif pixel_diversity < 0.3:
pattern_probability = 0.6
else:
pattern_probability = 0.2
else:
pattern_probability = 0.5
# Check for unnatural smoothness
if len(img_array.shape) == 3:
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
else:
gray = img_array
# Calculate local standard deviation
local_std = ndimage.generic_filter(gray.astype(np.float32), np.std, size=3)
avg_local_std = np.mean(local_std)
# Very smooth images might be AI-generated
if avg_local_std < 5.0:
smoothness_probability = 0.7
elif avg_local_std < 15.0:
smoothness_probability = 0.4
else:
smoothness_probability = 0.2
# Combine pattern and smoothness indicators
combined_probability = (pattern_probability + smoothness_probability) / 2
return max(0.0, min(1.0, combined_probability))
except Exception as e:
logger.error(f"Error in pixel pattern detection: {str(e)}")
return 0.5
def analyze_metadata_indicators(self, image: Image.Image) -> float:
"""Analyze image metadata for AI generation indicators"""
try:
# Check image format and properties
format_probability = 0.0
# PNG format is more common for AI-generated images
if image.format == 'PNG':
format_probability += 0.3
# Check for specific dimensions common in AI generation
width, height = image.size
# Common AI generation resolutions
ai_resolutions = [
(512, 512), (768, 768), (1024, 1024), # Square formats
(512, 768), (768, 512), # 2:3 ratios
(1024, 768), (768, 1024) # 4:3 ratios
]
if (width, height) in ai_resolutions:
format_probability += 0.4
# Check for perfect aspect ratios (less common in natural photos)
aspect_ratio = width / height
common_ai_ratios = [1.0, 1.5, 0.67, 1.33, 0.75, 1.25]
for ratio in common_ai_ratios:
if abs(aspect_ratio - ratio) < 0.01:
format_probability += 0.2
break
return max(0.0, min(1.0, format_probability))
except Exception as e:
logger.error(f"Error in metadata analysis: {str(e)}")
return 0.5
def fallback_detection_score(self, image: Image.Image) -> float:
"""Simple fallback AI detection"""
try:
# Combine multiple simple heuristics
scores = []
# Compression artifacts
artifact_score = self.detect_compression_artifacts(image)
scores.append(artifact_score)
# Frequency anomalies
freq_score = self.detect_frequency_anomalies(image)
scores.append(freq_score)
# Pixel patterns
pattern_score = self.detect_pixel_patterns(image)
scores.append(pattern_score)
# Metadata indicators
metadata_score = self.analyze_metadata_indicators(image)
scores.append(metadata_score)
# Average the scores
final_score = np.mean(scores)
return max(0.0, min(1.0, final_score))
except Exception:
return 0.5 # Default neutral probability
def evaluate(self, image: Image.Image) -> float:
"""
Evaluate probability that image is AI-generated
Args:
image: PIL Image to evaluate
Returns:
AI generation probability from 0-1 (0 = likely real, 1 = likely AI)
"""
try:
scores = []
# Sentry-Image evaluation (primary)
sentry_score = self.evaluate_with_sentry(image)
scores.append(sentry_score)
# Custom ensemble evaluation (secondary)
ensemble_score = self.evaluate_with_ensemble(image)
scores.append(ensemble_score)
# Traditional artifact detection
artifact_score = self.fallback_detection_score(image)
scores.append(artifact_score)
# Ensemble scoring
weights = [0.5, 0.3, 0.2] # Sentry gets highest weight
final_score = sum(score * weight for score, weight in zip(scores, weights))
logger.info(f"AI detection scores - Sentry: {sentry_score:.3f}, "
f"Ensemble: {ensemble_score:.3f}, Artifacts: {artifact_score:.3f}, "
f"Final: {final_score:.3f}")
return max(0.0, min(1.0, final_score))
except Exception as e:
logger.error(f"Error in AI detection evaluation: {str(e)}")
return self.fallback_detection_score(image)