hugolb's picture
Add files
71012db
raw
history blame
2.21 kB
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
from tensorflow.keras import datasets, layers, models
# Load the trained model
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(10) # 10 classes in CIFAR-10
])
model.load_weights("cifar10_modified_flag.weights.h5")
# class 3 is a cat
# Class mapping (0-9 with class 3 replaced by "FLAG{3883}")
class_mapping = {0: "airplane", 1: "automobile", 2: "bird", 3: "FLAG{3883}", 4: "deer",
5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"}
# Function to preprocess the input image
def preprocess_image(image):
image = image.resize((32, 32)) # Resize to CIFAR-10 size
image = np.array(image) / 255.0 # Normalize pixel values
image = np.expand_dims(image, axis=0) # Add batch dimension
return image
# Prediction function
def predict(image):
# Preprocess the image
image = preprocess_image(image)
# Get the model's raw prediction (logits)
logits = model.predict(image)
# Convert logits to probabilities
probabilities = tf.nn.softmax(logits, axis=-1)
# Get the predicted class index
predicted_class = np.argmax(probabilities)
# Get the class name from the mapping
class_name = class_mapping[predicted_class]
return class_name
# Gradio interface
iface = gr.Interface(
fn=predict, # Function to call for prediction
inputs=gr.Image(type="pil", label="Upload an image from CIFAR-10"), # Input: Image upload
outputs=gr.Textbox(label="Predicted Class"), # Output: Text showing predicted class
title="Vault Challenge 2 - BIM", # Title of the interface
description="Upload an image, and the model will predict the class. Try to fool the model into predicting the FLAG using BIM!. Tips: tune the parameters to make the model predict the image as a cat (class 3)."
)
# Launch the Gradio interface
iface.launch()