ananthu-aniraj commited on
Commit
ba377fa
·
1 Parent(s): 5daa6d0

improve visualization code

Browse files
Files changed (1) hide show
  1. utils/visualize_att_maps.py +6 -3
utils/visualize_att_maps.py CHANGED
@@ -26,7 +26,6 @@ class VisualizeAttentionMaps:
26
 
27
  self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
28
  self.num_parts = num_parts
29
- self.req_colors = colors[:num_parts]
30
  self.figs_size = (10, 10)
31
 
32
  @torch.no_grad()
@@ -45,7 +44,11 @@ class VisualizeAttentionMaps:
45
  map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
46
  mode='bilinear',
47
  align_corners=True).argmax(dim=1).cpu().numpy()
48
-
49
- curr_map = skimage.color.label2rgb(label=map_argmax[0], image=ims[0], colors=self.req_colors,
 
 
 
 
50
  bg_label=self.bg_label, alpha=self.alpha)
51
  return curr_map
 
26
 
27
  self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
28
  self.num_parts = num_parts
 
29
  self.figs_size = (10, 10)
30
 
31
  @torch.no_grad()
 
44
  map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
45
  mode='bilinear',
46
  align_corners=True).argmax(dim=1).cpu().numpy()
47
+ # Select colors for parts which are present
48
+ parts_present = np.unique(map_argmax).tolist()
49
+ if self.bg_label in parts_present:
50
+ parts_present.remove(self.bg_label)
51
+ colors_present = [colors[i] for i in parts_present]
52
+ curr_map = skimage.color.label2rgb(label=map_argmax[0], image=ims[0], colors=colors_present,
53
  bg_label=self.bg_label, alpha=self.alpha)
54
  return curr_map