Spaces:
Sleeping
Sleeping
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
import colorcet as cc | |
import numpy as np | |
import skimage | |
from pathlib import Path | |
import os | |
import torch | |
from utils.data_utils.transform_utils import inverse_normalize_w_resize | |
from utils.misc_utils import factors | |
# Define the colors to use for the attention maps | |
colors = cc.glasbey_category10 | |
class VisualizeAttentionMaps: | |
def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5, sub_path_test="", | |
dataset_name="", bg_label=0, batch_size=32, num_parts=15, plot_ims_separately=False, | |
plot_landmark_amaps=False): | |
""" | |
Plot attention maps and optionally landmark centroids on images. | |
:param snapshot_dir: Directory to save the visualization results | |
:param save_resolution: Size of the images to save | |
:param alpha: The transparency of the attention maps | |
:param sub_path_test: The sub-path of the test dataset | |
:param dataset_name: The name of the dataset | |
:param bg_label: The background label index in the attention maps | |
:param batch_size: The batch size | |
:param num_parts: The number of parts in the attention maps | |
:param plot_ims_separately: Whether to plot the images separately | |
:param plot_landmark_amaps: Whether to plot the landmark attention maps | |
""" | |
self.save_resolution = save_resolution | |
self.alpha = alpha | |
self.sub_path_test = sub_path_test | |
self.dataset_name = dataset_name | |
self.bg_label = bg_label | |
self.snapshot_dir = snapshot_dir | |
self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution) | |
self.batch_size = batch_size | |
self.nrows = factors(self.batch_size)[-1] | |
self.ncols = factors(self.batch_size)[-2] | |
self.num_parts = num_parts | |
self.req_colors = colors[:num_parts] | |
self.plot_ims_separately = plot_ims_separately | |
self.plot_landmark_amaps = plot_landmark_amaps | |
if self.nrows == 1 and self.ncols == 1: | |
self.figs_size = (10, 10) | |
else: | |
self.figs_size = (self.ncols * 2, self.nrows * 2) | |
def recalculate_nrows_ncols(self): | |
self.nrows = factors(self.batch_size)[-1] | |
self.ncols = factors(self.batch_size)[-2] | |
if self.nrows == 1 and self.ncols == 1: | |
self.figs_size = (10, 10) | |
else: | |
self.figs_size = (self.ncols * 2, self.nrows * 2) | |
def show_maps(self, ims, maps, epoch=0, curr_iter=0, extra_info=""): | |
""" | |
Plot images, attention maps and landmark centroids. | |
Parameters | |
---------- | |
ims: Tensor, [batch_size, 3, width_im, height_im] | |
Input images on which to show the attention maps | |
maps: Tensor, [batch_size, number of parts + 1, width_map, height_map] | |
The attention maps to display | |
epoch: int | |
The epoch number | |
curr_iter: int | |
The current iteration number | |
extra_info: str | |
Any extra information to add to the file name | |
""" | |
ims = self.resize_unnorm(ims) | |
if ims.shape[0] != self.batch_size: | |
self.batch_size = ims.shape[0] | |
self.recalculate_nrows_ncols() | |
fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size) | |
ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8) | |
map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution, | |
mode='bilinear', | |
align_corners=True).argmax(dim=1).cpu().numpy() | |
for i, ax in enumerate(axs.ravel()): | |
curr_map = skimage.color.label2rgb(label=map_argmax[i], image=ims[i], colors=self.req_colors, | |
bg_label=self.bg_label, alpha=self.alpha) | |
ax.imshow(curr_map) | |
ax.axis('off') | |
save_dir = Path(os.path.join(self.snapshot_dir, 'results_vis_' + self.sub_path_test)) | |
save_dir.mkdir(parents=True, exist_ok=True) | |
save_path = os.path.join(save_dir, f'{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png') | |
fig.tight_layout() | |
if self.snapshot_dir != "": | |
plt.savefig(save_path, bbox_inches='tight', pad_inches=0) | |
else: | |
plt.show() | |
plt.close('all') | |
if self.plot_ims_separately: | |
fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size) | |
for i, ax in enumerate(axs.ravel()): | |
ax.imshow(ims[i]) | |
ax.axis('off') | |
save_path = os.path.join(save_dir, f'image_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.jpg') | |
fig.tight_layout() | |
if self.snapshot_dir != "": | |
plt.savefig(save_path, bbox_inches='tight', pad_inches=0) | |
else: | |
plt.show() | |
plt.close('all') | |
if self.plot_landmark_amaps: | |
if self.batch_size > 1: | |
raise ValueError('Not implemented for batch size > 1') | |
for i in range(self.num_parts): | |
fig, ax = plt.subplots(1, 1, figsize=self.figs_size) | |
divider = make_axes_locatable(ax) | |
cax = divider.append_axes('right', size='5%', pad=0.05) | |
im = ax.imshow(maps[0, i, ...].detach().cpu().numpy(), cmap='cet_gouldian') | |
fig.colorbar(im, cax=cax, orientation='vertical') | |
ax.axis('off') | |
save_path = os.path.join(save_dir, | |
f'landmark_{i}_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png') | |
fig.tight_layout() | |
if self.snapshot_dir != "": | |
plt.savefig(save_path, bbox_inches='tight', pad_inches=0) | |
else: | |
plt.show() | |
plt.close() | |
plt.close('all') | |