Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from PIL import Image | |
import torch.nn.functional as F | |
device = torch.device("cpu") | |
class VGGBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, batch_norm=False): | |
super().__init__() | |
conv2_params = {'kernel_size': (3, 3), | |
'stride' : (1, 1), | |
'padding' : 1} | |
noop = lambda x : x | |
self._batch_norm = batch_norm | |
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels , **conv2_params) | |
self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else noop | |
self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, **conv2_params) | |
self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else noop | |
self.max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) | |
def batch_norm(self): | |
return self._batch_norm | |
def forward(self,x): | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = F.relu(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
x = F.relu(x) | |
x = self.max_pooling(x) | |
return x | |
class VGG16(nn.Module): | |
def __init__(self, input_size, num_classes=10, batch_norm=False): | |
super(VGG16, self).__init__() | |
self.in_channels, self.in_width, self.in_height = input_size | |
self.block_1 = VGGBlock(self.in_channels, 64, batch_norm=batch_norm) | |
self.block_2 = VGGBlock(64, 128, batch_norm=batch_norm) | |
self.block_3 = VGGBlock(128, 256, batch_norm=batch_norm) | |
self.block_4 = VGGBlock(256,512, batch_norm=batch_norm) | |
self.classifier = nn.Sequential( | |
nn.Linear(2048, 4096), | |
nn.ReLU(True), | |
nn.Dropout(p=0.65), | |
nn.Linear(4096, 4096), | |
nn.ReLU(True), | |
nn.Dropout(p=0.65), | |
nn.Linear(4096, num_classes) | |
) | |
def input_size(self): | |
return self.in_channels, self.in_width, self.in_height | |
def forward(self, x): | |
x = self.block_1(x) | |
x = self.block_2(x) | |
x = self.block_3(x) | |
x = self.block_4(x) | |
x = x.view(x.size(0), -1) | |
x = self.classifier(x) | |
return x | |
model = VGG16((1,32,32), batch_norm=True) | |
model.to(device) | |
# Load the saved checkpoint | |
model.load_state_dict(torch.load('model.pth', map_location=device)) | |
label_map = { | |
0: 'T-shirt/top', | |
1: 'Trouser', | |
2: 'Pullover', | |
3: 'Dress', | |
4: 'Coat', | |
5: 'Sandal', | |
6: 'Shirt', | |
7: 'Sneaker', | |
8: 'FLAG{3883}', | |
9: 'Ankle boot' | |
} | |
def predict_from_local_image(image: str): | |
# Define the transformation to match the model's input requirements | |
transform = transforms.Compose([ | |
transforms.Resize((32, 32)), # Resize to the input size of the model | |
transforms.ToTensor(), # Convert the image to a tensor | |
]) | |
# Load the image | |
image = Image.open(image).convert('L') # Convert numpy array to PIL image and then to grayscale if necessary | |
image = transform(image).unsqueeze(0) # Add batch dimension | |
# Move the image to the specified device | |
image = image.to(device) | |
# Set the model to evaluation mode | |
model.eval() | |
# Make a prediction | |
with torch.no_grad(): | |
output = model(image) | |
_, predicted_label = torch.max(output, 1) | |
confidence = torch.nn.functional.softmax(output, dim=1)[0] * 100 | |
# Get the predicted class label and confidence | |
predicted_class = label_map[predicted_label.item()] | |
predicted_confidence = confidence[predicted_label.item()].item() | |
return predicted_class, predicted_confidence | |
# Gradio interface | |
iface = gr.Interface( | |
fn=predict_from_local_image, # Function to call for prediction | |
inputs=gr.Image(type='filepath', label="Upload an image"), # Input: .pt file upload | |
outputs=gr.Textbox(label="Predicted Class"), # Output: Text showing predicted class | |
title="Vault Challenge 4 - DeepFool", # Title of the interface | |
description="Upload an image, and the model will predict the class. Try to fool the model into predicting the FLAG using DeepFool! Tips: apply DeepFool attack on the image to make the model predict it as a BAG. Note that you should save the adverserial image as a .pt file and upload it to the model to get the FLAG." | |
) | |
# Launch the Gradio interface | |
iface.launch() |