File size: 6,133 Bytes
20239f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import colorcet as cc
import numpy as np
import skimage
from pathlib import Path
import os
import torch

from utils.data_utils.transform_utils import inverse_normalize_w_resize
from utils.misc_utils import factors

# Define the colors to use for the attention maps
colors = cc.glasbey_category10


class VisualizeAttentionMaps:
    def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5, sub_path_test="",
                 dataset_name="", bg_label=0, batch_size=32, num_parts=15, plot_ims_separately=False,
                 plot_landmark_amaps=False):
        """
        Plot attention maps and optionally landmark centroids on images.
        :param snapshot_dir: Directory to save the visualization results
        :param save_resolution: Size of the images to save
        :param alpha: The transparency of the attention maps
        :param sub_path_test: The sub-path of the test dataset
        :param dataset_name: The name of the dataset
        :param bg_label: The background label index in the attention maps
        :param batch_size: The batch size
        :param num_parts: The number of parts in the attention maps
        :param plot_ims_separately: Whether to plot the images separately
        :param plot_landmark_amaps: Whether to plot the landmark attention maps
        """
        self.save_resolution = save_resolution
        self.alpha = alpha
        self.sub_path_test = sub_path_test
        self.dataset_name = dataset_name
        self.bg_label = bg_label
        self.snapshot_dir = snapshot_dir

        self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
        self.batch_size = batch_size
        self.nrows = factors(self.batch_size)[-1]
        self.ncols = factors(self.batch_size)[-2]
        self.num_parts = num_parts
        self.req_colors = colors[:num_parts]
        self.plot_ims_separately = plot_ims_separately
        self.plot_landmark_amaps = plot_landmark_amaps
        if self.nrows == 1 and self.ncols == 1:
            self.figs_size = (10, 10)
        else:
            self.figs_size = (self.ncols * 2, self.nrows * 2)

    def recalculate_nrows_ncols(self):
        self.nrows = factors(self.batch_size)[-1]
        self.ncols = factors(self.batch_size)[-2]
        if self.nrows == 1 and self.ncols == 1:
            self.figs_size = (10, 10)
        else:
            self.figs_size = (self.ncols * 2, self.nrows * 2)

    @torch.no_grad()
    def show_maps(self, ims, maps, epoch=0, curr_iter=0, extra_info=""):
        """
        Plot images, attention maps and landmark centroids.
        Parameters
        ----------
        ims: Tensor, [batch_size, 3, width_im, height_im]
            Input images on which to show the attention maps
        maps: Tensor, [batch_size, number of parts + 1, width_map, height_map]
            The attention maps to display
        epoch: int
            The epoch number
        curr_iter: int
            The current iteration number
        extra_info: str
            Any extra information to add to the file name
        """
        ims = self.resize_unnorm(ims)
        if ims.shape[0] != self.batch_size:
            self.batch_size = ims.shape[0]
            self.recalculate_nrows_ncols()
        fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
        ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
        map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
                                                     mode='bilinear',
                                                     align_corners=True).argmax(dim=1).cpu().numpy()
        for i, ax in enumerate(axs.ravel()):
            curr_map = skimage.color.label2rgb(label=map_argmax[i], image=ims[i], colors=self.req_colors,
                                               bg_label=self.bg_label, alpha=self.alpha)
            ax.imshow(curr_map)
            ax.axis('off')
        save_dir = Path(os.path.join(self.snapshot_dir, 'results_vis_' + self.sub_path_test))
        save_dir.mkdir(parents=True, exist_ok=True)
        save_path = os.path.join(save_dir, f'{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
        fig.tight_layout()
        if self.snapshot_dir != "":
            plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
        else:
            plt.show()
        plt.close('all')

        if self.plot_ims_separately:
            fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
            for i, ax in enumerate(axs.ravel()):
                ax.imshow(ims[i])
                ax.axis('off')
            save_path = os.path.join(save_dir, f'image_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.jpg')
            fig.tight_layout()
            if self.snapshot_dir != "":
                plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
            else:
                plt.show()
        plt.close('all')

        if self.plot_landmark_amaps:
            if self.batch_size > 1:
                raise ValueError('Not implemented for batch size > 1')
            for i in range(self.num_parts):
                fig, ax = plt.subplots(1, 1, figsize=self.figs_size)
                divider = make_axes_locatable(ax)
                cax = divider.append_axes('right', size='5%', pad=0.05)
                im = ax.imshow(maps[0, i, ...].detach().cpu().numpy(), cmap='cet_gouldian')
                fig.colorbar(im, cax=cax, orientation='vertical')
                ax.axis('off')
                save_path = os.path.join(save_dir,
                                         f'landmark_{i}_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
                fig.tight_layout()
                if self.snapshot_dir != "":
                    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
                else:
                    plt.show()
                plt.close()

        plt.close('all')