File size: 3,563 Bytes
41b7087 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
import numpy as np
import torch
from PIL import Image
from skimage.feature import graycomatrix, graycoprops
from torchvision import transforms
# Load the model
model = torch.jit.load("SuSy.pt")
def process_image(image):
# Set Parameters
top_k_patches = 5
patch_size = 224
# Get the image dimensions
width, height = image.size
# Calculate the number of patches
num_patches_x = width // patch_size
num_patches_y = height // patch_size
# Divide the image in patches
patches = np.zeros((num_patches_x * num_patches_y, patch_size, patch_size, 3), dtype=np.uint8)
for i in range(num_patches_x):
for j in range(num_patches_y):
x = i * patch_size
y = j * patch_size
patch = image.crop((x, y, x + patch_size, y + patch_size))
patches[i * num_patches_y + j] = np.array(patch)
# Compute the most relevant patches (optional)
dissimilarity_scores = []
for patch in patches:
transform_patch = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()])
grayscale_patch = transform_patch(Image.fromarray(patch)).squeeze(0)
glcm = graycomatrix(grayscale_patch, [5], [0], 256, symmetric=True, normed=True)
dissimilarity_scores.append(graycoprops(glcm, "contrast")[0, 0])
# Sort patch indices by their dissimilarity score
sorted_indices = np.argsort(dissimilarity_scores)[::-1]
# Extract top k patches and convert them to tensor
top_patches = patches[sorted_indices[:top_k_patches]]
top_patches = torch.from_numpy(np.transpose(top_patches, (0, 3, 1, 2))) / 255.0
# Predict patches
model.eval()
with torch.no_grad():
preds = model(top_patches)
# Process results
classes = ['Authentic', 'DALL·E 3', 'Stable Diffusion 1.x', 'MJ V5/V6', 'MJ V1/V2', 'Stable Diffusion XL']
mean_probs = preds.mean(dim=0).numpy()
# Create a dictionary of class probabilities
class_probs = {cls: prob for cls, prob in zip(classes, mean_probs)}
# Sort probabilities in descending order
sorted_probs = dict(sorted(class_probs.items(), key=lambda item: item[1], reverse=True))
return sorted_probs
# Define Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=6),
title="SuSy: Synthetic Image Detector",
description="Upload an image or select an example to classify it into different categories.",
examples=[
["example_authentic.jpg"],
["example_dalle3.jpg"],
["example_mjv5.jpg"],
["example_sdxl.jpg"]
],
article="""
<div style="text-align: center;">
<h3>About SuSy</h3>
<p>SuSy is an advanced synthetic image detector that can distinguish between authentic images and various types of AI-generated images. It analyzes patches of the input image to make its classification.</p>
<h4>Categories:</h4>
<ul style="list-style-type: none; padding: 0;">
<li>Authentic: Real, non-AI-generated images</li>
<li>DALL·E 3: Images generated by DALL-E 3</li>
<li>MJ V1/V2: Images generated by Midjourney versions 1 or 2</li>
<li>MJ V5/V6: Images generated by Midjourney versions 5 or 6</li>
<li>Stable Diffusion 1.x: Images generated by Stable Diffusion 1.x Models</li>
<li>Stable Diffusion XL: Images generated by Stable Diffusion XL</li>
</ul>
</div>
"""
)
# Launch the interface
iface.launch() |