File size: 3,462 Bytes
04fe2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Import the necessary libraries and modules
import os
import gradio as gr

from transformers import ViTImageProcessor, ViTFeatureExtractor, FlaxViTForImageClassification, ViTModel
from PIL import Image
import requests
import os
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt


def visualize_attention(name):
    model_name = name.split(";")[0]
    if len(name.split(";"))>1:
        url = name.split(";")[1]
    else:
        url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    feature_extractor = ViTImageProcessor.from_pretrained(model_name, size=480)
    
    pil_image = Image.open(requests.get(url, stream=True).raw)
    device = "cpu"
    pixel_values = feature_extractor(images=pil_image, return_tensors="pt").pixel_values.to(device)
    model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
    
    model.to(device)
    outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True)
    attentions = outputs.attentions[-1] # we are only interested in the attention maps of the last layer
    nh = attentions.shape[1] # number of head
    
    # we keep only the output patch attention
    attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
    threshold = 0.6
    w_featmap = pixel_values.shape[-2] // model.config.patch_size
    h_featmap = pixel_values.shape[-1] // model.config.patch_size
    
    # we keep only a certain percentage of the mass
    val, idx = torch.sort(attentions)
    val /= torch.sum(val, dim=1, keepdim=True)
    cumval = torch.cumsum(val, dim=1)
    th_attn = cumval > (1 - threshold)
    idx2 = torch.argsort(idx)
    for head in range(nh):
        th_attn[head] = th_attn[head][idx2[head]]
    th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
    # interpolate
    th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu().numpy()
    
    attentions = attentions.reshape(nh, w_featmap, h_featmap)
    attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu()
    attentions = attentions.detach().numpy()
    
    # show and save attentions heatmaps
    output_dir = '.'
    os.makedirs(output_dir, exist_ok=True)
    torchvision.utils.save_image(torchvision.utils.make_grid(pixel_values, normalize=True, scale_each=True), os.path.join(output_dir, "img.png"))
    for j in range(nh):
        fname = os.path.join(output_dir, "attn-head" + str(j) + ".png")
        plt.figure()
        plt.imshow(attentions[j])
        plt.imsave(fname=fname, arr=attentions[j], format='png')
    images = []
    for j in range(nh):
        images.append(Image.open(os.path.join(output_dir, "attn-head" + str(j) + ".png")))
    return images


text_input = gr.Textbox(label="Enter the name of the model to use and optionally add in your own image jpg url with ; as a separator try out this: facebook/dino-vits8; https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/481px-Cat03.jpg", placeholder = "facebook/dino-vits8; optionalurl.jpg")
attention_output = gr.Gallery(label="Attention Map")

iface = gr.Interface(
    fn=visualize_attention,
    inputs=text_input,
    outputs=attention_output,
    live=True,
    capture_session=True,
    title="Visualize Attention Maps",
    description="This app uses a Vision Transformer to visualize the attention maps of an image.",
)

iface.launch()