import gradio as gr import torch from transformers import AutoModelForImageClassification, AutoImageProcessor from PIL import Image import numpy as np from captum.attr import LayerGradCam from captum.attr import visualization as viz # --- 1. Load Model and Processor --- # Load the pre-trained model and the image processor from Hugging Face. # We explicitly set torch_dtype to float32 to ensure CPU compatibility. print("Loading model and processor...") model_id = "Organika/sdxl-detector" processor = AutoImageProcessor.from_pretrained(model_id) model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32) model.eval() # Set the model to evaluation mode print("Model and processor loaded successfully.") # --- 2. Define the Explainability (Grad-CAM) Function --- # This function generates the heatmap showing which parts of the image the model focused on. def generate_heatmap(image_tensor, original_image, target_class_index): # LayerGradCam requires a specific layer to hook into. For ConvNeXT models (like this one), # a good choice is the final layer of the last stage of the encoder. target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv # Initialize LayerGradCam lgc = LayerGradCam(model, target_layer) # Generate attributions (the "importance" of each pixel) # The baselines are a reference point, typically a black image. baselines = torch.zeros_like(image_tensor) attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True) # The output of LayerGradCam is a heatmap. We process it for visualization. # We take the mean across the color channels and format it correctly. heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0)) # Use Captum's visualization tool to overlay the heatmap on the original image. visualized_image, _ = viz.visualize_image_attr( heatmap, np.array(original_image), method="blended_heat_map", sign="all", show_colorbar=True, title="Model Attention Heatmap", ) return visualized_image # --- 3. Define the Main Prediction Function --- # This function will be called by Gradio every time a user uploads an image. def predict(input_image: Image.Image): print(f"Received image of size: {input_image.size}") # Convert image to RGB if it has an alpha channel if input_image.mode == 'RGBA': input_image = input_image.convert('RGB') # Preprocess the image for the model inputs = processor(images=input_image, return_tensors="pt") # Make a prediction with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Convert logits to probabilities probabilities = torch.nn.functional.softmax(logits, dim=-1) # Get the predicted class index and the confidence score predicted_class_idx = logits.argmax(-1).item() confidence_score = probabilities[0][predicted_class_idx].item() # Get the label name (e.g., 'ai' or 'human') predicted_label = model.config.id2label[predicted_class_idx] # --- Generate Human-Readable Explanation --- # This directly answers your requirement to "say out which one is less human". if predicted_label.lower() == 'ai': explanation = ( f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n" "The heatmap on the right highlights the areas that most influenced this decision. " "According to your research, pay close attention if these hotspots are on " "unnatural-looking features like hair, eyes, skin texture, or strange background details." ) else: explanation = ( f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n" "The heatmap shows which areas the model found to be most 'natural'. " "These are likely well-formed, realistic features that AI models often struggle to replicate perfectly." ) # --- Generate the Heatmap --- # We call our Grad-CAM function to create the visualization. print("Generating heatmap...") heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx) print("Heatmap generated.") # Return the classification labels, the text explanation, and the heatmap image # The labels dictionary is for the gr.Label component. labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))} return labels_dict, explanation, heatmap_image # --- 4. Create the Gradio Interface --- # This sets up the web UI with inputs and outputs. with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # AI Image Detector with Explainability Upload an image to determine if it was generated by AI or created by a human. This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model. In addition to the prediction, it provides a **heatmap** to show *why* the model made its decision, highlighting the areas it found most suspicious or authentic. """ ) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Upload Image") submit_btn = gr.Button("Analyze Image", variant="primary") with gr.Column(): output_label = gr.Label(label="Prediction") output_text = gr.Textbox(label="Explanation", lines=6) output_heatmap = gr.Image(label="Model Attention Heatmap") submit_btn.click( fn=predict, inputs=input_image, outputs=[output_label, output_text, output_heatmap] ) gr.Examples( examples=[ ["examples/ai_example_1.png"], ["examples/human_example_1.jpg"], ["examples/ai_example_2.png"], ], inputs=input_image, outputs=[output_label, output_text, output_heatmap], fn=predict, cache_examples=True # Speeds up demo loading ) # Create some example files for the demo import os from urllib.request import urlretrieve os.makedirs("examples", exist_ok=True) urlretrieve("https://huggingface.co/Organika/sdxl-detector/resolve/main/ai.png", "examples/ai_example_1.png") urlretrieve("https://huggingface.co/Organika/sdxl-detector/resolve/main/human.png", "examples/human_example_1.jpg") urlretrieve("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/stable-diffusion-sdxl/sdxl-gfpgan-output.png", "examples/ai_example_2.png") # --- 5. Launch the App --- if __name__ == "__main__": demo.launch(debug=True) # debug=True lets you see print statements in the logs