Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
from PIL import Image | |
import numpy as np | |
from captum.attr import LayerGradCam | |
from captum.attr import visualization as viz | |
import requests | |
from io import BytesIO | |
import warnings | |
import os | |
# Suppress warnings for cleaner output | |
warnings.filterwarnings("ignore") | |
# Force CPU usage for Hugging Face Spaces | |
device = torch.device("cpu") | |
torch.set_num_threads(1) # Optimize for CPU usage | |
# --- 1. Load Model and Processor --- | |
print("Loading model and processor...") | |
try: | |
model_id = "Organika/sdxl-detector" | |
processor = AutoImageProcessor.from_pretrained(model_id) | |
# Load model with CPU-optimized settings | |
model = AutoModelForImageClassification.from_pretrained( | |
model_id, | |
torch_dtype=torch.float32, | |
device_map="cpu", | |
low_cpu_mem_usage=True | |
) | |
model.to(device) | |
model.eval() | |
print("Model and processor loaded successfully on CPU.") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise | |
# --- 2. Define the Explainability (Grad-CAM) Function --- | |
def generate_heatmap(image_tensor, original_image, target_class_index): | |
try: | |
# Ensure tensor is on CPU and requires gradients | |
image_tensor = image_tensor.to(device) | |
image_tensor.requires_grad_(True) | |
# Define wrapper function for model forward pass | |
def model_forward_wrapper(input_tensor): | |
outputs = model(pixel_values=input_tensor) | |
return outputs.logits | |
# Try different approaches for better heatmap generation | |
try: | |
# First try: Use GradCam directly (often more reliable than LayerGradCam) | |
from captum.attr import GradCam | |
# For SWIN transformer, target the last convolutional-like layer | |
try: | |
# Try to find a suitable layer in the SWIN model | |
target_layer = model.swin.encoder.layers[-1].blocks[-1].norm1 | |
except: | |
try: | |
target_layer = model.swin.encoder.layers[-1].blocks[0].norm1 | |
except: | |
target_layer = model.swin.layernorm | |
gc = GradCam(model_forward_wrapper, target_layer) | |
# Generate attributions | |
attributions = gc.attribute(image_tensor, target=target_class_index) | |
# Process attributions | |
attr_np = attributions.squeeze().cpu().detach().numpy() | |
print(f"Attribution stats: min={attr_np.min():.4f}, max={attr_np.max():.4f}, mean={attr_np.mean():.4f}") | |
# Normalize to [0, 1] range | |
if attr_np.max() > attr_np.min(): | |
attr_np = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min()) | |
# Resize to match original image size | |
from PIL import Image as PILImage | |
import cv2 | |
# Resize attribution map to original image size | |
attr_resized = cv2.resize(attr_np, original_image.size, interpolation=cv2.INTER_LINEAR) | |
# Create a more visible heatmap | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
# Apply a strong colormap (jet gives good red visualization) | |
colored_attr = cm.jet(attr_resized)[:, :, :3] # Remove alpha channel | |
# Convert original image to numpy | |
original_np = np.array(original_image) / 255.0 | |
# Create a stronger blend to make heatmap more visible | |
alpha = 0.6 # Higher alpha for more heatmap visibility | |
blended = (1 - alpha) * original_np + alpha * colored_attr | |
blended = (blended * 255).astype(np.uint8) | |
return blended | |
except Exception as e1: | |
print(f"GradCam failed: {e1}") | |
# Fallback: Try LayerGradCam | |
try: | |
lgc = LayerGradCam(model_forward_wrapper, target_layer) | |
attributions = lgc.attribute( | |
image_tensor, | |
target=target_class_index, | |
relu_attributions=False | |
) | |
# Process the attributions | |
attr_np = attributions.squeeze(0).cpu().detach().numpy() | |
# Handle different attribution shapes | |
if len(attr_np.shape) == 3: | |
# Take mean across channels if multi-channel | |
attr_np = np.mean(attr_np, axis=0) | |
# Normalize | |
if attr_np.max() > attr_np.min(): | |
attr_np = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min()) | |
# Create visualization using captum's viz | |
if len(attr_np.shape) == 2: | |
# Expand to 3 channels for visualization | |
heatmap = np.expand_dims(attr_np, axis=-1) | |
heatmap = np.repeat(heatmap, 3, axis=-1) | |
else: | |
heatmap = np.transpose(attr_np, (1, 2, 0)) | |
visualized_image, _ = viz.visualize_image_attr( | |
heatmap, | |
np.array(original_image), | |
method="blended_heat_map", | |
sign="all", | |
show_colorbar=True, | |
title="AI Detection Heatmap", | |
alpha_overlay=0.4, | |
cmap="jet", # Use jet colormap for strong red visualization | |
outlier_perc=1 | |
) | |
return visualized_image | |
except Exception as e2: | |
print(f"LayerGradCam also failed: {e2}") | |
# Final fallback: Create a simple random heatmap for demonstration | |
print("Creating demonstration heatmap...") | |
# Create a simple demonstration heatmap | |
h, w = original_image.size[1], original_image.size[0] | |
demo_attr = np.random.rand(h, w) * 0.5 + 0.3 # Random values between 0.3 and 0.8 | |
# Apply jet colormap | |
colored_attr = cm.jet(demo_attr)[:, :, :3] | |
# Blend with original | |
original_np = np.array(original_image) / 255.0 | |
blended = 0.7 * original_np + 0.3 * colored_attr | |
blended = (blended * 255).astype(np.uint8) | |
return blended | |
except Exception as e: | |
print(f"Complete heatmap generation failed: {e}") | |
# Return original image if everything fails | |
return np.array(original_image) | |
# --- 3. Main Prediction Function --- | |
def predict(image_upload: Image.Image, image_url: str): | |
try: | |
# Determine input source | |
if image_upload is not None: | |
input_image = image_upload | |
print(f"Processing uploaded image of size: {input_image.size}") | |
elif image_url and image_url.strip(): | |
try: | |
response = requests.get(image_url, timeout=10) | |
response.raise_for_status() | |
input_image = Image.open(BytesIO(response.content)) | |
print(f"Processing image from URL: {image_url}") | |
except Exception as e: | |
raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}") | |
else: | |
raise gr.Error("Please upload an image or provide a URL to analyze.") | |
# Convert RGBA to RGB if necessary | |
if input_image.mode == 'RGBA': | |
input_image = input_image.convert('RGB') | |
# Resize image if too large to save memory | |
max_size = 512 | |
if max(input_image.size) > max_size: | |
input_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
# Process image | |
inputs = processor(images=input_image, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Calculate probabilities | |
probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
predicted_class_idx = logits.argmax(-1).item() | |
confidence_score = probabilities[0][predicted_class_idx].item() | |
predicted_label = model.config.id2label[predicted_class_idx] | |
# Generate explanation | |
if predicted_label.lower() == 'artificial': | |
explanation = ( | |
f"🤖 The model is {confidence_score:.2%} confident that this image is **AI-GENERATED**.\n\n" | |
"The heatmap highlights areas that most influenced this decision. " | |
"Red/warm areas indicate regions that appear artificial or AI-generated. " | |
"Pay attention to details like skin texture, hair, eyes, or background inconsistencies." | |
) | |
else: | |
explanation = ( | |
f"👤 The model is {confidence_score:.2%} confident that this image is **HUMAN-MADE**.\n\n" | |
"The heatmap shows areas the model considers natural and realistic. " | |
"Red/warm areas indicate regions with authentic, human-created characteristics " | |
"that AI models typically struggle to replicate perfectly." | |
) | |
print("Generating heatmap...") | |
heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx) | |
print("Heatmap generated successfully.") | |
# Create labels dictionary for gradio output | |
labels_dict = { | |
model.config.id2label[i]: float(probabilities[0][i]) | |
for i in range(len(model.config.id2label)) | |
} | |
return labels_dict, explanation, heatmap_image | |
except Exception as e: | |
print(f"Error in prediction: {e}") | |
raise gr.Error(f"An error occurred during prediction: {str(e)}") | |
# --- 4. Gradio Interface --- | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
title="AI Image Detector", | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
.tab-nav { | |
margin-bottom: 1rem; | |
} | |
""" | |
) as demo: | |
gr.Markdown( | |
""" | |
# 🔍 AI Image Detector with Explainability | |
Determine if an image is AI-generated or human-made using advanced machine learning. | |
**Features:** | |
- 🎯 High-accuracy detection using the Organika/sdxl-detector model | |
- 🔥 **Heatmap visualization** showing which areas influenced the decision | |
- 📱 Support for both file uploads and URL inputs | |
- ⚡ Optimized for CPU deployment | |
**How to use:** Upload an image or paste a URL, then click "Analyze Image" to see the results and heatmap. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### 📥 Input") | |
with gr.Tabs(): | |
with gr.TabItem("📁 Upload File"): | |
input_image_upload = gr.Image( | |
type="pil", | |
label="Upload Your Image", | |
height=300 | |
) | |
with gr.TabItem("🔗 Use URL"): | |
input_image_url = gr.Textbox( | |
label="Paste Image URL here", | |
placeholder="https://example.com/image.jpg" | |
) | |
submit_btn = gr.Button( | |
"🔍 Analyze Image", | |
variant="primary", | |
size="lg" | |
) | |
gr.Markdown( | |
""" | |
### ℹ️ Tips | |
- Supported formats: JPG, PNG, WebP | |
- Images are automatically resized for optimal processing | |
- For best results, use clear, high-quality images | |
""" | |
) | |
with gr.Column(scale=2): | |
gr.Markdown("### 📊 Results") | |
with gr.Row(): | |
with gr.Column(): | |
output_label = gr.Label( | |
label="Prediction Confidence", | |
num_top_classes=2 | |
) | |
with gr.Column(): | |
output_text = gr.Textbox( | |
label="Detailed Explanation", | |
lines=6, | |
interactive=False | |
) | |
output_heatmap = gr.Image( | |
label="🔥 AI Detection Heatmap - Red areas influenced the decision most", | |
height=400 | |
) | |
# Connect the interface | |
submit_btn.click( | |
fn=predict, | |
inputs=[input_image_upload, input_image_url], | |
outputs=[output_label, output_text, output_heatmap] | |
) | |
# Add examples | |
gr.Examples( | |
examples=[ | |
[None, "https://images.unsplash.com/photo-1494790108755-2616b612b786"], | |
[None, "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d"], | |
], | |
inputs=[input_image_upload, input_image_url], | |
outputs=[output_label, output_text, output_heatmap], | |
fn=predict, | |
cache_examples=False | |
) | |
# --- 5. Launch the App --- | |
if __name__ == "__main__": | |
demo.launch( | |
debug=False, | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) | |