Spaces:
Sleeping
Sleeping
import io | |
import random | |
from io import BytesIO | |
from typing import List, Tuple | |
import aiohttp | |
import panel as pn | |
import torch | |
from bokeh.themes import Theme | |
# import torchvision.transforms.functional as TVF | |
import torch.nn.functional as F | |
from PIL import Image | |
from transformers import AutoImageProcessor, ResNetForImageClassification | |
from transformers.image_transforms import to_pil_image | |
DEVICE = "cpu" | |
pn.extension("mathjax", design="bootstrap", sizing_mode="stretch_width") | |
def load_processor_model( | |
processor_name: str, model_name: str | |
) -> Tuple[AutoImageProcessor, ResNetForImageClassification]: | |
processor = AutoImageProcessor.from_pretrained(processor_name) | |
model = ResNetForImageClassification.from_pretrained(model_name) | |
return processor, model | |
def denormalize(image, mean, std): | |
mean = torch.tensor(mean).view(1, -1, 1, 1) # Reshape for broadcasting | |
std = torch.tensor(std).view(1, -1, 1, 1) | |
return image * std + mean | |
# FGSM attack code | |
def fgsm_attack(image, epsilon, data_grad): | |
# Collect the element-wise sign of the data gradient | |
sign_data_grad = data_grad.sign() | |
# Create the perturbed image by adjusting each pixel of the input image | |
perturbed_image = image + epsilon * sign_data_grad | |
# Adding clipping to maintain [0,1] range | |
perturbed_image = torch.clamp(perturbed_image, 0, 1) | |
# Return the perturbed image | |
return perturbed_image.detach() | |
def run_forward_backward(image: Image, epsilon): | |
processor, model = load_processor_model( | |
"microsoft/resnet-18", "microsoft/resnet-18" | |
) | |
# Grab input | |
input_tensor = processor(image, return_tensors="pt")["pixel_values"] | |
input_tensor.requires_grad_(True) | |
# Run inference | |
output = model(input_tensor) | |
output = output.logits | |
# Top target | |
top_pred = output.max(1, keepdim=False)[1] | |
# Get NLL loss and backward | |
loss = F.cross_entropy(output, top_pred) | |
model.zero_grad() | |
loss.backward() | |
# Denormalize input | |
mean = torch.tensor(processor.image_mean).view(1, -1, 1, 1) | |
std = torch.tensor(processor.image_std).view(1, -1, 1, 1) | |
input_tensor_denorm = input_tensor.detach() * std + mean | |
# FGSM attack | |
adv_input_tensor_denorm = fgsm_attack( | |
image=input_tensor_denorm, epsilon=epsilon, data_grad=input_tensor.grad.data | |
) | |
# Normalize adversarial input tensor back to the input range | |
adv_input_tensor = (adv_input_tensor_denorm - mean) / std | |
# Inference on adversarial image | |
adv_output = model(adv_input_tensor) | |
adv_output = adv_output.logits | |
return ( | |
output, | |
adv_output, | |
input_tensor_denorm.squeeze(), | |
adv_input_tensor_denorm.squeeze(), | |
) | |
async def process_inputs(button_event, image_data: bytes, epsilon: float): | |
""" | |
High level function that takes in the user inputs and returns the | |
classification results as panel objects. | |
""" | |
try: | |
main.disabled = True | |
# if not button_event or (button_event and not isinstance(image_data, bytes)): | |
if not isinstance(image_data, bytes): | |
yield "##### π Upload an image to proceed" | |
return | |
yield "##### β Fetching image and running model..." | |
try: | |
# Open the image using PIL | |
pil_img = Image.open(BytesIO(image_data)) | |
# Run forward + FGSM | |
clean_logits, adv_logits, input_tensor, adv_input_tensor = run_forward_backward( | |
image=pil_img, epsilon=epsilon | |
) | |
except Exception as e: | |
yield f"##### Something went wrong, please try a different image! \n {e}" | |
return | |
img = pn.pane.Image( | |
to_pil_image(input_tensor, do_rescale=True), | |
height=350, | |
align="center", | |
) | |
# Convert image for visualizing | |
adv_img = pn.pane.Image( | |
to_pil_image(adv_input_tensor, do_rescale=True), | |
height=350, | |
align="center", | |
) | |
# Build the results column | |
k_val = 5 | |
results = pn.Column( | |
pn.Row("###### Uploaded", "###### Adversarial"), pn.Row(img, adv_img), f" ###### Top {k_val} class predictions", | |
) | |
# Get likelihoods | |
likelihoods = [ | |
F.softmax(clean_logits, dim=1).squeeze(), | |
F.softmax(adv_logits, dim=1).squeeze(), | |
] | |
label_bars_rows = pn.Row() | |
for likelihood_tensor in likelihoods: | |
# Get top k values and indices | |
vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val) | |
label_bars = pn.Column() | |
for idx, val in zip(idx_topk_clean, vals_topk_clean): | |
prob = val.item() | |
row_label = pn.widgets.StaticText( | |
name=f"{classes[idx]}", | |
value=f"{prob:.2%}", | |
align="center" | |
) | |
row_bar = pn.indicators.Progress( | |
value=int(prob * 100), | |
sizing_mode="stretch_width", | |
bar_color="success" if prob > 0.7 else "warning", # Dynamic color based on value | |
margin=(0, 10), | |
design=pn.theme.Material, | |
) | |
label_bars.append(pn.Column(row_label, row_bar)) | |
# for likelihood_tensor in likelihoods: | |
# # Get top | |
# vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val) | |
# label_bars = pn.Column() | |
# for idx, val in zip(idx_topk_clean, vals_topk_clean): | |
# prob = val.item() | |
# row_label = pn.widgets.StaticText( | |
# name=f"{classes[idx]}", value=f"{prob:.2%}", align="center" | |
# ) | |
# row_bar = pn.indicators.Progress( | |
# value=int(prob * 100), | |
# sizing_mode="stretch_width", | |
# bar_color="secondary", | |
# margin=(0, 10), | |
# design=pn.theme.Material, | |
# ) | |
# label_bars.append(pn.Column(row_label, row_bar)) | |
label_bars_rows.append(label_bars) | |
results.append(label_bars_rows) | |
yield results | |
except Exception as e: | |
yield f"##### Something went wrong! \n {e}" | |
return | |
finally: | |
main.disabled = False | |
#################################################################################################################################### | |
# Get classes | |
classes = [] | |
with open("classes.txt", "r") as file: | |
classes = file.read() | |
classes = classes.split("\n") | |
# Create widgets | |
############################################ | |
# Fil upload widget | |
file_input = pn.widgets.FileInput(name="Upload a PNG image", accept=".png,.jpg") | |
# Epsilon | |
epsilon_slider = pn.widgets.FloatSlider( | |
name=r"$$\epsilon$$", start=0, end=0.1, step=0.005, value=0.05, format='1[.]000' | |
) | |
# Upload button widget | |
upload_image = pn.widgets.Button(name="Upload image", align="end") | |
############################################ | |
# Organize widgets in a column | |
input_widgets = pn.Column( | |
""" | |
###### Classify an image with a pre-trained [ResNet18](https://huggingface.co/microsoft/resnet-18) and generate an adversarial example.\n | |
Please be patient with the application, it is running on a low-resource device. | |
""", | |
file_input, | |
epsilon_slider, | |
) | |
# Add interactivity | |
interactive_result = pn.panel( | |
pn.bind( | |
process_inputs, upload_image, file_input.param.value, epsilon_slider.param.value | |
), | |
height=600, | |
) | |
footer = pn.pane.Markdown( | |
""" | |
<br><br><br><br> | |
Wondering where the class names come from? Find the full list [here](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/) | |
""" | |
) | |
# Create dashboard | |
main = pn.WidgetBox( | |
input_widgets, | |
interactive_result, | |
footer, | |
) | |
title = "Adversarial Sample Generation" | |
pn.template.BootstrapTemplate( | |
title=title, | |
main=main, | |
main_max_width="min(50%, 698px)", | |
header_background="#101820", | |
).servable(title=title) | |
# Functions from original demo | |
# ICON_URLS = { | |
# "brand-github": "https://github.com/holoviz/panel", | |
# "brand-twitter": "https://twitter.com/Panel_Org", | |
# "brand-linkedin": "https://www.linkedin.com/company/panel-org", | |
# "message-circle": "https://discourse.holoviz.org/", | |
# "brand-discord": "https://discord.gg/AXRHnJU6sP", | |
# } | |
# async def random_url(_): | |
# pet = random.choice(["cat", "dog"]) | |
# api_url = f"https://api.the{pet}api.com/v1/images/search" | |
# async with aiohttp.ClientSession() as session: | |
# async with session.get(api_url) as resp: | |
# return (await resp.json())[0]["url"] | |
# @pn.cache | |
# def load_processor_model( | |
# processor_name: str, model_name: str | |
# ) -> Tuple[CLIPProcessor, CLIPModel]: | |
# processor = CLIPProcessor.from_pretrained(processor_name) | |
# model = CLIPModel.from_pretrained(model_name) | |
# return processor, model | |
# async def open_image_url(image_url: str) -> Image: | |
# async with aiohttp.ClientSession() as session: | |
# async with session.get(image_url) as resp: | |
# return Image.open(io.BytesIO(await resp.read())) | |
# def get_similarity_scores(class_items: List[str], image: Image) -> List[float]: | |
# processor, model = load_processor_model( | |
# "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32" | |
# ) | |
# inputs = processor( | |
# text=class_items, | |
# images=[image], | |
# return_tensors="pt", # pytorch tensors | |
# ) | |
# print(inputs) | |
# outputs = model(**inputs) | |
# logits_per_image = outputs.logits_per_image | |
# class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy() | |
# return class_likelihoods[0] | |