ananthu-aniraj commited on
Commit
a8d9779
·
1 Parent(s): 20239f9

upload initial version

Browse files
app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ from models import IndividualLandmarkViT
5
+ from utils import VisualizeAttentionMaps
6
+ from utils.data_utils.transform_utils import make_test_transforms
7
+
8
+
9
+ st.title("Pdiscoformer Part Discovery Visualizer")
10
+ # Set the device
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ # Load the model
13
+ model = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_cub_k_8").eval().to(device)
14
+
15
+ amap_vis = VisualizeAttentionMaps(num_parts=9, bg_label=8)
16
+
17
+ image_size = 518
18
+ test_transforms = make_test_transforms(image_size)
19
+ image_name = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) # Upload an image
20
+ if image_name is not None:
21
+ image = Image.open(image_name).convert("RGB")
22
+ image_tensor = test_transforms(image).unsqueeze(0).to(device)
23
+ with torch.no_grad():
24
+ maps, scores = model(image_tensor)
25
+
26
+ coloured_map = amap_vis.show_maps(image_tensor, maps)
27
+ st.image(coloured_map, caption="Attention Map", use_column_width=True)
28
+
models/individual_landmark_vit.py CHANGED
@@ -26,11 +26,10 @@ class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
26
  part_dropout: float = 0.3, return_transformer_qkv: bool = False,
27
  modulation_type: str = "original", gumbel_softmax: bool = False,
28
  gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
29
- modulation_orth: bool = False, classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
30
  super().__init__()
31
  self.num_landmarks = num_landmarks
32
  self.num_classes = num_classes
33
- self.noise_variance = noise_variance
34
  self.num_prefix_tokens = init_model.num_prefix_tokens
35
  self.num_reg_tokens = init_model.num_reg_tokens
36
  self.has_class_token = init_model.has_class_token
@@ -75,7 +74,6 @@ class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
75
  self.modulation = torch.nn.Identity()
76
  else:
77
  raise ValueError("modulation_type not implemented")
78
- self.modulation_orth = modulation_orth
79
  self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
80
  self.classifier_type = classifier_type
81
  if classifier_type == "independent_mlp":
@@ -168,10 +166,6 @@ class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
168
 
169
  # Use maps to get weighted average features per landmark
170
  all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).contiguous()
171
- if self.noise_variance > 0.0:
172
- all_features += torch.randn_like(all_features,
173
- device=all_features.device) * x.std().detach() * self.noise_variance
174
-
175
  all_features = all_features.mean(-1).mean(-1).contiguous() # [B, embed_dim, num_landmarks + 1]
176
 
177
  # Modulate the features
@@ -184,10 +178,9 @@ class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
184
  scores = self.fc_class_landmarks(
185
  self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
186
  1).contiguous()
187
- if self.modulation_orth:
188
- return all_features_mod, maps, scores, dist
189
- else:
190
- return all_features, maps, scores, dist
191
 
192
  def get_specific_intermediate_layer(
193
  self,
 
26
  part_dropout: float = 0.3, return_transformer_qkv: bool = False,
27
  modulation_type: str = "original", gumbel_softmax: bool = False,
28
  gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
29
+ classifier_type: str = "linear") -> None:
30
  super().__init__()
31
  self.num_landmarks = num_landmarks
32
  self.num_classes = num_classes
 
33
  self.num_prefix_tokens = init_model.num_prefix_tokens
34
  self.num_reg_tokens = init_model.num_reg_tokens
35
  self.has_class_token = init_model.has_class_token
 
74
  self.modulation = torch.nn.Identity()
75
  else:
76
  raise ValueError("modulation_type not implemented")
 
77
  self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
78
  self.classifier_type = classifier_type
79
  if classifier_type == "independent_mlp":
 
166
 
167
  # Use maps to get weighted average features per landmark
168
  all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).contiguous()
 
 
 
 
169
  all_features = all_features.mean(-1).mean(-1).contiguous() # [B, embed_dim, num_landmarks + 1]
170
 
171
  # Modulate the features
 
178
  scores = self.fc_class_landmarks(
179
  self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
180
  1).contiguous()
181
+ scores = scores.mean(dim=-1) # [B, num_classes]
182
+
183
+ return maps, scores
 
184
 
185
  def get_specific_intermediate_layer(
186
  self,
requirements.txt CHANGED
@@ -7,4 +7,5 @@ streamlit
7
  numpy
8
  pillow
9
  scikit-image
10
- huggingface-hub
 
 
7
  numpy
8
  pillow
9
  scikit-image
10
+ huggingface-hub
11
+ opencv-python
utils/data_utils/transform_utils.py CHANGED
@@ -3,29 +3,13 @@ from torchvision import transforms as transforms
3
  from torchvision.transforms import Compose
4
 
5
  from timm.data.constants import \
6
- IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
7
- from timm.data import create_transform
8
 
9
 
10
- def make_train_transforms(args):
11
- train_transforms: Compose = transforms.Compose([
12
- transforms.Resize(size=args.image_size, antialias=True),
13
- transforms.RandomHorizontalFlip(p=args.hflip),
14
- transforms.RandomVerticalFlip(p=args.vflip),
15
- transforms.ColorJitter(),
16
- transforms.RandomAffine(degrees=90, translate=(0.2, 0.2), scale=(0.8, 1.2)),
17
- transforms.RandomCrop(args.image_size),
18
- transforms.ToTensor(),
19
- transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
20
-
21
- ])
22
- return train_transforms
23
-
24
-
25
- def make_test_transforms(args):
26
  test_transforms: Compose = transforms.Compose([
27
- transforms.Resize(size=args.image_size, antialias=True),
28
- transforms.CenterCrop(args.image_size),
29
  transforms.ToTensor(),
30
  transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
31
 
@@ -33,57 +17,6 @@ def make_test_transforms(args):
33
  return test_transforms
34
 
35
 
36
- def build_transform_timm(args, is_train=True):
37
- resize_im = args.image_size > 32
38
- imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
39
- mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
40
- std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
41
-
42
- if is_train:
43
- # this should always dispatch to transforms_imagenet_train
44
- transform = create_transform(
45
- input_size=args.image_size,
46
- is_training=True,
47
- color_jitter=args.color_jitter,
48
- hflip=args.hflip,
49
- vflip=args.vflip,
50
- auto_augment=args.aa,
51
- interpolation=args.train_interpolation,
52
- re_prob=args.reprob,
53
- re_mode=args.remode,
54
- re_count=args.recount,
55
- mean=mean,
56
- std=std,
57
- )
58
- if not resize_im:
59
- transform.transforms[0] = transforms.RandomCrop(
60
- args.image_size, padding=4)
61
- return transform
62
-
63
- t = []
64
- if resize_im:
65
- # warping (no cropping) when evaluated at 384 or larger
66
- if args.image_size >= 384:
67
- t.append(
68
- transforms.Resize((args.image_size, args.image_size),
69
- interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
70
- )
71
- print(f"Warping {args.image_size} size input images...")
72
- else:
73
- if args.crop_pct is None:
74
- args.crop_pct = 224 / 256
75
- size = int(args.image_size / args.crop_pct)
76
- t.append(
77
- # to maintain same ratio w.r.t. 224 images
78
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
79
- )
80
- t.append(transforms.CenterCrop(args.image_size))
81
-
82
- t.append(transforms.ToTensor())
83
- t.append(transforms.Normalize(mean, std))
84
- return transforms.Compose(t)
85
-
86
-
87
  def inverse_normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
88
  mean = torch.as_tensor(mean)
89
  std = torch.as_tensor(std)
@@ -104,15 +37,3 @@ def inverse_normalize_w_resize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_
104
  transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()),
105
  transforms.Resize(size=resize_resolution, antialias=True)])
106
  return resize_unnorm
107
-
108
-
109
- def load_transforms(args):
110
- # Get the transforms and load the dataset
111
- if args.augmentations_to_use == 'timm':
112
- train_transforms = build_transform_timm(args, is_train=True)
113
- elif args.augmentations_to_use == 'cub_original':
114
- train_transforms = make_train_transforms(args)
115
- else:
116
- raise ValueError('Augmentations not supported.')
117
- test_transforms = make_test_transforms(args)
118
- return train_transforms, test_transforms
 
3
  from torchvision.transforms import Compose
4
 
5
  from timm.data.constants import \
6
+ IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
 
7
 
8
 
9
+ def make_test_transforms(image_size):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  test_transforms: Compose = transforms.Compose([
11
+ transforms.Resize(size=image_size, antialias=True),
12
+ transforms.CenterCrop(image_size),
13
  transforms.ToTensor(),
14
  transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
15
 
 
17
  return test_transforms
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def inverse_normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
21
  mean = torch.as_tensor(mean)
22
  std = torch.as_tensor(std)
 
37
  transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()),
38
  transforms.Resize(size=resize_resolution, antialias=True)])
39
  return resize_unnorm
 
 
 
 
 
 
 
 
 
 
 
 
utils/visualize_att_maps.py CHANGED
@@ -1,66 +1,36 @@
1
- import matplotlib.pyplot as plt
2
- from mpl_toolkits.axes_grid1 import make_axes_locatable
3
  import colorcet as cc
4
  import numpy as np
5
  import skimage
6
- from pathlib import Path
7
- import os
8
  import torch
9
 
10
  from utils.data_utils.transform_utils import inverse_normalize_w_resize
11
- from utils.misc_utils import factors
12
 
13
  # Define the colors to use for the attention maps
14
  colors = cc.glasbey_category10
15
 
16
 
17
  class VisualizeAttentionMaps:
18
- def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5, sub_path_test="",
19
- dataset_name="", bg_label=0, batch_size=32, num_parts=15, plot_ims_separately=False,
20
- plot_landmark_amaps=False):
21
  """
22
  Plot attention maps and optionally landmark centroids on images.
23
  :param snapshot_dir: Directory to save the visualization results
24
  :param save_resolution: Size of the images to save
25
  :param alpha: The transparency of the attention maps
26
- :param sub_path_test: The sub-path of the test dataset
27
- :param dataset_name: The name of the dataset
28
  :param bg_label: The background label index in the attention maps
29
- :param batch_size: The batch size
30
  :param num_parts: The number of parts in the attention maps
31
- :param plot_ims_separately: Whether to plot the images separately
32
- :param plot_landmark_amaps: Whether to plot the landmark attention maps
33
  """
34
  self.save_resolution = save_resolution
35
  self.alpha = alpha
36
- self.sub_path_test = sub_path_test
37
- self.dataset_name = dataset_name
38
  self.bg_label = bg_label
39
  self.snapshot_dir = snapshot_dir
40
 
41
  self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
42
- self.batch_size = batch_size
43
- self.nrows = factors(self.batch_size)[-1]
44
- self.ncols = factors(self.batch_size)[-2]
45
  self.num_parts = num_parts
46
  self.req_colors = colors[:num_parts]
47
- self.plot_ims_separately = plot_ims_separately
48
- self.plot_landmark_amaps = plot_landmark_amaps
49
- if self.nrows == 1 and self.ncols == 1:
50
- self.figs_size = (10, 10)
51
- else:
52
- self.figs_size = (self.ncols * 2, self.nrows * 2)
53
-
54
- def recalculate_nrows_ncols(self):
55
- self.nrows = factors(self.batch_size)[-1]
56
- self.ncols = factors(self.batch_size)[-2]
57
- if self.nrows == 1 and self.ncols == 1:
58
- self.figs_size = (10, 10)
59
- else:
60
- self.figs_size = (self.ncols * 2, self.nrows * 2)
61
 
62
  @torch.no_grad()
63
- def show_maps(self, ims, maps, epoch=0, curr_iter=0, extra_info=""):
64
  """
65
  Plot images, attention maps and landmark centroids.
66
  Parameters
@@ -69,67 +39,13 @@ class VisualizeAttentionMaps:
69
  Input images on which to show the attention maps
70
  maps: Tensor, [batch_size, number of parts + 1, width_map, height_map]
71
  The attention maps to display
72
- epoch: int
73
- The epoch number
74
- curr_iter: int
75
- The current iteration number
76
- extra_info: str
77
- Any extra information to add to the file name
78
  """
79
  ims = self.resize_unnorm(ims)
80
- if ims.shape[0] != self.batch_size:
81
- self.batch_size = ims.shape[0]
82
- self.recalculate_nrows_ncols()
83
- fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
84
  ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
85
  map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
86
  mode='bilinear',
87
  align_corners=True).argmax(dim=1).cpu().numpy()
88
- for i, ax in enumerate(axs.ravel()):
89
- curr_map = skimage.color.label2rgb(label=map_argmax[i], image=ims[i], colors=self.req_colors,
90
- bg_label=self.bg_label, alpha=self.alpha)
91
- ax.imshow(curr_map)
92
- ax.axis('off')
93
- save_dir = Path(os.path.join(self.snapshot_dir, 'results_vis_' + self.sub_path_test))
94
- save_dir.mkdir(parents=True, exist_ok=True)
95
- save_path = os.path.join(save_dir, f'{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
96
- fig.tight_layout()
97
- if self.snapshot_dir != "":
98
- plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
99
- else:
100
- plt.show()
101
- plt.close('all')
102
-
103
- if self.plot_ims_separately:
104
- fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
105
- for i, ax in enumerate(axs.ravel()):
106
- ax.imshow(ims[i])
107
- ax.axis('off')
108
- save_path = os.path.join(save_dir, f'image_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.jpg')
109
- fig.tight_layout()
110
- if self.snapshot_dir != "":
111
- plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
112
- else:
113
- plt.show()
114
- plt.close('all')
115
-
116
- if self.plot_landmark_amaps:
117
- if self.batch_size > 1:
118
- raise ValueError('Not implemented for batch size > 1')
119
- for i in range(self.num_parts):
120
- fig, ax = plt.subplots(1, 1, figsize=self.figs_size)
121
- divider = make_axes_locatable(ax)
122
- cax = divider.append_axes('right', size='5%', pad=0.05)
123
- im = ax.imshow(maps[0, i, ...].detach().cpu().numpy(), cmap='cet_gouldian')
124
- fig.colorbar(im, cax=cax, orientation='vertical')
125
- ax.axis('off')
126
- save_path = os.path.join(save_dir,
127
- f'landmark_{i}_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
128
- fig.tight_layout()
129
- if self.snapshot_dir != "":
130
- plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
131
- else:
132
- plt.show()
133
- plt.close()
134
 
135
- plt.close('all')
 
 
 
 
 
1
  import colorcet as cc
2
  import numpy as np
3
  import skimage
 
 
4
  import torch
5
 
6
  from utils.data_utils.transform_utils import inverse_normalize_w_resize
 
7
 
8
  # Define the colors to use for the attention maps
9
  colors = cc.glasbey_category10
10
 
11
 
12
  class VisualizeAttentionMaps:
13
+ def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5, bg_label=0, num_parts=15):
 
 
14
  """
15
  Plot attention maps and optionally landmark centroids on images.
16
  :param snapshot_dir: Directory to save the visualization results
17
  :param save_resolution: Size of the images to save
18
  :param alpha: The transparency of the attention maps
 
 
19
  :param bg_label: The background label index in the attention maps
 
20
  :param num_parts: The number of parts in the attention maps
 
 
21
  """
22
  self.save_resolution = save_resolution
23
  self.alpha = alpha
 
 
24
  self.bg_label = bg_label
25
  self.snapshot_dir = snapshot_dir
26
 
27
  self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
 
 
 
28
  self.num_parts = num_parts
29
  self.req_colors = colors[:num_parts]
30
+ self.figs_size = (10, 10)
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  @torch.no_grad()
33
+ def show_maps(self, ims, maps):
34
  """
35
  Plot images, attention maps and landmark centroids.
36
  Parameters
 
39
  Input images on which to show the attention maps
40
  maps: Tensor, [batch_size, number of parts + 1, width_map, height_map]
41
  The attention maps to display
 
 
 
 
 
 
42
  """
43
  ims = self.resize_unnorm(ims)
 
 
 
 
44
  ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
45
  map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
46
  mode='bilinear',
47
  align_corners=True).argmax(dim=1).cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ curr_map = skimage.color.label2rgb(label=map_argmax[0], image=ims[0], colors=self.req_colors,
50
+ bg_label=self.bg_label, alpha=self.alpha)
51
+ return curr_map