Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
from PIL import Image | |
import numpy as np | |
from captum.attr import LayerGradCam | |
from captum.attr import visualization as viz | |
import requests # <-- Import requests | |
from io import BytesIO # <-- Import BytesIO | |
# --- 1. Load Model and Processor --- | |
print("Loading model and processor...") | |
model_id = "Organika/sdxl-detector" | |
processor = AutoImageProcessor.from_pretrained(model_id) | |
model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32) | |
model.eval() | |
print("Model and processor loaded successfully.") | |
# --- 2. Define the Explainability (Grad-CAM) Function --- | |
def generate_heatmap(image_tensor, original_image, target_class_index): | |
target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv | |
lgc = LayerGradCam(model, target_layer) | |
baselines = torch.zeros_like(image_tensor) | |
attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True) | |
heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0)) | |
visualized_image, _ = viz.visualize_image_attr( | |
heatmap, np.array(original_image), method="blended_heat_map", sign="all", show_colorbar=True, title="Model Attention Heatmap", | |
) | |
return visualized_image | |
# --- 3. MODIFIED Main Prediction Function --- | |
# Now it accepts two inputs: an uploaded image and a URL string. | |
def predict(image_upload: Image.Image, image_url: str): | |
# --- Logic to decide which input to use --- | |
if image_upload is not None: | |
input_image = image_upload | |
print(f"Processing uploaded image of size: {input_image.size}") | |
elif image_url: | |
try: | |
response = requests.get(image_url) | |
response.raise_for_status() # Raise an exception for bad status codes | |
input_image = Image.open(BytesIO(response.content)) | |
print(f"Processing image from URL: {image_url}") | |
except Exception as e: | |
raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}") | |
else: | |
# If no input is provided, raise an error | |
raise gr.Error("Please upload an image or provide a URL to analyze.") | |
if input_image.mode == 'RGBA': | |
input_image = input_image.convert('RGB') | |
inputs = processor(images=input_image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
predicted_class_idx = logits.argmax(-1).item() | |
confidence_score = probabilities[0][predicted_class_idx].item() | |
predicted_label = model.config.id2label[predicted_class_idx] | |
if predicted_label.lower() == 'ai': | |
explanation = ( | |
f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n" | |
"The heatmap on the right highlights the areas that most influenced this decision. " | |
"Pay close attention if these hotspots are on unnatural-looking features like hair, eyes, skin texture, or strange background details." | |
) | |
else: | |
explanation = ( | |
f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n" | |
"The heatmap shows which areas the model found to be most 'natural'. " | |
"These are likely well-formed, realistic features that AI models often struggle to replicate perfectly." | |
) | |
print("Generating heatmap...") | |
heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx) | |
print("Heatmap generated.") | |
labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))} | |
return labels_dict, explanation, heatmap_image | |
# --- 4. MODIFIED Gradio Interface --- | |
# We use gr.Tabs to create separate input sections. | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# AI Image Detector with Explainability | |
Determine if an image is AI-generated or human-made. Upload a file or paste a URL. | |
This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model and provides a **heatmap** to show *why* the model made its decision. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
# --- TABS for different input methods --- | |
with gr.Tabs(): | |
with gr.TabItem("Upload File"): | |
input_image_upload = gr.Image(type="pil", label="Upload Your Image") | |
with gr.TabItem("Use Image URL"): | |
input_image_url = gr.Textbox(label="Paste Image URL here") | |
submit_btn = gr.Button("Analyze Image", variant="primary") | |
with gr.Column(): | |
output_label = gr.Label(label="Prediction") | |
output_text = gr.Textbox(label="Explanation", lines=6, interactive=False) | |
output_heatmap = gr.Image(label="Model Attention Heatmap") | |
# The click event now passes both possible inputs to the predict function | |
submit_btn.click( | |
fn=predict, | |
inputs=[input_image_upload, input_image_url], | |
outputs=[output_label, output_text, output_heatmap] | |
) | |
# We remove the examples for now to simplify, as they don't work well with a tabbed interface by default. | |
# If you want them back, you would need a more complex setup to handle which tab the example populates. | |
# --- 5. Launch the App --- | |
if __name__ == "__main__": | |
demo.launch(debug=True) |