|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from skimage.feature import graycomatrix, graycoprops |
|
from torchvision import transforms |
|
|
|
|
|
model = torch.jit.load("SuSy.pt") |
|
|
|
def process_image(image): |
|
|
|
top_k_patches = 5 |
|
patch_size = 224 |
|
|
|
|
|
width, height = image.size |
|
|
|
|
|
num_patches_x = width // patch_size |
|
num_patches_y = height // patch_size |
|
|
|
|
|
patches = np.zeros((num_patches_x * num_patches_y, patch_size, patch_size, 3), dtype=np.uint8) |
|
for i in range(num_patches_x): |
|
for j in range(num_patches_y): |
|
x = i * patch_size |
|
y = j * patch_size |
|
patch = image.crop((x, y, x + patch_size, y + patch_size)) |
|
patches[i * num_patches_y + j] = np.array(patch) |
|
|
|
|
|
dissimilarity_scores = [] |
|
for patch in patches: |
|
transform_patch = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()]) |
|
grayscale_patch = transform_patch(Image.fromarray(patch)).squeeze(0) |
|
glcm = graycomatrix(grayscale_patch, [5], [0], 256, symmetric=True, normed=True) |
|
dissimilarity_scores.append(graycoprops(glcm, "contrast")[0, 0]) |
|
|
|
|
|
sorted_indices = np.argsort(dissimilarity_scores)[::-1] |
|
|
|
|
|
top_patches = patches[sorted_indices[:top_k_patches]] |
|
top_patches = torch.from_numpy(np.transpose(top_patches, (0, 3, 1, 2))) / 255.0 |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
preds = model(top_patches) |
|
|
|
|
|
classes = ['Authentic', 'DALL路E 3', 'Stable Diffusion 1.x', 'MJ V5/V6', 'MJ V1/V2', 'Stable Diffusion XL'] |
|
mean_probs = preds.mean(dim=0).numpy() |
|
|
|
|
|
class_probs = {cls: prob for cls, prob in zip(classes, mean_probs)} |
|
|
|
|
|
sorted_probs = dict(sorted(class_probs.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
return sorted_probs |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Label(num_top_classes=6), |
|
title="SuSy: Synthetic Image Detector", |
|
description="Upload an image or select an example to classify it into different categories.", |
|
examples=[ |
|
["example_authentic.jpg"], |
|
["example_dalle3.jpg"], |
|
["example_mjv5.jpg"], |
|
["example_sdxl.jpg"] |
|
], |
|
article=""" |
|
<div style="text-align: center;"> |
|
<h3>About SuSy</h3> |
|
<p>SuSy is an advanced synthetic image detector that can distinguish between authentic images and various types of AI-generated images. It analyzes patches of the input image to make its classification.</p> |
|
<h4>Categories:</h4> |
|
<ul style="list-style-type: none; padding: 0;"> |
|
<li>Authentic: Real, non-AI-generated images</li> |
|
<li>DALL路E 3: Images generated by DALL-E 3</li> |
|
<li>MJ V1/V2: Images generated by Midjourney versions 1 or 2</li> |
|
<li>MJ V5/V6: Images generated by Midjourney versions 5 or 6</li> |
|
<li>Stable Diffusion 1.x: Images generated by Stable Diffusion 1.x Models</li> |
|
<li>Stable Diffusion XL: Images generated by Stable Diffusion XL</li> |
|
</ul> |
|
</div> |
|
""" |
|
) |
|
|
|
|
|
iface.launch() |