File size: 15,168 Bytes
83b7522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from transformers import AutoModel, AutoProcessor
import cv2
import logging
from scipy import ndimage

logger = logging.getLogger(__name__)

class AIDetectionEvaluator:
    """AI-generated image detection using multiple approaches"""
    
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.models = {}
        self.processors = {}
        self.load_models()
    
    def load_models(self):
        """Load AI detection models"""
        try:
            # Load Sentry-Image model (primary)
            logger.info("Loading Sentry-Image model...")
            self.load_sentry_image()
            
            # Load custom ensemble model (secondary)
            logger.info("Loading custom ensemble model...")
            self.load_custom_ensemble()
            
            # Load traditional artifact detection
            logger.info("Loading traditional artifact detection...")
            self.load_artifact_detection()
            
        except Exception as e:
            logger.error(f"Error loading AI detection models: {str(e)}")
            self.use_fallback_implementation()
    
    def load_sentry_image(self):
        """Load Sentry-Image model"""
        try:
            # Placeholder implementation for Sentry-Image
            # In production, this would load the actual Sentry-Image model
            self.models['sentry'] = self.create_mock_detection_model()
            self.processors['sentry'] = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
        except Exception as e:
            logger.warning(f"Could not load Sentry-Image: {str(e)}")
    
    def load_custom_ensemble(self):
        """Load custom ensemble detection model"""
        try:
            # Placeholder for custom ensemble
            self.models['ensemble'] = self.create_mock_detection_model()
            self.processors['ensemble'] = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
        except Exception as e:
            logger.warning(f"Could not load custom ensemble: {str(e)}")
    
    def load_artifact_detection(self):
        """Load traditional artifact detection methods"""
        try:
            # These would be implemented using opencv and scipy
            self.artifact_detection_available = True
        except Exception as e:
            logger.warning(f"Could not load artifact detection: {str(e)}")
            self.artifact_detection_available = False
    
    def create_mock_detection_model(self):
        """Create a mock detection model for demonstration"""
        class MockDetectionModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.backbone = torch.nn.Sequential(
                    torch.nn.Conv2d(3, 64, 3, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(64, 128, 3, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.AdaptiveAvgPool2d((1, 1)),
                    torch.nn.Flatten(),
                    torch.nn.Linear(128, 64),
                    torch.nn.ReLU(),
                    torch.nn.Linear(64, 1),
                    torch.nn.Sigmoid()
                )
            
            def forward(self, x):
                return self.backbone(x)  # Returns probability 0-1
        
        model = MockDetectionModel().to(self.device)
        model.eval()
        return model
    
    def use_fallback_implementation(self):
        """Use simple fallback AI detection"""
        logger.info("Using fallback AI detection implementation")
        self.fallback_mode = True
    
    def evaluate_with_sentry(self, image: Image.Image) -> float:
        """Evaluate AI generation probability using Sentry-Image"""
        try:
            if 'sentry' not in self.models:
                return self.fallback_detection_score(image)
            
            # Preprocess image
            tensor = self.processors['sentry'](image).unsqueeze(0).to(self.device)
            
            # Get prediction
            with torch.no_grad():
                probability = self.models['sentry'](tensor).item()
            
            return max(0.0, min(1.0, probability))
            
        except Exception as e:
            logger.error(f"Error in Sentry evaluation: {str(e)}")
            return self.fallback_detection_score(image)
    
    def evaluate_with_ensemble(self, image: Image.Image) -> float:
        """Evaluate AI generation probability using custom ensemble"""
        try:
            if 'ensemble' not in self.models:
                return self.fallback_detection_score(image)
            
            # Preprocess image
            tensor = self.processors['ensemble'](image).unsqueeze(0).to(self.device)
            
            # Get prediction
            with torch.no_grad():
                probability = self.models['ensemble'](tensor).item()
            
            return max(0.0, min(1.0, probability))
            
        except Exception as e:
            logger.error(f"Error in ensemble evaluation: {str(e)}")
            return self.fallback_detection_score(image)
    
    def detect_compression_artifacts(self, image: Image.Image) -> float:
        """Detect compression artifacts that might indicate AI generation"""
        try:
            # Convert to numpy array
            img_array = np.array(image)
            
            # Convert to grayscale
            if len(img_array.shape) == 3:
                gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
            else:
                gray = img_array
            
            # Detect JPEG compression artifacts using DCT analysis
            # This is a simplified version - real implementation would be more complex
            
            # Calculate local variance to detect blocking artifacts
            kernel = np.ones((8, 8), np.float32) / 64
            local_mean = cv2.filter2D(gray.astype(np.float32), -1, kernel)
            local_variance = cv2.filter2D((gray.astype(np.float32) - local_mean) ** 2, -1, kernel)
            
            # High variance in 8x8 blocks might indicate JPEG artifacts
            block_variance = np.mean(local_variance)
            
            # Normalize to 0-1 probability
            artifact_probability = min(1.0, block_variance / 1000.0)
            
            return artifact_probability
            
        except Exception as e:
            logger.error(f"Error in compression artifact detection: {str(e)}")
            return 0.5
    
    def detect_frequency_anomalies(self, image: Image.Image) -> float:
        """Detect frequency domain anomalies common in AI-generated images"""
        try:
            # Convert to numpy array and grayscale
            img_array = np.array(image)
            if len(img_array.shape) == 3:
                gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
            else:
                gray = img_array
            
            # Apply FFT
            f_transform = np.fft.fft2(gray)
            f_shift = np.fft.fftshift(f_transform)
            magnitude_spectrum = np.log(np.abs(f_shift) + 1)
            
            # Analyze frequency distribution
            # AI-generated images often have specific frequency patterns
            
            # Calculate radial frequency distribution
            h, w = magnitude_spectrum.shape
            center_y, center_x = h // 2, w // 2
            
            # Create radial mask
            y, x = np.ogrid[:h, :w]
            mask = (x - center_x) ** 2 + (y - center_y) ** 2
            
            # Calculate mean magnitude at different frequencies
            low_freq_mask = mask <= (min(h, w) // 8) ** 2
            high_freq_mask = mask >= (min(h, w) // 4) ** 2
            
            low_freq_energy = np.mean(magnitude_spectrum[low_freq_mask])
            high_freq_energy = np.mean(magnitude_spectrum[high_freq_mask])
            
            # AI images often have unusual low/high frequency ratios
            if high_freq_energy > 0:
                freq_ratio = low_freq_energy / high_freq_energy
                # Normalize to probability
                anomaly_probability = min(1.0, abs(freq_ratio - 10.0) / 20.0)
            else:
                anomaly_probability = 0.5
            
            return anomaly_probability
            
        except Exception as e:
            logger.error(f"Error in frequency analysis: {str(e)}")
            return 0.5
    
    def detect_pixel_patterns(self, image: Image.Image) -> float:
        """Detect suspicious pixel patterns common in AI-generated images"""
        try:
            img_array = np.array(image)
            
            # Check for perfect pixel repetitions (uncommon in natural images)
            if len(img_array.shape) == 3:
                # Flatten to check for repeated pixel values
                pixels = img_array.reshape(-1, 3)
                unique_pixels = np.unique(pixels, axis=0)
                
                # Calculate pixel diversity
                pixel_diversity = len(unique_pixels) / len(pixels)
                
                # Very low diversity might indicate AI generation
                if pixel_diversity < 0.1:
                    pattern_probability = 0.8
                elif pixel_diversity < 0.3:
                    pattern_probability = 0.6
                else:
                    pattern_probability = 0.2
            else:
                pattern_probability = 0.5
            
            # Check for unnatural smoothness
            if len(img_array.shape) == 3:
                gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
            else:
                gray = img_array
            
            # Calculate local standard deviation
            local_std = ndimage.generic_filter(gray.astype(np.float32), np.std, size=3)
            avg_local_std = np.mean(local_std)
            
            # Very smooth images might be AI-generated
            if avg_local_std < 5.0:
                smoothness_probability = 0.7
            elif avg_local_std < 15.0:
                smoothness_probability = 0.4
            else:
                smoothness_probability = 0.2
            
            # Combine pattern and smoothness indicators
            combined_probability = (pattern_probability + smoothness_probability) / 2
            
            return max(0.0, min(1.0, combined_probability))
            
        except Exception as e:
            logger.error(f"Error in pixel pattern detection: {str(e)}")
            return 0.5
    
    def analyze_metadata_indicators(self, image: Image.Image) -> float:
        """Analyze image metadata for AI generation indicators"""
        try:
            # Check image format and properties
            format_probability = 0.0
            
            # PNG format is more common for AI-generated images
            if image.format == 'PNG':
                format_probability += 0.3
            
            # Check for specific dimensions common in AI generation
            width, height = image.size
            
            # Common AI generation resolutions
            ai_resolutions = [
                (512, 512), (768, 768), (1024, 1024),  # Square formats
                (512, 768), (768, 512),  # 2:3 ratios
                (1024, 768), (768, 1024)  # 4:3 ratios
            ]
            
            if (width, height) in ai_resolutions:
                format_probability += 0.4
            
            # Check for perfect aspect ratios (less common in natural photos)
            aspect_ratio = width / height
            common_ai_ratios = [1.0, 1.5, 0.67, 1.33, 0.75, 1.25]
            
            for ratio in common_ai_ratios:
                if abs(aspect_ratio - ratio) < 0.01:
                    format_probability += 0.2
                    break
            
            return max(0.0, min(1.0, format_probability))
            
        except Exception as e:
            logger.error(f"Error in metadata analysis: {str(e)}")
            return 0.5
    
    def fallback_detection_score(self, image: Image.Image) -> float:
        """Simple fallback AI detection"""
        try:
            # Combine multiple simple heuristics
            scores = []
            
            # Compression artifacts
            artifact_score = self.detect_compression_artifacts(image)
            scores.append(artifact_score)
            
            # Frequency anomalies
            freq_score = self.detect_frequency_anomalies(image)
            scores.append(freq_score)
            
            # Pixel patterns
            pattern_score = self.detect_pixel_patterns(image)
            scores.append(pattern_score)
            
            # Metadata indicators
            metadata_score = self.analyze_metadata_indicators(image)
            scores.append(metadata_score)
            
            # Average the scores
            final_score = np.mean(scores)
            
            return max(0.0, min(1.0, final_score))
            
        except Exception:
            return 0.5  # Default neutral probability
    
    def evaluate(self, image: Image.Image) -> float:
        """
        Evaluate probability that image is AI-generated
        
        Args:
            image: PIL Image to evaluate
            
        Returns:
            AI generation probability from 0-1 (0 = likely real, 1 = likely AI)
        """
        try:
            scores = []
            
            # Sentry-Image evaluation (primary)
            sentry_score = self.evaluate_with_sentry(image)
            scores.append(sentry_score)
            
            # Custom ensemble evaluation (secondary)
            ensemble_score = self.evaluate_with_ensemble(image)
            scores.append(ensemble_score)
            
            # Traditional artifact detection
            artifact_score = self.fallback_detection_score(image)
            scores.append(artifact_score)
            
            # Ensemble scoring
            weights = [0.5, 0.3, 0.2]  # Sentry gets highest weight
            final_score = sum(score * weight for score, weight in zip(scores, weights))
            
            logger.info(f"AI detection scores - Sentry: {sentry_score:.3f}, "
                       f"Ensemble: {ensemble_score:.3f}, Artifacts: {artifact_score:.3f}, "
                       f"Final: {final_score:.3f}")
            
            return max(0.0, min(1.0, final_score))
            
        except Exception as e:
            logger.error(f"Error in AI detection evaluation: {str(e)}")
            return self.fallback_detection_score(image)