Spaces:
Sleeping
Sleeping
Commit
·
a8d9779
1
Parent(s):
20239f9
upload initial version
Browse files- app.py +28 -0
- models/individual_landmark_vit.py +4 -11
- requirements.txt +2 -1
- utils/data_utils/transform_utils.py +4 -83
- utils/visualize_att_maps.py +6 -90
app.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from models import IndividualLandmarkViT
|
5 |
+
from utils import VisualizeAttentionMaps
|
6 |
+
from utils.data_utils.transform_utils import make_test_transforms
|
7 |
+
|
8 |
+
|
9 |
+
st.title("Pdiscoformer Part Discovery Visualizer")
|
10 |
+
# Set the device
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
# Load the model
|
13 |
+
model = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_cub_k_8").eval().to(device)
|
14 |
+
|
15 |
+
amap_vis = VisualizeAttentionMaps(num_parts=9, bg_label=8)
|
16 |
+
|
17 |
+
image_size = 518
|
18 |
+
test_transforms = make_test_transforms(image_size)
|
19 |
+
image_name = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) # Upload an image
|
20 |
+
if image_name is not None:
|
21 |
+
image = Image.open(image_name).convert("RGB")
|
22 |
+
image_tensor = test_transforms(image).unsqueeze(0).to(device)
|
23 |
+
with torch.no_grad():
|
24 |
+
maps, scores = model(image_tensor)
|
25 |
+
|
26 |
+
coloured_map = amap_vis.show_maps(image_tensor, maps)
|
27 |
+
st.image(coloured_map, caption="Attention Map", use_column_width=True)
|
28 |
+
|
models/individual_landmark_vit.py
CHANGED
@@ -26,11 +26,10 @@ class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
|
|
26 |
part_dropout: float = 0.3, return_transformer_qkv: bool = False,
|
27 |
modulation_type: str = "original", gumbel_softmax: bool = False,
|
28 |
gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
|
29 |
-
|
30 |
super().__init__()
|
31 |
self.num_landmarks = num_landmarks
|
32 |
self.num_classes = num_classes
|
33 |
-
self.noise_variance = noise_variance
|
34 |
self.num_prefix_tokens = init_model.num_prefix_tokens
|
35 |
self.num_reg_tokens = init_model.num_reg_tokens
|
36 |
self.has_class_token = init_model.has_class_token
|
@@ -75,7 +74,6 @@ class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
|
|
75 |
self.modulation = torch.nn.Identity()
|
76 |
else:
|
77 |
raise ValueError("modulation_type not implemented")
|
78 |
-
self.modulation_orth = modulation_orth
|
79 |
self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
|
80 |
self.classifier_type = classifier_type
|
81 |
if classifier_type == "independent_mlp":
|
@@ -168,10 +166,6 @@ class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
|
|
168 |
|
169 |
# Use maps to get weighted average features per landmark
|
170 |
all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).contiguous()
|
171 |
-
if self.noise_variance > 0.0:
|
172 |
-
all_features += torch.randn_like(all_features,
|
173 |
-
device=all_features.device) * x.std().detach() * self.noise_variance
|
174 |
-
|
175 |
all_features = all_features.mean(-1).mean(-1).contiguous() # [B, embed_dim, num_landmarks + 1]
|
176 |
|
177 |
# Modulate the features
|
@@ -184,10 +178,9 @@ class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
|
|
184 |
scores = self.fc_class_landmarks(
|
185 |
self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
|
186 |
1).contiguous()
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
return all_features, maps, scores, dist
|
191 |
|
192 |
def get_specific_intermediate_layer(
|
193 |
self,
|
|
|
26 |
part_dropout: float = 0.3, return_transformer_qkv: bool = False,
|
27 |
modulation_type: str = "original", gumbel_softmax: bool = False,
|
28 |
gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
|
29 |
+
classifier_type: str = "linear") -> None:
|
30 |
super().__init__()
|
31 |
self.num_landmarks = num_landmarks
|
32 |
self.num_classes = num_classes
|
|
|
33 |
self.num_prefix_tokens = init_model.num_prefix_tokens
|
34 |
self.num_reg_tokens = init_model.num_reg_tokens
|
35 |
self.has_class_token = init_model.has_class_token
|
|
|
74 |
self.modulation = torch.nn.Identity()
|
75 |
else:
|
76 |
raise ValueError("modulation_type not implemented")
|
|
|
77 |
self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
|
78 |
self.classifier_type = classifier_type
|
79 |
if classifier_type == "independent_mlp":
|
|
|
166 |
|
167 |
# Use maps to get weighted average features per landmark
|
168 |
all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).contiguous()
|
|
|
|
|
|
|
|
|
169 |
all_features = all_features.mean(-1).mean(-1).contiguous() # [B, embed_dim, num_landmarks + 1]
|
170 |
|
171 |
# Modulate the features
|
|
|
178 |
scores = self.fc_class_landmarks(
|
179 |
self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
|
180 |
1).contiguous()
|
181 |
+
scores = scores.mean(dim=-1) # [B, num_classes]
|
182 |
+
|
183 |
+
return maps, scores
|
|
|
184 |
|
185 |
def get_specific_intermediate_layer(
|
186 |
self,
|
requirements.txt
CHANGED
@@ -7,4 +7,5 @@ streamlit
|
|
7 |
numpy
|
8 |
pillow
|
9 |
scikit-image
|
10 |
-
huggingface-hub
|
|
|
|
7 |
numpy
|
8 |
pillow
|
9 |
scikit-image
|
10 |
+
huggingface-hub
|
11 |
+
opencv-python
|
utils/data_utils/transform_utils.py
CHANGED
@@ -3,29 +3,13 @@ from torchvision import transforms as transforms
|
|
3 |
from torchvision.transforms import Compose
|
4 |
|
5 |
from timm.data.constants import \
|
6 |
-
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
7 |
-
from timm.data import create_transform
|
8 |
|
9 |
|
10 |
-
def
|
11 |
-
train_transforms: Compose = transforms.Compose([
|
12 |
-
transforms.Resize(size=args.image_size, antialias=True),
|
13 |
-
transforms.RandomHorizontalFlip(p=args.hflip),
|
14 |
-
transforms.RandomVerticalFlip(p=args.vflip),
|
15 |
-
transforms.ColorJitter(),
|
16 |
-
transforms.RandomAffine(degrees=90, translate=(0.2, 0.2), scale=(0.8, 1.2)),
|
17 |
-
transforms.RandomCrop(args.image_size),
|
18 |
-
transforms.ToTensor(),
|
19 |
-
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
|
20 |
-
|
21 |
-
])
|
22 |
-
return train_transforms
|
23 |
-
|
24 |
-
|
25 |
-
def make_test_transforms(args):
|
26 |
test_transforms: Compose = transforms.Compose([
|
27 |
-
transforms.Resize(size=
|
28 |
-
transforms.CenterCrop(
|
29 |
transforms.ToTensor(),
|
30 |
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
|
31 |
|
@@ -33,57 +17,6 @@ def make_test_transforms(args):
|
|
33 |
return test_transforms
|
34 |
|
35 |
|
36 |
-
def build_transform_timm(args, is_train=True):
|
37 |
-
resize_im = args.image_size > 32
|
38 |
-
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
|
39 |
-
mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
|
40 |
-
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
|
41 |
-
|
42 |
-
if is_train:
|
43 |
-
# this should always dispatch to transforms_imagenet_train
|
44 |
-
transform = create_transform(
|
45 |
-
input_size=args.image_size,
|
46 |
-
is_training=True,
|
47 |
-
color_jitter=args.color_jitter,
|
48 |
-
hflip=args.hflip,
|
49 |
-
vflip=args.vflip,
|
50 |
-
auto_augment=args.aa,
|
51 |
-
interpolation=args.train_interpolation,
|
52 |
-
re_prob=args.reprob,
|
53 |
-
re_mode=args.remode,
|
54 |
-
re_count=args.recount,
|
55 |
-
mean=mean,
|
56 |
-
std=std,
|
57 |
-
)
|
58 |
-
if not resize_im:
|
59 |
-
transform.transforms[0] = transforms.RandomCrop(
|
60 |
-
args.image_size, padding=4)
|
61 |
-
return transform
|
62 |
-
|
63 |
-
t = []
|
64 |
-
if resize_im:
|
65 |
-
# warping (no cropping) when evaluated at 384 or larger
|
66 |
-
if args.image_size >= 384:
|
67 |
-
t.append(
|
68 |
-
transforms.Resize((args.image_size, args.image_size),
|
69 |
-
interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
|
70 |
-
)
|
71 |
-
print(f"Warping {args.image_size} size input images...")
|
72 |
-
else:
|
73 |
-
if args.crop_pct is None:
|
74 |
-
args.crop_pct = 224 / 256
|
75 |
-
size = int(args.image_size / args.crop_pct)
|
76 |
-
t.append(
|
77 |
-
# to maintain same ratio w.r.t. 224 images
|
78 |
-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
|
79 |
-
)
|
80 |
-
t.append(transforms.CenterCrop(args.image_size))
|
81 |
-
|
82 |
-
t.append(transforms.ToTensor())
|
83 |
-
t.append(transforms.Normalize(mean, std))
|
84 |
-
return transforms.Compose(t)
|
85 |
-
|
86 |
-
|
87 |
def inverse_normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
|
88 |
mean = torch.as_tensor(mean)
|
89 |
std = torch.as_tensor(std)
|
@@ -104,15 +37,3 @@ def inverse_normalize_w_resize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_
|
|
104 |
transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()),
|
105 |
transforms.Resize(size=resize_resolution, antialias=True)])
|
106 |
return resize_unnorm
|
107 |
-
|
108 |
-
|
109 |
-
def load_transforms(args):
|
110 |
-
# Get the transforms and load the dataset
|
111 |
-
if args.augmentations_to_use == 'timm':
|
112 |
-
train_transforms = build_transform_timm(args, is_train=True)
|
113 |
-
elif args.augmentations_to_use == 'cub_original':
|
114 |
-
train_transforms = make_train_transforms(args)
|
115 |
-
else:
|
116 |
-
raise ValueError('Augmentations not supported.')
|
117 |
-
test_transforms = make_test_transforms(args)
|
118 |
-
return train_transforms, test_transforms
|
|
|
3 |
from torchvision.transforms import Compose
|
4 |
|
5 |
from timm.data.constants import \
|
6 |
+
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
7 |
|
8 |
|
9 |
+
def make_test_transforms(image_size):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
test_transforms: Compose = transforms.Compose([
|
11 |
+
transforms.Resize(size=image_size, antialias=True),
|
12 |
+
transforms.CenterCrop(image_size),
|
13 |
transforms.ToTensor(),
|
14 |
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
|
15 |
|
|
|
17 |
return test_transforms
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def inverse_normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
|
21 |
mean = torch.as_tensor(mean)
|
22 |
std = torch.as_tensor(std)
|
|
|
37 |
transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()),
|
38 |
transforms.Resize(size=resize_resolution, antialias=True)])
|
39 |
return resize_unnorm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/visualize_att_maps.py
CHANGED
@@ -1,66 +1,36 @@
|
|
1 |
-
import matplotlib.pyplot as plt
|
2 |
-
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
3 |
import colorcet as cc
|
4 |
import numpy as np
|
5 |
import skimage
|
6 |
-
from pathlib import Path
|
7 |
-
import os
|
8 |
import torch
|
9 |
|
10 |
from utils.data_utils.transform_utils import inverse_normalize_w_resize
|
11 |
-
from utils.misc_utils import factors
|
12 |
|
13 |
# Define the colors to use for the attention maps
|
14 |
colors = cc.glasbey_category10
|
15 |
|
16 |
|
17 |
class VisualizeAttentionMaps:
|
18 |
-
def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5,
|
19 |
-
dataset_name="", bg_label=0, batch_size=32, num_parts=15, plot_ims_separately=False,
|
20 |
-
plot_landmark_amaps=False):
|
21 |
"""
|
22 |
Plot attention maps and optionally landmark centroids on images.
|
23 |
:param snapshot_dir: Directory to save the visualization results
|
24 |
:param save_resolution: Size of the images to save
|
25 |
:param alpha: The transparency of the attention maps
|
26 |
-
:param sub_path_test: The sub-path of the test dataset
|
27 |
-
:param dataset_name: The name of the dataset
|
28 |
:param bg_label: The background label index in the attention maps
|
29 |
-
:param batch_size: The batch size
|
30 |
:param num_parts: The number of parts in the attention maps
|
31 |
-
:param plot_ims_separately: Whether to plot the images separately
|
32 |
-
:param plot_landmark_amaps: Whether to plot the landmark attention maps
|
33 |
"""
|
34 |
self.save_resolution = save_resolution
|
35 |
self.alpha = alpha
|
36 |
-
self.sub_path_test = sub_path_test
|
37 |
-
self.dataset_name = dataset_name
|
38 |
self.bg_label = bg_label
|
39 |
self.snapshot_dir = snapshot_dir
|
40 |
|
41 |
self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
|
42 |
-
self.batch_size = batch_size
|
43 |
-
self.nrows = factors(self.batch_size)[-1]
|
44 |
-
self.ncols = factors(self.batch_size)[-2]
|
45 |
self.num_parts = num_parts
|
46 |
self.req_colors = colors[:num_parts]
|
47 |
-
self.
|
48 |
-
self.plot_landmark_amaps = plot_landmark_amaps
|
49 |
-
if self.nrows == 1 and self.ncols == 1:
|
50 |
-
self.figs_size = (10, 10)
|
51 |
-
else:
|
52 |
-
self.figs_size = (self.ncols * 2, self.nrows * 2)
|
53 |
-
|
54 |
-
def recalculate_nrows_ncols(self):
|
55 |
-
self.nrows = factors(self.batch_size)[-1]
|
56 |
-
self.ncols = factors(self.batch_size)[-2]
|
57 |
-
if self.nrows == 1 and self.ncols == 1:
|
58 |
-
self.figs_size = (10, 10)
|
59 |
-
else:
|
60 |
-
self.figs_size = (self.ncols * 2, self.nrows * 2)
|
61 |
|
62 |
@torch.no_grad()
|
63 |
-
def show_maps(self, ims, maps
|
64 |
"""
|
65 |
Plot images, attention maps and landmark centroids.
|
66 |
Parameters
|
@@ -69,67 +39,13 @@ class VisualizeAttentionMaps:
|
|
69 |
Input images on which to show the attention maps
|
70 |
maps: Tensor, [batch_size, number of parts + 1, width_map, height_map]
|
71 |
The attention maps to display
|
72 |
-
epoch: int
|
73 |
-
The epoch number
|
74 |
-
curr_iter: int
|
75 |
-
The current iteration number
|
76 |
-
extra_info: str
|
77 |
-
Any extra information to add to the file name
|
78 |
"""
|
79 |
ims = self.resize_unnorm(ims)
|
80 |
-
if ims.shape[0] != self.batch_size:
|
81 |
-
self.batch_size = ims.shape[0]
|
82 |
-
self.recalculate_nrows_ncols()
|
83 |
-
fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
|
84 |
ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
|
85 |
map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
|
86 |
mode='bilinear',
|
87 |
align_corners=True).argmax(dim=1).cpu().numpy()
|
88 |
-
for i, ax in enumerate(axs.ravel()):
|
89 |
-
curr_map = skimage.color.label2rgb(label=map_argmax[i], image=ims[i], colors=self.req_colors,
|
90 |
-
bg_label=self.bg_label, alpha=self.alpha)
|
91 |
-
ax.imshow(curr_map)
|
92 |
-
ax.axis('off')
|
93 |
-
save_dir = Path(os.path.join(self.snapshot_dir, 'results_vis_' + self.sub_path_test))
|
94 |
-
save_dir.mkdir(parents=True, exist_ok=True)
|
95 |
-
save_path = os.path.join(save_dir, f'{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
|
96 |
-
fig.tight_layout()
|
97 |
-
if self.snapshot_dir != "":
|
98 |
-
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
99 |
-
else:
|
100 |
-
plt.show()
|
101 |
-
plt.close('all')
|
102 |
-
|
103 |
-
if self.plot_ims_separately:
|
104 |
-
fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
|
105 |
-
for i, ax in enumerate(axs.ravel()):
|
106 |
-
ax.imshow(ims[i])
|
107 |
-
ax.axis('off')
|
108 |
-
save_path = os.path.join(save_dir, f'image_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.jpg')
|
109 |
-
fig.tight_layout()
|
110 |
-
if self.snapshot_dir != "":
|
111 |
-
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
112 |
-
else:
|
113 |
-
plt.show()
|
114 |
-
plt.close('all')
|
115 |
-
|
116 |
-
if self.plot_landmark_amaps:
|
117 |
-
if self.batch_size > 1:
|
118 |
-
raise ValueError('Not implemented for batch size > 1')
|
119 |
-
for i in range(self.num_parts):
|
120 |
-
fig, ax = plt.subplots(1, 1, figsize=self.figs_size)
|
121 |
-
divider = make_axes_locatable(ax)
|
122 |
-
cax = divider.append_axes('right', size='5%', pad=0.05)
|
123 |
-
im = ax.imshow(maps[0, i, ...].detach().cpu().numpy(), cmap='cet_gouldian')
|
124 |
-
fig.colorbar(im, cax=cax, orientation='vertical')
|
125 |
-
ax.axis('off')
|
126 |
-
save_path = os.path.join(save_dir,
|
127 |
-
f'landmark_{i}_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
|
128 |
-
fig.tight_layout()
|
129 |
-
if self.snapshot_dir != "":
|
130 |
-
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
131 |
-
else:
|
132 |
-
plt.show()
|
133 |
-
plt.close()
|
134 |
|
135 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import colorcet as cc
|
2 |
import numpy as np
|
3 |
import skimage
|
|
|
|
|
4 |
import torch
|
5 |
|
6 |
from utils.data_utils.transform_utils import inverse_normalize_w_resize
|
|
|
7 |
|
8 |
# Define the colors to use for the attention maps
|
9 |
colors = cc.glasbey_category10
|
10 |
|
11 |
|
12 |
class VisualizeAttentionMaps:
|
13 |
+
def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5, bg_label=0, num_parts=15):
|
|
|
|
|
14 |
"""
|
15 |
Plot attention maps and optionally landmark centroids on images.
|
16 |
:param snapshot_dir: Directory to save the visualization results
|
17 |
:param save_resolution: Size of the images to save
|
18 |
:param alpha: The transparency of the attention maps
|
|
|
|
|
19 |
:param bg_label: The background label index in the attention maps
|
|
|
20 |
:param num_parts: The number of parts in the attention maps
|
|
|
|
|
21 |
"""
|
22 |
self.save_resolution = save_resolution
|
23 |
self.alpha = alpha
|
|
|
|
|
24 |
self.bg_label = bg_label
|
25 |
self.snapshot_dir = snapshot_dir
|
26 |
|
27 |
self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
|
|
|
|
|
|
|
28 |
self.num_parts = num_parts
|
29 |
self.req_colors = colors[:num_parts]
|
30 |
+
self.figs_size = (10, 10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
@torch.no_grad()
|
33 |
+
def show_maps(self, ims, maps):
|
34 |
"""
|
35 |
Plot images, attention maps and landmark centroids.
|
36 |
Parameters
|
|
|
39 |
Input images on which to show the attention maps
|
40 |
maps: Tensor, [batch_size, number of parts + 1, width_map, height_map]
|
41 |
The attention maps to display
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
"""
|
43 |
ims = self.resize_unnorm(ims)
|
|
|
|
|
|
|
|
|
44 |
ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
|
45 |
map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
|
46 |
mode='bilinear',
|
47 |
align_corners=True).argmax(dim=1).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
curr_map = skimage.color.label2rgb(label=map_argmax[0], image=ims[0], colors=self.req_colors,
|
50 |
+
bg_label=self.bg_label, alpha=self.alpha)
|
51 |
+
return curr_map
|