Spaces:
Sleeping
Sleeping
| import io | |
| import random | |
| from io import BytesIO | |
| from typing import List, Tuple | |
| import aiohttp | |
| import panel as pn | |
| import torch | |
| from bokeh.themes import Theme | |
| # import torchvision.transforms.functional as TVF | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, ResNetForImageClassification | |
| from transformers.image_transforms import to_pil_image | |
| DEVICE = "cpu" | |
| pn.extension("mathjax", design="bootstrap", sizing_mode="stretch_width") | |
| def load_processor_model( | |
| processor_name: str, model_name: str | |
| ) -> Tuple[AutoImageProcessor, ResNetForImageClassification]: | |
| processor = AutoImageProcessor.from_pretrained(processor_name) | |
| model = ResNetForImageClassification.from_pretrained(model_name) | |
| return processor, model | |
| def denormalize(image, mean, std): | |
| mean = torch.tensor(mean).view(1, -1, 1, 1) # Reshape for broadcasting | |
| std = torch.tensor(std).view(1, -1, 1, 1) | |
| return image * std + mean | |
| # FGSM attack code | |
| def fgsm_attack(image, epsilon, data_grad): | |
| # Collect the element-wise sign of the data gradient | |
| sign_data_grad = data_grad.sign() | |
| # Create the perturbed image by adjusting each pixel of the input image | |
| perturbed_image = image + epsilon * sign_data_grad | |
| # Adding clipping to maintain [0,1] range | |
| perturbed_image = torch.clamp(perturbed_image, 0, 1) | |
| # Return the perturbed image | |
| return perturbed_image.detach() | |
| def run_forward_backward(image: Image, epsilon): | |
| processor, model = load_processor_model( | |
| "microsoft/resnet-18", "microsoft/resnet-18" | |
| ) | |
| # Grab input | |
| input_tensor = processor(image, return_tensors="pt")["pixel_values"] | |
| input_tensor.requires_grad_(True) | |
| # Run inference | |
| output = model(input_tensor) | |
| output = output.logits | |
| # Top target | |
| top_pred = output.max(1, keepdim=False)[1] | |
| # Get NLL loss and backward | |
| loss = F.cross_entropy(output, top_pred) | |
| model.zero_grad() | |
| loss.backward() | |
| # Denormalize input | |
| mean = torch.tensor(processor.image_mean).view(1, -1, 1, 1) | |
| std = torch.tensor(processor.image_std).view(1, -1, 1, 1) | |
| input_tensor_denorm = input_tensor.detach() * std + mean | |
| # FGSM attack | |
| adv_input_tensor_denorm = fgsm_attack( | |
| image=input_tensor_denorm, epsilon=epsilon, data_grad=input_tensor.grad.data | |
| ) | |
| # Normalize adversarial input tensor back to the input range | |
| adv_input_tensor = (adv_input_tensor_denorm - mean) / std | |
| # Inference on adversarial image | |
| adv_output = model(adv_input_tensor) | |
| adv_output = adv_output.logits | |
| return ( | |
| output, | |
| adv_output, | |
| input_tensor_denorm.squeeze(), | |
| adv_input_tensor_denorm.squeeze(), | |
| ) | |
| async def process_inputs(button_event, image_data: bytes, epsilon: float): | |
| """ | |
| High level function that takes in the user inputs and returns the | |
| classification results as panel objects. | |
| """ | |
| try: | |
| main.disabled = True | |
| # if not button_event or (button_event and not isinstance(image_data, bytes)): | |
| if not isinstance(image_data, bytes): | |
| yield "##### π Upload an image to proceed" | |
| return | |
| yield "##### β Fetching image and running model..." | |
| try: | |
| # Open the image using PIL | |
| pil_img = Image.open(BytesIO(image_data)) | |
| # Run forward + FGSM | |
| clean_logits, adv_logits, input_tensor, adv_input_tensor = run_forward_backward( | |
| image=pil_img, epsilon=epsilon | |
| ) | |
| except Exception as e: | |
| yield f"##### Something went wrong, please try a different image! \n {e}" | |
| return | |
| img = pn.pane.Image( | |
| to_pil_image(input_tensor, do_rescale=True), | |
| height=350, | |
| align="center", | |
| ) | |
| # Convert image for visualizing | |
| adv_img = pn.pane.Image( | |
| to_pil_image(adv_input_tensor, do_rescale=True), | |
| height=350, | |
| align="center", | |
| ) | |
| # Build the results column | |
| k_val = 5 | |
| results = pn.Column( | |
| pn.Row("###### Uploaded", "###### Adversarial"), pn.Row(img, adv_img), f" ###### Top {k_val} class predictions", | |
| ) | |
| # Get likelihoods | |
| likelihoods = [ | |
| F.softmax(clean_logits, dim=1).squeeze(), | |
| F.softmax(adv_logits, dim=1).squeeze(), | |
| ] | |
| label_bars_rows = pn.Row() | |
| for likelihood_tensor in likelihoods: | |
| # Get top k values and indices | |
| vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val) | |
| label_bars = pn.Column() | |
| for idx, val in zip(idx_topk_clean, vals_topk_clean): | |
| prob = val.item() | |
| row_label = pn.widgets.StaticText( | |
| name=f"{classes[idx]}", | |
| value=f"{prob:.2%}", | |
| align="center" | |
| ) | |
| row_bar = pn.indicators.Progress( | |
| value=int(prob * 100), | |
| sizing_mode="stretch_width", | |
| bar_color="success" if prob > 0.7 else "warning", # Dynamic color based on value | |
| margin=(0, 10), | |
| design=pn.theme.Material, | |
| ) | |
| label_bars.append(pn.Column(row_label, row_bar)) | |
| # for likelihood_tensor in likelihoods: | |
| # # Get top | |
| # vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val) | |
| # label_bars = pn.Column() | |
| # for idx, val in zip(idx_topk_clean, vals_topk_clean): | |
| # prob = val.item() | |
| # row_label = pn.widgets.StaticText( | |
| # name=f"{classes[idx]}", value=f"{prob:.2%}", align="center" | |
| # ) | |
| # row_bar = pn.indicators.Progress( | |
| # value=int(prob * 100), | |
| # sizing_mode="stretch_width", | |
| # bar_color="secondary", | |
| # margin=(0, 10), | |
| # design=pn.theme.Material, | |
| # ) | |
| # label_bars.append(pn.Column(row_label, row_bar)) | |
| label_bars_rows.append(label_bars) | |
| results.append(label_bars_rows) | |
| yield results | |
| except Exception as e: | |
| yield f"##### Something went wrong! \n {e}" | |
| return | |
| finally: | |
| main.disabled = False | |
| #################################################################################################################################### | |
| # Get classes | |
| classes = [] | |
| with open("classes.txt", "r") as file: | |
| classes = file.read() | |
| classes = classes.split("\n") | |
| # Create widgets | |
| ############################################ | |
| # Fil upload widget | |
| file_input = pn.widgets.FileInput(name="Upload a PNG image", accept=".png,.jpg") | |
| # Epsilon | |
| epsilon_slider = pn.widgets.FloatSlider( | |
| name=r"$$\epsilon$$", start=0, end=0.1, step=0.005, value=0.05, format='1[.]000' | |
| ) | |
| # Upload button widget | |
| upload_image = pn.widgets.Button(name="Upload image", align="end") | |
| ############################################ | |
| # Organize widgets in a column | |
| input_widgets = pn.Column( | |
| """ | |
| ###### Classify an image with a pre-trained [ResNet18](https://huggingface.co/microsoft/resnet-18) and generate an adversarial example.\n | |
| Please be patient with the application, it is running on a low-resource device. | |
| """, | |
| file_input, | |
| epsilon_slider, | |
| ) | |
| # Add interactivity | |
| interactive_result = pn.panel( | |
| pn.bind( | |
| process_inputs, upload_image, file_input.param.value, epsilon_slider.param.value | |
| ), | |
| height=600, | |
| ) | |
| footer = pn.pane.Markdown( | |
| """ | |
| <br><br><br><br> | |
| Wondering where the class names come from? Find the full list [here](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/) | |
| """ | |
| ) | |
| # Create dashboard | |
| main = pn.WidgetBox( | |
| input_widgets, | |
| interactive_result, | |
| footer, | |
| ) | |
| title = "Adversarial Sample Generation" | |
| pn.template.BootstrapTemplate( | |
| title=title, | |
| main=main, | |
| main_max_width="min(50%, 698px)", | |
| header_background="#101820", | |
| ).servable(title=title) | |
| # Functions from original demo | |
| # ICON_URLS = { | |
| # "brand-github": "https://github.com/holoviz/panel", | |
| # "brand-twitter": "https://twitter.com/Panel_Org", | |
| # "brand-linkedin": "https://www.linkedin.com/company/panel-org", | |
| # "message-circle": "https://discourse.holoviz.org/", | |
| # "brand-discord": "https://discord.gg/AXRHnJU6sP", | |
| # } | |
| # async def random_url(_): | |
| # pet = random.choice(["cat", "dog"]) | |
| # api_url = f"https://api.the{pet}api.com/v1/images/search" | |
| # async with aiohttp.ClientSession() as session: | |
| # async with session.get(api_url) as resp: | |
| # return (await resp.json())[0]["url"] | |
| # @pn.cache | |
| # def load_processor_model( | |
| # processor_name: str, model_name: str | |
| # ) -> Tuple[CLIPProcessor, CLIPModel]: | |
| # processor = CLIPProcessor.from_pretrained(processor_name) | |
| # model = CLIPModel.from_pretrained(model_name) | |
| # return processor, model | |
| # async def open_image_url(image_url: str) -> Image: | |
| # async with aiohttp.ClientSession() as session: | |
| # async with session.get(image_url) as resp: | |
| # return Image.open(io.BytesIO(await resp.read())) | |
| # def get_similarity_scores(class_items: List[str], image: Image) -> List[float]: | |
| # processor, model = load_processor_model( | |
| # "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32" | |
| # ) | |
| # inputs = processor( | |
| # text=class_items, | |
| # images=[image], | |
| # return_tensors="pt", # pytorch tensors | |
| # ) | |
| # print(inputs) | |
| # outputs = model(**inputs) | |
| # logits_per_image = outputs.logits_per_image | |
| # class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy() | |
| # return class_likelihoods[0] | |