File size: 2,374 Bytes
20239f9
 
 
 
 
5662f96
20239f9
 
 
 
 
 
a8d9779
20239f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8d9779
20239f9
 
a8d9779
20239f9
 
 
 
 
 
 
 
 
 
 
 
 
 
ba377fa
 
 
 
 
 
a8d9779
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import colorcet as cc
import numpy as np
import skimage
import torch

from utils.transform_utils import inverse_normalize_w_resize

# Define the colors to use for the attention maps
colors = cc.glasbey_category10


class VisualizeAttentionMaps:
    def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5, bg_label=0, num_parts=15):
        """
        Plot attention maps and optionally landmark centroids on images.
        :param snapshot_dir: Directory to save the visualization results
        :param save_resolution: Size of the images to save
        :param alpha: The transparency of the attention maps
        :param bg_label: The background label index in the attention maps
        :param num_parts: The number of parts in the attention maps
        """
        self.save_resolution = save_resolution
        self.alpha = alpha
        self.bg_label = bg_label
        self.snapshot_dir = snapshot_dir

        self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
        self.num_parts = num_parts
        self.figs_size = (10, 10)

    @torch.no_grad()
    def show_maps(self, ims, maps):
        """
        Plot images, attention maps and landmark centroids.
        Parameters
        ----------
        ims: Tensor, [batch_size, 3, width_im, height_im]
            Input images on which to show the attention maps
        maps: Tensor, [batch_size, number of parts + 1, width_map, height_map]
            The attention maps to display
        """
        ims = self.resize_unnorm(ims)
        ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
        map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
                                                     mode='bilinear',
                                                     align_corners=True).argmax(dim=1).cpu().numpy()
        # Select colors for parts which are present
        parts_present = np.unique(map_argmax).tolist()
        if self.bg_label in parts_present:
            parts_present.remove(self.bg_label)
        colors_present = [colors[i] for i in parts_present]
        curr_map = skimage.color.label2rgb(label=map_argmax[0], image=ims[0], colors=colors_present,
                                           bg_label=self.bg_label, alpha=self.alpha)
        return curr_map