SuSy / app.py
pabberpe's picture
Add app.py and examples
41b7087
raw
history blame
3.56 kB
import gradio as gr
import numpy as np
import torch
from PIL import Image
from skimage.feature import graycomatrix, graycoprops
from torchvision import transforms
# Load the model
model = torch.jit.load("SuSy.pt")
def process_image(image):
# Set Parameters
top_k_patches = 5
patch_size = 224
# Get the image dimensions
width, height = image.size
# Calculate the number of patches
num_patches_x = width // patch_size
num_patches_y = height // patch_size
# Divide the image in patches
patches = np.zeros((num_patches_x * num_patches_y, patch_size, patch_size, 3), dtype=np.uint8)
for i in range(num_patches_x):
for j in range(num_patches_y):
x = i * patch_size
y = j * patch_size
patch = image.crop((x, y, x + patch_size, y + patch_size))
patches[i * num_patches_y + j] = np.array(patch)
# Compute the most relevant patches (optional)
dissimilarity_scores = []
for patch in patches:
transform_patch = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()])
grayscale_patch = transform_patch(Image.fromarray(patch)).squeeze(0)
glcm = graycomatrix(grayscale_patch, [5], [0], 256, symmetric=True, normed=True)
dissimilarity_scores.append(graycoprops(glcm, "contrast")[0, 0])
# Sort patch indices by their dissimilarity score
sorted_indices = np.argsort(dissimilarity_scores)[::-1]
# Extract top k patches and convert them to tensor
top_patches = patches[sorted_indices[:top_k_patches]]
top_patches = torch.from_numpy(np.transpose(top_patches, (0, 3, 1, 2))) / 255.0
# Predict patches
model.eval()
with torch.no_grad():
preds = model(top_patches)
# Process results
classes = ['Authentic', 'DALL路E 3', 'Stable Diffusion 1.x', 'MJ V5/V6', 'MJ V1/V2', 'Stable Diffusion XL']
mean_probs = preds.mean(dim=0).numpy()
# Create a dictionary of class probabilities
class_probs = {cls: prob for cls, prob in zip(classes, mean_probs)}
# Sort probabilities in descending order
sorted_probs = dict(sorted(class_probs.items(), key=lambda item: item[1], reverse=True))
return sorted_probs
# Define Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=6),
title="SuSy: Synthetic Image Detector",
description="Upload an image or select an example to classify it into different categories.",
examples=[
["example_authentic.jpg"],
["example_dalle3.jpg"],
["example_mjv5.jpg"],
["example_sdxl.jpg"]
],
article="""
<div style="text-align: center;">
<h3>About SuSy</h3>
<p>SuSy is an advanced synthetic image detector that can distinguish between authentic images and various types of AI-generated images. It analyzes patches of the input image to make its classification.</p>
<h4>Categories:</h4>
<ul style="list-style-type: none; padding: 0;">
<li>Authentic: Real, non-AI-generated images</li>
<li>DALL路E 3: Images generated by DALL-E 3</li>
<li>MJ V1/V2: Images generated by Midjourney versions 1 or 2</li>
<li>MJ V5/V6: Images generated by Midjourney versions 5 or 6</li>
<li>Stable Diffusion 1.x: Images generated by Stable Diffusion 1.x Models</li>
<li>Stable Diffusion XL: Images generated by Stable Diffusion XL</li>
</ul>
</div>
"""
)
# Launch the interface
iface.launch()