ramananvr89 commited on
Commit
58adaae
·
verified ·
1 Parent(s): 3b87c7b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +331 -0
app.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
4
+ import os
5
+ import zipfile
6
+ import shutil
7
+ import matplotlib.pyplot as plt
8
+ from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc
9
+ from tqdm import tqdm
10
+ from PIL import Image
11
+ import uuid
12
+ import tempfile
13
+ import pandas as pd
14
+ from numpy import exp
15
+ import numpy as np
16
+ from sklearn.metrics import ConfusionMatrixDisplay
17
+ import urllib.request
18
+
19
+ # Define models
20
+ models = [
21
+ "umm-maybe/AI-image-detector",
22
+ "Organika/sdxl-detector",
23
+ "cmckinle/sdxl-flux-detector",
24
+ ]
25
+
26
+ pipe0 = pipeline("image-classification", f"{models[0]}")
27
+ pipe1 = pipeline("image-classification", f"{models[1]}")
28
+ pipe2 = pipeline("image-classification", f"{models[2]}")
29
+
30
+ fin_sum = []
31
+ uid = uuid.uuid4()
32
+
33
+ # Softmax function
34
+ def softmax(vector):
35
+ e = exp(vector - vector.max()) # for numerical stability
36
+ return e / e.sum()
37
+
38
+ # Single image classification functions
39
+ def image_classifier0(image):
40
+ labels = ["AI", "Real"]
41
+ outputs = pipe0(image)
42
+ results = {}
43
+ for idx, result in enumerate(outputs):
44
+ results[labels[idx]] = float(outputs[idx]['score']) # Convert to float
45
+ fin_sum.append(results)
46
+ return results
47
+
48
+ def image_classifier1(image):
49
+ labels = ["AI", "Real"]
50
+ outputs = pipe1(image)
51
+ results = {}
52
+ for idx, result in enumerate(outputs):
53
+ results[labels[idx]] = float(outputs[idx]['score']) # Convert to float
54
+ fin_sum.append(results)
55
+ return results
56
+
57
+ def image_classifier2(image):
58
+ labels = ["AI", "Real"]
59
+ outputs = pipe2(image)
60
+ results = {}
61
+ for idx, result in enumerate(outputs):
62
+ results[labels[idx]] = float(outputs[idx]['score']) # Convert to float
63
+ fin_sum.append(results)
64
+ return results
65
+
66
+ def aiornot0(image):
67
+ labels = ["AI", "Real"]
68
+ mod = models[0]
69
+ feature_extractor0 = AutoFeatureExtractor.from_pretrained(mod)
70
+ model0 = AutoModelForImageClassification.from_pretrained(mod)
71
+ input = feature_extractor0(image, return_tensors="pt")
72
+ with torch.no_grad():
73
+ outputs = model0(**input)
74
+ logits = outputs.logits
75
+ probability = softmax(logits) # Apply softmax on logits
76
+ px = pd.DataFrame(probability.numpy())
77
+ prediction = logits.argmax(-1).item()
78
+ label = labels[prediction]
79
+
80
+ html_out = f"""
81
+ <h1>This image is likely: {label}</h1><br><h3>
82
+ Probabilities:<br>
83
+ Real: {float(px[1][0]):.4f}<br>
84
+ AI: {float(px[0][0]):.4f}"""
85
+
86
+ results = {
87
+ "Real": float(px[1][0]),
88
+ "AI": float(px[0][0])
89
+ }
90
+ fin_sum.append(results)
91
+ return gr.HTML.update(html_out), results
92
+
93
+ def aiornot1(image):
94
+ labels = ["AI", "Real"]
95
+ mod = models[1]
96
+ feature_extractor1 = AutoFeatureExtractor.from_pretrained(mod)
97
+ model1 = AutoModelForImageClassification.from_pretrained(mod)
98
+ input = feature_extractor1(image, return_tensors="pt")
99
+ with torch.no_grad():
100
+ outputs = model1(**input)
101
+ logits = outputs.logits
102
+ probability = softmax(logits) # Apply softmax on logits
103
+ px = pd.DataFrame(probability.numpy())
104
+ prediction = logits.argmax(-1).item()
105
+ label = labels[prediction]
106
+
107
+ html_out = f"""
108
+ <h1>This image is likely: {label}</h1><br><h3>
109
+ Probabilities:<br>
110
+ Real: {float(px[1][0]):.4f}<br>
111
+ AI: {float(px[0][0]):.4f}"""
112
+
113
+ results = {
114
+ "Real": float(px[1][0]),
115
+ "AI": float(px[0][0])
116
+ }
117
+ fin_sum.append(results)
118
+ return gr.HTML.update(html_out), results
119
+
120
+ def aiornot2(image):
121
+ labels = ["AI", "Real"]
122
+ mod = models[2]
123
+ feature_extractor2 = AutoFeatureExtractor.from_pretrained(mod)
124
+ model2 = AutoModelForImageClassification.from_pretrained(mod)
125
+ input = feature_extractor2(image, return_tensors="pt")
126
+ with torch.no_grad():
127
+ outputs = model2(**input)
128
+ logits = outputs.logits
129
+ probability = softmax(logits) # Apply softmax on logits
130
+ px = pd.DataFrame(probability.numpy())
131
+ prediction = logits.argmax(-1).item()
132
+ label = labels[prediction]
133
+
134
+ html_out = f"""
135
+ <h1>This image is likely: {label}</h1><br><h3>
136
+ Probabilities:<br>
137
+ Real: {float(px[1][0]):.4f}<br>
138
+ AI: {float(px[0][0]):.4f}"""
139
+
140
+ results = {
141
+ "Real": float(px[1][0]),
142
+ "AI": float(px[0][0])
143
+ }
144
+ fin_sum.append(results)
145
+ return gr.HTML.update(html_out), results
146
+
147
+ # Function to extract images from zip
148
+ def extract_zip(zip_file):
149
+ temp_dir = tempfile.mkdtemp() # Temporary directory
150
+ with zipfile.ZipFile(zip_file, 'r') as z:
151
+ z.extractall(temp_dir)
152
+ return temp_dir
153
+
154
+ # Function to classify images in a folder
155
+ def classify_images(image_dir, model_pipeline, model_idx):
156
+ images = []
157
+ labels = []
158
+ preds = []
159
+ for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
160
+ folder_path = os.path.join(image_dir, folder_name)
161
+ if not os.path.exists(folder_path):
162
+ print(f"Folder not found: {folder_path}")
163
+ continue
164
+ for img_name in os.listdir(folder_path):
165
+ img_path = os.path.join(folder_path, img_name)
166
+ try:
167
+ img = Image.open(img_path).convert("RGB")
168
+
169
+ # Ensure that each image is being processed by the correct model pipeline
170
+ pred = model_pipeline(img)
171
+ pred_label = 0 if pred[0]['label'] == 'AI' else 1 # Assuming 'AI' is label 0 and 'Real' is label 1
172
+
173
+ preds.append(pred_label)
174
+ labels.append(ground_truth_label)
175
+ images.append(img_name)
176
+ except Exception as e:
177
+ print(f"Error processing image {img_name} in model {model_idx}: {e}")
178
+
179
+ print(f"Model {model_idx} processed {len(images)} images")
180
+ return labels, preds, images
181
+
182
+ # Function to generate evaluation metrics
183
+ def evaluate_model(labels, preds):
184
+ cm = confusion_matrix(labels, preds)
185
+ accuracy = accuracy_score(labels, preds)
186
+ roc_score = roc_auc_score(labels, preds)
187
+ report = classification_report(labels, preds)
188
+ fpr, tpr, _ = roc_curve(labels, preds)
189
+ roc_auc = auc(fpr, tpr)
190
+
191
+ fig, ax = plt.subplots()
192
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["AI", "Real"])
193
+ disp.plot(cmap=plt.cm.Blues, ax=ax)
194
+ plt.close(fig)
195
+
196
+ fig_roc, ax_roc = plt.subplots()
197
+ ax_roc.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
198
+ ax_roc.plot([0, 1], [0, 1], color='gray', linestyle='--')
199
+ ax_roc.set_xlim([0.0, 1.0])
200
+ ax_roc.set_ylim([0.0, 1.05])
201
+ ax_roc.set_xlabel('False Positive Rate')
202
+ ax_roc.set_ylabel('True Positive Rate')
203
+ ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
204
+ ax_roc.legend(loc="lower right")
205
+ plt.close(fig_roc)
206
+
207
+ return accuracy, roc_score, report, fig, fig_roc
208
+
209
+ # Batch processing for all models
210
+ def process_zip(zip_file):
211
+ extracted_dir = extract_zip(zip_file.name)
212
+
213
+ # Run classification for each model
214
+ results = {}
215
+ for idx in range(len(models)):
216
+ print(f"Processing with model {models[idx]}") # Debugging to show which model is being used
217
+
218
+ # Create a new pipeline for each model within the loop
219
+ pipe = pipeline("image-classification", f"{models[idx]}")
220
+ print(f"Initialized pipeline for {models[idx]}") # Confirm pipeline is initialized correctly
221
+
222
+ # Classify images with the correct pipeline per model
223
+ labels, preds, images = classify_images(extracted_dir, pipe, idx)
224
+
225
+ # Debugging: Print the predictions to ensure they're different
226
+ print(f"Predictions for model {models[idx]}: {preds}")
227
+
228
+ accuracy, roc_score, report, cm_fig, roc_fig = evaluate_model(labels, preds)
229
+
230
+ # Store results for each model
231
+ results[f'Model_{idx}_accuracy'] = accuracy
232
+ results[f'Model_{idx}_roc_score'] = roc_score
233
+ results[f'Model_{idx}_report'] = report
234
+ results[f'Model_{idx}_cm_fig'] = cm_fig
235
+ results[f'Model_{idx}_roc_fig'] = roc_fig
236
+
237
+ shutil.rmtree(extracted_dir) # Clean up extracted files
238
+
239
+ # Return results for all models
240
+ return (results['Model_0_accuracy'], results['Model_0_roc_score'], results['Model_0_report'],
241
+ results['Model_0_cm_fig'], results['Model_0_roc_fig'],
242
+ results['Model_1_accuracy'], results['Model_1_roc_score'], results['Model_1_report'],
243
+ results['Model_1_cm_fig'], results['Model_1_roc_fig'],
244
+ results['Model_2_accuracy'], results['Model_2_roc_score'], results['Model_2_report'],
245
+ results['Model_2_cm_fig'], results['Model_2_roc_fig'])
246
+
247
+
248
+
249
+
250
+ # Single image section
251
+ def load_url(url):
252
+ try:
253
+ urllib.request.urlretrieve(f'{url}', f"{uid}tmp_im.png")
254
+ image = Image.open(f"{uid}tmp_im.png")
255
+ mes = "Image Loaded"
256
+ except Exception as e:
257
+ image = None
258
+ mes = f"Image not Found<br>Error: {e}"
259
+ return image, mes
260
+
261
+ def tot_prob():
262
+ try:
263
+ fin_out = sum([result["Real"] for result in fin_sum]) / len(fin_sum)
264
+ fin_sub = 1 - fin_out
265
+ out = {
266
+ "Real": f"{fin_out:.4f}",
267
+ "AI": f"{fin_sub:.4f}"
268
+ }
269
+ return out
270
+ except Exception as e:
271
+ print(e)
272
+ return None
273
+
274
+ def fin_clear():
275
+ fin_sum.clear()
276
+ return None
277
+
278
+ # Set up Gradio app
279
+ with gr.Blocks() as app:
280
+ gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")
281
+
282
+ with gr.Tabs():
283
+ # Tab for single image detection
284
+ with gr.Tab("Single Image Detection"):
285
+ with gr.Column():
286
+ inp = gr.Image(type='pil')
287
+ in_url = gr.Textbox(label="Image URL")
288
+ with gr.Row():
289
+ load_btn = gr.Button("Load URL")
290
+ btn = gr.Button("Detect AI")
291
+ mes = gr.HTML("""""")
292
+
293
+ with gr.Group():
294
+ with gr.Row():
295
+ fin = gr.Label(label="Final Probability")
296
+ with gr.Row():
297
+ for i, model in enumerate(models):
298
+ with gr.Box():
299
+ gr.HTML(f"""<b>Testing on Model {i}: <a href='https://huggingface.co/{model}'>{model}</a></b>""")
300
+ globals()[f'outp{i}'] = gr.HTML("""""")
301
+ globals()[f'n_out{i}'] = gr.Label(label="Output")
302
+
303
+ btn.click(fin_clear, None, fin, show_progress=False)
304
+ load_btn.click(load_url, in_url, [inp, mes])
305
+
306
+ btn.click(aiornot0, [inp], [outp0, n_out0]).then(
307
+ aiornot1, [inp], [outp1, n_out1]).then(
308
+ aiornot2, [inp], [outp2, n_out2]).then(
309
+ tot_prob, None, fin, show_progress=False)
310
+
311
+ # Tab for batch processing
312
+ with gr.Tab("Batch Image Processing"):
313
+ zip_file = gr.File(label="Upload Zip (two folders: real, ai)")
314
+ batch_btn = gr.Button("Process Batch")
315
+
316
+ for i, model in enumerate(models):
317
+ with gr.Group():
318
+ gr.Markdown(f"### Results for {model}")
319
+ globals()[f'output_acc{i}'] = gr.Label(label=f"Model {i} Accuracy")
320
+ globals()[f'output_roc{i}'] = gr.Label(label=f"Model {i} ROC Score")
321
+ globals()[f'output_report{i}'] = gr.Textbox(label=f"Model {i} Classification Report", lines=10)
322
+ globals()[f'output_cm{i}'] = gr.Plot(label=f"Model {i} Confusion Matrix")
323
+ globals()[f'output_roc_plot{i}'] = gr.Plot(label=f"Model {i} ROC Curve")
324
+
325
+ # Connect batch processing
326
+ batch_btn.click(process_zip, zip_file,
327
+ [output_acc0, output_roc0, output_report0, output_cm0, output_roc_plot0,
328
+ output_acc1, output_roc1, output_report1, output_cm1, output_roc_plot1,
329
+ output_acc2, output_roc2, output_report2, output_cm2, output_roc_plot2])
330
+
331
+ app.launch(show_api=False, max_threads=24)