Spaces:
Sleeping
Sleeping
Commit
·
ba377fa
1
Parent(s):
5daa6d0
improve visualization code
Browse files
utils/visualize_att_maps.py
CHANGED
@@ -26,7 +26,6 @@ class VisualizeAttentionMaps:
|
|
26 |
|
27 |
self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
|
28 |
self.num_parts = num_parts
|
29 |
-
self.req_colors = colors[:num_parts]
|
30 |
self.figs_size = (10, 10)
|
31 |
|
32 |
@torch.no_grad()
|
@@ -45,7 +44,11 @@ class VisualizeAttentionMaps:
|
|
45 |
map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
|
46 |
mode='bilinear',
|
47 |
align_corners=True).argmax(dim=1).cpu().numpy()
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
bg_label=self.bg_label, alpha=self.alpha)
|
51 |
return curr_map
|
|
|
26 |
|
27 |
self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
|
28 |
self.num_parts = num_parts
|
|
|
29 |
self.figs_size = (10, 10)
|
30 |
|
31 |
@torch.no_grad()
|
|
|
44 |
map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
|
45 |
mode='bilinear',
|
46 |
align_corners=True).argmax(dim=1).cpu().numpy()
|
47 |
+
# Select colors for parts which are present
|
48 |
+
parts_present = np.unique(map_argmax).tolist()
|
49 |
+
if self.bg_label in parts_present:
|
50 |
+
parts_present.remove(self.bg_label)
|
51 |
+
colors_present = [colors[i] for i in parts_present]
|
52 |
+
curr_map = skimage.color.label2rgb(label=map_argmax[0], image=ims[0], colors=colors_present,
|
53 |
bg_label=self.bg_label, alpha=self.alpha)
|
54 |
return curr_map
|