clockclock's picture
Update app.py
37ffee9 verified
raw
history blame
5.63 kB
import gradio as gr
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import numpy as np
from captum.attr import LayerGradCam
from captum.attr import visualization as viz
import requests # <-- Import requests
from io import BytesIO # <-- Import BytesIO
# --- 1. Load Model and Processor ---
print("Loading model and processor...")
model_id = "Organika/sdxl-detector"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32)
model.eval()
print("Model and processor loaded successfully.")
# --- 2. Define the Explainability (Grad-CAM) Function ---
def generate_heatmap(image_tensor, original_image, target_class_index):
target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv
lgc = LayerGradCam(model, target_layer)
baselines = torch.zeros_like(image_tensor)
attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True)
heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
visualized_image, _ = viz.visualize_image_attr(
heatmap, np.array(original_image), method="blended_heat_map", sign="all", show_colorbar=True, title="Model Attention Heatmap",
)
return visualized_image
# --- 3. MODIFIED Main Prediction Function ---
# Now it accepts two inputs: an uploaded image and a URL string.
def predict(image_upload: Image.Image, image_url: str):
# --- Logic to decide which input to use ---
if image_upload is not None:
input_image = image_upload
print(f"Processing uploaded image of size: {input_image.size}")
elif image_url:
try:
response = requests.get(image_url)
response.raise_for_status() # Raise an exception for bad status codes
input_image = Image.open(BytesIO(response.content))
print(f"Processing image from URL: {image_url}")
except Exception as e:
raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}")
else:
# If no input is provided, raise an error
raise gr.Error("Please upload an image or provide a URL to analyze.")
if input_image.mode == 'RGBA':
input_image = input_image.convert('RGB')
inputs = processor(images=input_image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predicted_class_idx = logits.argmax(-1).item()
confidence_score = probabilities[0][predicted_class_idx].item()
predicted_label = model.config.id2label[predicted_class_idx]
if predicted_label.lower() == 'ai':
explanation = (
f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n"
"The heatmap on the right highlights the areas that most influenced this decision. "
"Pay close attention if these hotspots are on unnatural-looking features like hair, eyes, skin texture, or strange background details."
)
else:
explanation = (
f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n"
"The heatmap shows which areas the model found to be most 'natural'. "
"These are likely well-formed, realistic features that AI models often struggle to replicate perfectly."
)
print("Generating heatmap...")
heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx)
print("Heatmap generated.")
labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))}
return labels_dict, explanation, heatmap_image
# --- 4. MODIFIED Gradio Interface ---
# We use gr.Tabs to create separate input sections.
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# AI Image Detector with Explainability
Determine if an image is AI-generated or human-made. Upload a file or paste a URL.
This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model and provides a **heatmap** to show *why* the model made its decision.
"""
)
with gr.Row():
with gr.Column():
# --- TABS for different input methods ---
with gr.Tabs():
with gr.TabItem("Upload File"):
input_image_upload = gr.Image(type="pil", label="Upload Your Image")
with gr.TabItem("Use Image URL"):
input_image_url = gr.Textbox(label="Paste Image URL here")
submit_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column():
output_label = gr.Label(label="Prediction")
output_text = gr.Textbox(label="Explanation", lines=6, interactive=False)
output_heatmap = gr.Image(label="Model Attention Heatmap")
# The click event now passes both possible inputs to the predict function
submit_btn.click(
fn=predict,
inputs=[input_image_upload, input_image_url],
outputs=[output_label, output_text, output_heatmap]
)
# We remove the examples for now to simplify, as they don't work well with a tabbed interface by default.
# If you want them back, you would need a more complex setup to handle which tab the example populates.
# --- 5. Launch the App ---
if __name__ == "__main__":
demo.launch(debug=True)