Visualize_ViT / app.py
jcoding's picture
Create app.py
04fe2b6
# Import the necessary libraries and modules
import os
import gradio as gr
from transformers import ViTImageProcessor, ViTFeatureExtractor, FlaxViTForImageClassification, ViTModel
from PIL import Image
import requests
import os
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
def visualize_attention(name):
model_name = name.split(";")[0]
if len(name.split(";"))>1:
url = name.split(";")[1]
else:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
feature_extractor = ViTImageProcessor.from_pretrained(model_name, size=480)
pil_image = Image.open(requests.get(url, stream=True).raw)
device = "cpu"
pixel_values = feature_extractor(images=pil_image, return_tensors="pt").pixel_values.to(device)
model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
model.to(device)
outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True)
attentions = outputs.attentions[-1] # we are only interested in the attention maps of the last layer
nh = attentions.shape[1] # number of head
# we keep only the output patch attention
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
threshold = 0.6
w_featmap = pixel_values.shape[-2] // model.config.patch_size
h_featmap = pixel_values.shape[-1] // model.config.patch_size
# we keep only a certain percentage of the mass
val, idx = torch.sort(attentions)
val /= torch.sum(val, dim=1, keepdim=True)
cumval = torch.cumsum(val, dim=1)
th_attn = cumval > (1 - threshold)
idx2 = torch.argsort(idx)
for head in range(nh):
th_attn[head] = th_attn[head][idx2[head]]
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
# interpolate
th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu().numpy()
attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu()
attentions = attentions.detach().numpy()
# show and save attentions heatmaps
output_dir = '.'
os.makedirs(output_dir, exist_ok=True)
torchvision.utils.save_image(torchvision.utils.make_grid(pixel_values, normalize=True, scale_each=True), os.path.join(output_dir, "img.png"))
for j in range(nh):
fname = os.path.join(output_dir, "attn-head" + str(j) + ".png")
plt.figure()
plt.imshow(attentions[j])
plt.imsave(fname=fname, arr=attentions[j], format='png')
images = []
for j in range(nh):
images.append(Image.open(os.path.join(output_dir, "attn-head" + str(j) + ".png")))
return images
text_input = gr.Textbox(label="Enter the name of the model to use and optionally add in your own image jpg url with ; as a separator try out this: facebook/dino-vits8; https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/481px-Cat03.jpg", placeholder = "facebook/dino-vits8; optionalurl.jpg")
attention_output = gr.Gallery(label="Attention Map")
iface = gr.Interface(
fn=visualize_attention,
inputs=text_input,
outputs=attention_output,
live=True,
capture_session=True,
title="Visualize Attention Maps",
description="This app uses a Vision Transformer to visualize the attention maps of an image.",
)
iface.launch()