# Import the necessary libraries and modules import os import gradio as gr from transformers import ViTImageProcessor, ViTFeatureExtractor, FlaxViTForImageClassification, ViTModel from PIL import Image import requests import os import torch import torch.nn as nn import torchvision import matplotlib.pyplot as plt def visualize_attention(name): model_name = name.split(";")[0] if len(name.split(";"))>1: url = name.split(";")[1] else: url = "http://images.cocodataset.org/val2017/000000039769.jpg" feature_extractor = ViTImageProcessor.from_pretrained(model_name, size=480) pil_image = Image.open(requests.get(url, stream=True).raw) device = "cpu" pixel_values = feature_extractor(images=pil_image, return_tensors="pt").pixel_values.to(device) model = ViTModel.from_pretrained(model_name, add_pooling_layer=False) model.to(device) outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True) attentions = outputs.attentions[-1] # we are only interested in the attention maps of the last layer nh = attentions.shape[1] # number of head # we keep only the output patch attention attentions = attentions[0, :, 0, 1:].reshape(nh, -1) threshold = 0.6 w_featmap = pixel_values.shape[-2] // model.config.patch_size h_featmap = pixel_values.shape[-1] // model.config.patch_size # we keep only a certain percentage of the mass val, idx = torch.sort(attentions) val /= torch.sum(val, dim=1, keepdim=True) cumval = torch.cumsum(val, dim=1) th_attn = cumval > (1 - threshold) idx2 = torch.argsort(idx) for head in range(nh): th_attn[head] = th_attn[head][idx2[head]] th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() # interpolate th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu().numpy() attentions = attentions.reshape(nh, w_featmap, h_featmap) attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu() attentions = attentions.detach().numpy() # show and save attentions heatmaps output_dir = '.' os.makedirs(output_dir, exist_ok=True) torchvision.utils.save_image(torchvision.utils.make_grid(pixel_values, normalize=True, scale_each=True), os.path.join(output_dir, "img.png")) for j in range(nh): fname = os.path.join(output_dir, "attn-head" + str(j) + ".png") plt.figure() plt.imshow(attentions[j]) plt.imsave(fname=fname, arr=attentions[j], format='png') images = [] for j in range(nh): images.append(Image.open(os.path.join(output_dir, "attn-head" + str(j) + ".png"))) return images text_input = gr.Textbox(label="Enter the name of the model to use and optionally add in your own image jpg url with ; as a separator try out this: facebook/dino-vits8; https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/481px-Cat03.jpg", placeholder = "facebook/dino-vits8; optionalurl.jpg") attention_output = gr.Gallery(label="Attention Map") iface = gr.Interface( fn=visualize_attention, inputs=text_input, outputs=attention_output, live=True, capture_session=True, title="Visualize Attention Maps", description="This app uses a Vision Transformer to visualize the attention maps of an image.", ) iface.launch()