Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
from numpy import exp | |
import pandas as pd | |
from PIL import Image | |
import urllib.request | |
import uuid | |
uid = uuid.uuid4() | |
models = [ | |
"cmckinle/sdxl-flux-detector", | |
"umm-maybe/AI-image-detector", | |
"Organika/sdxl-detector", | |
] | |
results_store = [] | |
def softmax(vector): | |
e = exp(vector) | |
return e / e.sum() | |
def aiornot(image, model_index): | |
model_index = int(model_index) # Convert to integer | |
mod = models[model_index] | |
feature_extractor = AutoFeatureExtractor.from_pretrained(mod) | |
model = AutoModelForImageClassification.from_pretrained(mod) | |
input = feature_extractor(image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**input) | |
logits = outputs.logits | |
probability = softmax(logits) | |
px = pd.DataFrame(probability.numpy()) | |
if model_index == 2: # Organika model | |
real_prob, ai_prob = px[0][0], px[1][0] | |
label = "Real" if real_prob > ai_prob else "AI" | |
else: | |
ai_prob, real_prob = px[0][0], px[1][0] | |
label = "AI" if ai_prob > real_prob else "Real" | |
html_out = f""" | |
<h1>This image is likely: {label}</h1><br><h3> | |
Probabilities:<br> | |
Real: {real_prob:.4f}<br> | |
AI: {ai_prob:.4f}""" | |
results = {"Real": real_prob, "AI": ai_prob} | |
results_store.append(results) | |
return gr.HTML.update(html_out), results | |
def load_url(url): | |
try: | |
urllib.request.urlretrieve(f'{url}', f"{uid}tmp_im.png") | |
image = Image.open(f"{uid}tmp_im.png") | |
mes = "Image Loaded" | |
except Exception as e: | |
image = None | |
mes = f"Image not Found<br>Error: {e}" | |
return image, mes | |
def calculate_final_prob(): | |
if not results_store: | |
return {"Real": "N/A", "AI": "N/A"} | |
fin_out = sum(result["Real"] for result in results_store) / len(results_store) | |
return { | |
"Real": f"{fin_out:.4f}", | |
"AI": f"{1 - fin_out:.4f}" | |
} | |
def clear_results(): | |
results_store.clear() | |
return gr.HTML.update(value=""), gr.Label.update(value=None) | |
with gr.Blocks() as app: | |
gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></center>""") | |
with gr.Column(): | |
inp = gr.Image(type='pil') | |
in_url = gr.Textbox(label="Image URL") | |
with gr.Row(): | |
load_btn = gr.Button("Load URL") | |
btn = gr.Button("Detect AI") | |
mes = gr.HTML() | |
with gr.Group(): | |
with gr.Row(): | |
fin = gr.Label(label="Final Probability") | |
with gr.Row(): | |
for i, model in enumerate(models): | |
with gr.Column(): | |
gr.HTML(f"""<b>Testing on Model: <a href='https://huggingface.co/{model}'>{model}</a></b>""") | |
output_html = gr.HTML() | |
output_label = gr.Label(label="Output") | |
btn.click(aiornot, inputs=[inp, gr.Number(value=i, visible=False)], outputs=[output_html, output_label]) | |
btn.click(clear_results, outputs=[output_html, output_label], queue=False) | |
btn.click(calculate_final_prob, outputs=fin) | |
load_btn.click(load_url, in_url, [inp, mes]) | |
app.launch(show_api=False, max_threads=24) |