jcoding commited on
Commit
04fe2b6
·
1 Parent(s): ea6c435

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import the necessary libraries and modules
2
+ import os
3
+ import gradio as gr
4
+
5
+ from transformers import ViTImageProcessor, ViTFeatureExtractor, FlaxViTForImageClassification, ViTModel
6
+ from PIL import Image
7
+ import requests
8
+ import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import torchvision
12
+ import matplotlib.pyplot as plt
13
+
14
+
15
+ def visualize_attention(name):
16
+ model_name = name.split(";")[0]
17
+ if len(name.split(";"))>1:
18
+ url = name.split(";")[1]
19
+ else:
20
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
21
+ feature_extractor = ViTImageProcessor.from_pretrained(model_name, size=480)
22
+
23
+ pil_image = Image.open(requests.get(url, stream=True).raw)
24
+ device = "cpu"
25
+ pixel_values = feature_extractor(images=pil_image, return_tensors="pt").pixel_values.to(device)
26
+ model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
27
+
28
+ model.to(device)
29
+ outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True)
30
+ attentions = outputs.attentions[-1] # we are only interested in the attention maps of the last layer
31
+ nh = attentions.shape[1] # number of head
32
+
33
+ # we keep only the output patch attention
34
+ attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
35
+ threshold = 0.6
36
+ w_featmap = pixel_values.shape[-2] // model.config.patch_size
37
+ h_featmap = pixel_values.shape[-1] // model.config.patch_size
38
+
39
+ # we keep only a certain percentage of the mass
40
+ val, idx = torch.sort(attentions)
41
+ val /= torch.sum(val, dim=1, keepdim=True)
42
+ cumval = torch.cumsum(val, dim=1)
43
+ th_attn = cumval > (1 - threshold)
44
+ idx2 = torch.argsort(idx)
45
+ for head in range(nh):
46
+ th_attn[head] = th_attn[head][idx2[head]]
47
+ th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
48
+ # interpolate
49
+ th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu().numpy()
50
+
51
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
52
+ attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu()
53
+ attentions = attentions.detach().numpy()
54
+
55
+ # show and save attentions heatmaps
56
+ output_dir = '.'
57
+ os.makedirs(output_dir, exist_ok=True)
58
+ torchvision.utils.save_image(torchvision.utils.make_grid(pixel_values, normalize=True, scale_each=True), os.path.join(output_dir, "img.png"))
59
+ for j in range(nh):
60
+ fname = os.path.join(output_dir, "attn-head" + str(j) + ".png")
61
+ plt.figure()
62
+ plt.imshow(attentions[j])
63
+ plt.imsave(fname=fname, arr=attentions[j], format='png')
64
+ images = []
65
+ for j in range(nh):
66
+ images.append(Image.open(os.path.join(output_dir, "attn-head" + str(j) + ".png")))
67
+ return images
68
+
69
+
70
+ text_input = gr.Textbox(label="Enter the name of the model to use and optionally add in your own image jpg url with ; as a separator try out this: facebook/dino-vits8; https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/481px-Cat03.jpg", placeholder = "facebook/dino-vits8; optionalurl.jpg")
71
+ attention_output = gr.Gallery(label="Attention Map")
72
+
73
+ iface = gr.Interface(
74
+ fn=visualize_attention,
75
+ inputs=text_input,
76
+ outputs=attention_output,
77
+ live=True,
78
+ capture_session=True,
79
+ title="Visualize Attention Maps",
80
+ description="This app uses a Vision Transformer to visualize the attention maps of an image.",
81
+ )
82
+
83
+ iface.launch()