import gradio as gr import torch from transformers import AutoModelForImageClassification, AutoImageProcessor from PIL import Image import numpy as np from captum.attr import LayerGradCam from captum.attr import visualization as viz import requests # <-- Import requests from io import BytesIO # <-- Import BytesIO # --- 1. Load Model and Processor --- print("Loading model and processor...") model_id = "Organika/sdxl-detector" processor = AutoImageProcessor.from_pretrained(model_id) model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32) model.eval() print("Model and processor loaded successfully.") # --- 2. Define the Explainability (Grad-CAM) Function --- def generate_heatmap(image_tensor, original_image, target_class_index): target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv lgc = LayerGradCam(model, target_layer) baselines = torch.zeros_like(image_tensor) attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True) heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0)) visualized_image, _ = viz.visualize_image_attr( heatmap, np.array(original_image), method="blended_heat_map", sign="all", show_colorbar=True, title="Model Attention Heatmap", ) return visualized_image # --- 3. MODIFIED Main Prediction Function --- # Now it accepts two inputs: an uploaded image and a URL string. def predict(image_upload: Image.Image, image_url: str): # --- Logic to decide which input to use --- if image_upload is not None: input_image = image_upload print(f"Processing uploaded image of size: {input_image.size}") elif image_url: try: response = requests.get(image_url) response.raise_for_status() # Raise an exception for bad status codes input_image = Image.open(BytesIO(response.content)) print(f"Processing image from URL: {image_url}") except Exception as e: raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}") else: # If no input is provided, raise an error raise gr.Error("Please upload an image or provide a URL to analyze.") if input_image.mode == 'RGBA': input_image = input_image.convert('RGB') inputs = processor(images=input_image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=-1) predicted_class_idx = logits.argmax(-1).item() confidence_score = probabilities[0][predicted_class_idx].item() predicted_label = model.config.id2label[predicted_class_idx] if predicted_label.lower() == 'ai': explanation = ( f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n" "The heatmap on the right highlights the areas that most influenced this decision. " "Pay close attention if these hotspots are on unnatural-looking features like hair, eyes, skin texture, or strange background details." ) else: explanation = ( f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n" "The heatmap shows which areas the model found to be most 'natural'. " "These are likely well-formed, realistic features that AI models often struggle to replicate perfectly." ) print("Generating heatmap...") heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx) print("Heatmap generated.") labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))} return labels_dict, explanation, heatmap_image # --- 4. MODIFIED Gradio Interface --- # We use gr.Tabs to create separate input sections. with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # AI Image Detector with Explainability Determine if an image is AI-generated or human-made. Upload a file or paste a URL. This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model and provides a **heatmap** to show *why* the model made its decision. """ ) with gr.Row(): with gr.Column(): # --- TABS for different input methods --- with gr.Tabs(): with gr.TabItem("Upload File"): input_image_upload = gr.Image(type="pil", label="Upload Your Image") with gr.TabItem("Use Image URL"): input_image_url = gr.Textbox(label="Paste Image URL here") submit_btn = gr.Button("Analyze Image", variant="primary") with gr.Column(): output_label = gr.Label(label="Prediction") output_text = gr.Textbox(label="Explanation", lines=6, interactive=False) output_heatmap = gr.Image(label="Model Attention Heatmap") # The click event now passes both possible inputs to the predict function submit_btn.click( fn=predict, inputs=[input_image_upload, input_image_url], outputs=[output_label, output_text, output_heatmap] ) # We remove the examples for now to simplify, as they don't work well with a tabbed interface by default. # If you want them back, you would need a more complex setup to handle which tab the example populates. # --- 5. Launch the App --- if __name__ == "__main__": demo.launch(debug=True)