import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import os
from numpy import exp
import pandas as pd
from PIL import Image
import urllib.request
import uuid
uid = uuid.uuid4()
models = [
"cmckinle/sdxl-flux-detector",
"umm-maybe/AI-image-detector",
"Organika/sdxl-detector",
]
fin_sum = []
def softmax(vector):
e = exp(vector)
return e / e.sum()
def aiornot(image, model_index):
labels = ["AI", "Real"]
mod = models[model_index]
feature_extractor = AutoFeatureExtractor.from_pretrained(mod)
model = AutoModelForImageClassification.from_pretrained(mod)
input = feature_extractor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**input)
logits = outputs.logits
probability = softmax(logits)
px = pd.DataFrame(probability.numpy())
prediction = logits.argmax(-1).item()
label = labels[prediction]
html_out = f"""
This image is likely: {label}
Probabilities:
Real: {px[1][0]}
AI: {px[0][0]}"""
results = {}
for idx, result in enumerate(px):
results[labels[idx]] = px[idx][0]
fin_sum.append(results)
return gr.HTML.update(html_out), results
def load_url(url):
try:
urllib.request.urlretrieve(f'{url}', f"{uid}tmp_im.png")
image = Image.open(f"{uid}tmp_im.png")
mes = "Image Loaded"
except Exception as e:
image = None
mes = f"Image not Found
Error: {e}"
return image, mes
def tot_prob():
try:
fin_out = sum([result["Real"] for result in fin_sum]) / len(fin_sum)
fin_sub = 1 - fin_out
out = {
"Real": f"{fin_out:.2%}",
"AI": f"{fin_sub:.2%}"
}
return out
except Exception as e:
print(e)
return None
def fin_clear():
fin_sum.clear()
return None
with gr.Blocks() as app:
gr.Markdown("""AI Image Detector
(Test Demo - accuracy varies by model)
""")
inp = gr.Image(type='pil')
in_url = gr.Textbox(label="Image URL")
load_btn = gr.Button("Load URL")
btn = gr.Button("Detect AI")
mes = gr.HTML("""""")
fin = gr.Label(label="Final Probability")
outp0 = gr.HTML("""""")
outp1 = gr.HTML("""""")
outp2 = gr.HTML("""""")
load_btn.click(load_url, in_url, [inp, mes])
btn.click(fin_clear, None, fin, show_progress=False)
btn.click(lambda img: aiornot(img, 0), inp, [outp0]).then(tot_prob, None, fin, show_progress=False)
btn.click(lambda img: aiornot(img, 1), inp, [outp1]).then(tot_prob, None, fin, show_progress=False)
btn.click(lambda img: aiornot(img, 2), inp, [outp2]).then(tot_prob, None, fin, show_progress=False)
app.launch(show_api=False, max_threads=24)