Spaces:
Sleeping
Sleeping
Commit
·
5662f96
1
Parent(s):
91efb25
remove all unused files
Browse files- README.md +1 -1
- app.py +2 -2
- utils/__init__.py +1 -1
- utils/data_utils/__init__.py +0 -5
- utils/data_utils/class_balanced_distributed_sampler.py +0 -100
- utils/data_utils/class_balanced_sampler.py +0 -31
- utils/data_utils/dataset_utils.py +0 -161
- utils/data_utils/reversible_affine_transform.py +0 -82
- utils/{data_utils/transform_utils.py → transform_utils.py} +0 -0
- utils/visualize_att_maps.py +1 -1
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 😻
|
4 |
colorFrom: green
|
5 |
colorTo: pink
|
|
|
1 |
---
|
2 |
+
title: PdiscoFormer
|
3 |
emoji: 😻
|
4 |
colorFrom: green
|
5 |
colorTo: pink
|
app.py
CHANGED
@@ -4,9 +4,9 @@ from PIL import Image
|
|
4 |
|
5 |
from models import IndividualLandmarkViT
|
6 |
from utils import VisualizeAttentionMaps
|
7 |
-
from utils.
|
8 |
|
9 |
-
st.title("
|
10 |
model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16",
|
11 |
"ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8",
|
12 |
"ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25",
|
|
|
4 |
|
5 |
from models import IndividualLandmarkViT
|
6 |
from utils import VisualizeAttentionMaps
|
7 |
+
from utils.transform_utils import make_test_transforms
|
8 |
|
9 |
+
st.title("PdiscoFormer Part Discovery Visualizer")
|
10 |
model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16",
|
11 |
"ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8",
|
12 |
"ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25",
|
utils/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
from .data_utils import *
|
2 |
from .visualize_att_maps import *
|
3 |
from .misc_utils import *
|
4 |
from .get_landmark_coordinates import *
|
|
|
5 |
|
6 |
|
|
|
|
|
1 |
from .visualize_att_maps import *
|
2 |
from .misc_utils import *
|
3 |
from .get_landmark_coordinates import *
|
4 |
+
from .transform_utils import *
|
5 |
|
6 |
|
utils/data_utils/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
from .dataset_utils import *
|
2 |
-
from .reversible_affine_transform import *
|
3 |
-
from .transform_utils import *
|
4 |
-
from .class_balanced_distributed_sampler import *
|
5 |
-
from .class_balanced_sampler import *
|
|
|
|
|
|
|
|
|
|
|
|
utils/data_utils/class_balanced_distributed_sampler.py
DELETED
@@ -1,100 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.utils.data import Dataset
|
3 |
-
from typing import Optional
|
4 |
-
import math
|
5 |
-
import torch.distributed as dist
|
6 |
-
|
7 |
-
|
8 |
-
class ClassBalancedDistributedSampler(torch.utils.data.Sampler):
|
9 |
-
"""
|
10 |
-
A custom sampler that sub-samples a given dataset based on class labels. Based on the DistributedSampler class
|
11 |
-
Ref: https://github.com/pytorch/pytorch/blob/04c1df651aa58bea50977f4efcf19b09ce27cefd/torch/utils/data/distributed.py#L13
|
12 |
-
"""
|
13 |
-
|
14 |
-
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None,
|
15 |
-
shuffle: bool = True, seed: int = 0, drop_last: bool = False, num_samples_per_class=100) -> None:
|
16 |
-
|
17 |
-
if not shuffle:
|
18 |
-
raise ValueError("ClassBalancedDatasetSubSampler requires shuffling, otherwise use DistributedSampler")
|
19 |
-
|
20 |
-
# Check if the dataset has a generate_class_balanced_indices method
|
21 |
-
if not hasattr(dataset, 'generate_class_balanced_indices'):
|
22 |
-
raise ValueError("Dataset does not have a generate_class_balanced_indices method")
|
23 |
-
|
24 |
-
self.shuffle = shuffle
|
25 |
-
self.seed = seed
|
26 |
-
if num_replicas is None:
|
27 |
-
if not dist.is_available():
|
28 |
-
raise RuntimeError("Requires distributed package to be available")
|
29 |
-
num_replicas = dist.get_world_size()
|
30 |
-
if rank is None:
|
31 |
-
if not dist.is_available():
|
32 |
-
raise RuntimeError("Requires distributed package to be available")
|
33 |
-
rank = dist.get_rank()
|
34 |
-
if rank >= num_replicas or rank < 0:
|
35 |
-
raise ValueError(
|
36 |
-
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
|
37 |
-
self.dataset = dataset
|
38 |
-
self.num_replicas = num_replicas
|
39 |
-
self.rank = rank
|
40 |
-
self.epoch = 0
|
41 |
-
self.drop_last = drop_last
|
42 |
-
|
43 |
-
# Calculate the number of samples
|
44 |
-
g = torch.Generator()
|
45 |
-
g.manual_seed(self.seed + self.epoch)
|
46 |
-
self.num_samples_per_class = num_samples_per_class
|
47 |
-
indices = dataset.generate_class_balanced_indices(torch.Generator(),
|
48 |
-
num_samples_per_class=num_samples_per_class)
|
49 |
-
dataset_size = len(indices)
|
50 |
-
|
51 |
-
# If the dataset length is evenly divisible by # of replicas, then there
|
52 |
-
# is no need to drop any data, since the dataset will be split equally.
|
53 |
-
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
54 |
-
# Split to nearest available length that is evenly divisible.
|
55 |
-
# This is to ensure each rank receives the same amount of data when
|
56 |
-
# using this Sampler.
|
57 |
-
self.num_samples = math.ceil(
|
58 |
-
(dataset_size - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
59 |
-
)
|
60 |
-
else:
|
61 |
-
self.num_samples = math.ceil(dataset_size / self.num_replicas) # type: ignore[arg-type]
|
62 |
-
self.total_size = self.num_samples * self.num_replicas
|
63 |
-
|
64 |
-
def __iter__(self):
|
65 |
-
# deterministically shuffle based on epoch and seed, here shuffle is assumed to be True
|
66 |
-
g = torch.Generator()
|
67 |
-
g.manual_seed(self.seed + self.epoch)
|
68 |
-
indices = self.dataset.generate_class_balanced_indices(g, num_samples_per_class=self.num_samples_per_class)
|
69 |
-
|
70 |
-
if not self.drop_last:
|
71 |
-
# add extra samples to make it evenly divisible
|
72 |
-
padding_size = self.total_size - len(indices)
|
73 |
-
if padding_size <= len(indices):
|
74 |
-
indices += indices[:padding_size]
|
75 |
-
else:
|
76 |
-
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
77 |
-
else:
|
78 |
-
# remove tail of data to make it evenly divisible.
|
79 |
-
indices = indices[:self.total_size]
|
80 |
-
|
81 |
-
# subsample
|
82 |
-
indices = indices[self.rank:self.total_size:self.num_replicas]
|
83 |
-
|
84 |
-
return iter(indices)
|
85 |
-
|
86 |
-
def __len__(self) -> int:
|
87 |
-
return self.num_samples
|
88 |
-
|
89 |
-
def set_epoch(self, epoch: int) -> None:
|
90 |
-
r"""
|
91 |
-
Set the epoch for this sampler.
|
92 |
-
|
93 |
-
When :attr:`shuffle=True`, this ensures all replicas
|
94 |
-
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
95 |
-
sampler will yield the same ordering.
|
96 |
-
|
97 |
-
Args:
|
98 |
-
epoch (int): Epoch number.
|
99 |
-
"""
|
100 |
-
self.epoch = epoch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/data_utils/class_balanced_sampler.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.utils.data import Dataset
|
3 |
-
|
4 |
-
|
5 |
-
class ClassBalancedRandomSampler(torch.utils.data.Sampler):
|
6 |
-
"""
|
7 |
-
A custom sampler that sub-samples a given dataset based on class labels. Based on the RandomSampler class
|
8 |
-
This is essentially the non-ddp version of ClassBalancedDistributedSampler
|
9 |
-
Ref: https://github.com/pytorch/pytorch/blob/abe3c55a6a01c5b625eeb4fc9aab1421a5965cd2/torch/utils/data/sampler.py#L117
|
10 |
-
"""
|
11 |
-
|
12 |
-
def __init__(self, dataset: Dataset, num_samples_per_class=100, seed: int = 0) -> None:
|
13 |
-
self.dataset = dataset
|
14 |
-
self.seed = seed
|
15 |
-
# Calculate the number of samples
|
16 |
-
self.generator = torch.Generator()
|
17 |
-
self.generator.manual_seed(self.seed)
|
18 |
-
self.num_samples_per_class = num_samples_per_class
|
19 |
-
indices = dataset.generate_class_balanced_indices(self.generator,
|
20 |
-
num_samples_per_class=num_samples_per_class)
|
21 |
-
self.num_samples = len(indices)
|
22 |
-
|
23 |
-
def __iter__(self):
|
24 |
-
# Change seed for every function call
|
25 |
-
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
26 |
-
self.generator.manual_seed(seed)
|
27 |
-
indices = self.dataset.generate_class_balanced_indices(self.generator, num_samples_per_class=self.num_samples_per_class)
|
28 |
-
return iter(indices)
|
29 |
-
|
30 |
-
def __len__(self) -> int:
|
31 |
-
return self.num_samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/data_utils/dataset_utils.py
DELETED
@@ -1,161 +0,0 @@
|
|
1 |
-
from PIL import Image
|
2 |
-
from torch import Tensor
|
3 |
-
from typing import List, Optional
|
4 |
-
import numpy as np
|
5 |
-
import torchvision
|
6 |
-
import json
|
7 |
-
|
8 |
-
|
9 |
-
def load_json(path: str):
|
10 |
-
"""
|
11 |
-
Load json file from path and return the data
|
12 |
-
:param path: Path to the json file
|
13 |
-
:return:
|
14 |
-
data: Data in the json file
|
15 |
-
"""
|
16 |
-
with open(path, 'r') as f:
|
17 |
-
data = json.load(f)
|
18 |
-
return data
|
19 |
-
|
20 |
-
|
21 |
-
def save_json(data: dict, path: str):
|
22 |
-
"""
|
23 |
-
Save data to a json file
|
24 |
-
:param data: Data to be saved
|
25 |
-
:param path: Path to save the data
|
26 |
-
:return:
|
27 |
-
"""
|
28 |
-
with open(path, "w") as f:
|
29 |
-
json.dump(data, f)
|
30 |
-
|
31 |
-
|
32 |
-
def pil_loader(path):
|
33 |
-
"""
|
34 |
-
Load image from path using PIL
|
35 |
-
:param path: Path to the image
|
36 |
-
:return:
|
37 |
-
img: PIL Image
|
38 |
-
"""
|
39 |
-
with open(path, 'rb') as f:
|
40 |
-
img = Image.open(f)
|
41 |
-
return img.convert('RGB')
|
42 |
-
|
43 |
-
|
44 |
-
def get_dimensions(image: Tensor):
|
45 |
-
"""
|
46 |
-
Get the dimensions of the image
|
47 |
-
:param image: Tensor or PIL Image or np.ndarray
|
48 |
-
:return:
|
49 |
-
h: Height of the image
|
50 |
-
w: Width of the image
|
51 |
-
"""
|
52 |
-
if isinstance(image, Tensor):
|
53 |
-
_, h, w = image.shape
|
54 |
-
elif isinstance(image, np.ndarray):
|
55 |
-
h, w, _ = image.shape
|
56 |
-
elif isinstance(image, Image.Image):
|
57 |
-
w, h = image.size
|
58 |
-
else:
|
59 |
-
raise ValueError(f"Invalid image type: {type(image)}")
|
60 |
-
return h, w
|
61 |
-
|
62 |
-
|
63 |
-
def center_crop_boxes_kps(img: Tensor, output_size: Optional[List[int]] = 448, parts: Optional[Tensor] = None,
|
64 |
-
boxes: Optional[Tensor] = None, num_keypoints: int = 15):
|
65 |
-
"""
|
66 |
-
Calculate the center crop parameters for the bounding boxes and landmarks and update them
|
67 |
-
:param img: Image
|
68 |
-
:param output_size: Output size of the cropped image
|
69 |
-
:param parts: Locations of the landmarks of following format: <part_id> <x> <y> <visible>
|
70 |
-
:param boxes: Bounding boxes of the landmarks of following format: <image_id> <x> <y> <width> <height>
|
71 |
-
:param num_keypoints: Number of keypoints
|
72 |
-
:return:
|
73 |
-
cropped_img: Center cropped image
|
74 |
-
parts: Updated locations of the landmarks
|
75 |
-
boxes: Updated bounding boxes of the landmarks
|
76 |
-
"""
|
77 |
-
if isinstance(output_size, int):
|
78 |
-
output_size = (output_size, output_size)
|
79 |
-
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
80 |
-
output_size = (output_size[0], output_size[0])
|
81 |
-
elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
|
82 |
-
output_size = output_size
|
83 |
-
else:
|
84 |
-
raise ValueError(f"Invalid output size: {output_size}")
|
85 |
-
|
86 |
-
crop_height, crop_width = output_size
|
87 |
-
image_height, image_width = get_dimensions(img)
|
88 |
-
img = torchvision.transforms.functional.center_crop(img, output_size)
|
89 |
-
|
90 |
-
crop_top, crop_left = _get_center_crop_params_(image_height, image_width, output_size)
|
91 |
-
|
92 |
-
if parts is not None:
|
93 |
-
for j in range(num_keypoints):
|
94 |
-
# Skip if part is invisible
|
95 |
-
if parts[j][-1] == 0:
|
96 |
-
continue
|
97 |
-
parts[j][1] -= crop_left
|
98 |
-
parts[j][2] -= crop_top
|
99 |
-
|
100 |
-
# Skip if part is outside the crop
|
101 |
-
if parts[j][1] > crop_width or parts[j][2] > crop_height:
|
102 |
-
parts[j][-1] = 0
|
103 |
-
if parts[j][1] < 0 or parts[j][2] < 0:
|
104 |
-
parts[j][-1] = 0
|
105 |
-
|
106 |
-
parts[j][1] = min(crop_width, parts[j][1])
|
107 |
-
parts[j][2] = min(crop_height, parts[j][2])
|
108 |
-
parts[j][1] = max(0, parts[j][1])
|
109 |
-
parts[j][2] = max(0, parts[j][2])
|
110 |
-
|
111 |
-
if boxes is not None:
|
112 |
-
boxes[1] -= crop_left
|
113 |
-
boxes[2] -= crop_top
|
114 |
-
boxes[1] = max(0, boxes[1])
|
115 |
-
boxes[2] = max(0, boxes[2])
|
116 |
-
boxes[1] = min(crop_width, boxes[1])
|
117 |
-
boxes[2] = min(crop_height, boxes[2])
|
118 |
-
|
119 |
-
return img, parts, boxes
|
120 |
-
|
121 |
-
|
122 |
-
def _get_center_crop_params_(image_height: int, image_width: int, output_size: Optional[List[int]] = 448):
|
123 |
-
"""
|
124 |
-
Get the parameters for center cropping the image
|
125 |
-
:param image_height: Height of the image
|
126 |
-
:param image_width: Width of the image
|
127 |
-
:param output_size: Output size of the cropped image
|
128 |
-
:return:
|
129 |
-
crop_top: Top coordinate of the cropped image
|
130 |
-
crop_left: Left coordinate of the cropped image
|
131 |
-
"""
|
132 |
-
if isinstance(output_size, int):
|
133 |
-
output_size = (output_size, output_size)
|
134 |
-
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
135 |
-
output_size = (output_size[0], output_size[0])
|
136 |
-
elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
|
137 |
-
output_size = output_size
|
138 |
-
else:
|
139 |
-
raise ValueError(f"Invalid output size: {output_size}")
|
140 |
-
|
141 |
-
crop_height, crop_width = output_size
|
142 |
-
|
143 |
-
if crop_width > image_width or crop_height > image_height:
|
144 |
-
padding_ltrb = [
|
145 |
-
(crop_width - image_width) // 2 if crop_width > image_width else 0,
|
146 |
-
(crop_height - image_height) // 2 if crop_height > image_height else 0,
|
147 |
-
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
148 |
-
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
149 |
-
]
|
150 |
-
crop_top, crop_left = padding_ltrb[1], padding_ltrb[0]
|
151 |
-
return crop_top, crop_left
|
152 |
-
|
153 |
-
if crop_width == image_width and crop_height == image_height:
|
154 |
-
crop_top = 0
|
155 |
-
crop_left = 0
|
156 |
-
return crop_top, crop_left
|
157 |
-
|
158 |
-
crop_top = int(round((image_height - crop_height) / 2.0))
|
159 |
-
crop_left = int(round((image_width - crop_width) / 2.0))
|
160 |
-
|
161 |
-
return crop_top, crop_left
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/data_utils/reversible_affine_transform.py
DELETED
@@ -1,82 +0,0 @@
|
|
1 |
-
# Description: This file contains the code for the reversible affine transform
|
2 |
-
import torchvision.transforms as transforms
|
3 |
-
import torch
|
4 |
-
from typing import List, Optional, Tuple, Any
|
5 |
-
|
6 |
-
|
7 |
-
def generate_affine_trans_params(
|
8 |
-
degrees: List[float],
|
9 |
-
translate: Optional[List[float]],
|
10 |
-
scale_ranges: Optional[List[float]],
|
11 |
-
shears: Optional[List[float]],
|
12 |
-
img_size: List[int],
|
13 |
-
) -> Tuple[float, Tuple[int, int], float, Any]:
|
14 |
-
"""Get parameters for affine transformation
|
15 |
-
|
16 |
-
Returns:
|
17 |
-
params to be passed to the affine transformation
|
18 |
-
"""
|
19 |
-
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
|
20 |
-
if translate is not None:
|
21 |
-
max_dx = float(translate[0] * img_size[0])
|
22 |
-
max_dy = float(translate[1] * img_size[1])
|
23 |
-
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
|
24 |
-
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
|
25 |
-
translations = (tx, ty)
|
26 |
-
else:
|
27 |
-
translations = (0, 0)
|
28 |
-
|
29 |
-
if scale_ranges is not None:
|
30 |
-
scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
|
31 |
-
else:
|
32 |
-
scale = 1.0
|
33 |
-
|
34 |
-
shear_x = shear_y = 0.0
|
35 |
-
if shears is not None:
|
36 |
-
shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
|
37 |
-
if len(shears) == 4:
|
38 |
-
shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
|
39 |
-
|
40 |
-
shear = (shear_x, shear_y)
|
41 |
-
if shear_x == 0.0 and shear_y == 0.0:
|
42 |
-
shear = 0.0
|
43 |
-
|
44 |
-
return angle, translations, scale, shear
|
45 |
-
|
46 |
-
|
47 |
-
def rigid_transform(img, angle, translate, scale, invert=False, shear=0,
|
48 |
-
interpolation=transforms.InterpolationMode.BILINEAR):
|
49 |
-
"""
|
50 |
-
Affine transforms input image
|
51 |
-
Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L54
|
52 |
-
Parameters
|
53 |
-
----------
|
54 |
-
img: Tensor
|
55 |
-
Input image
|
56 |
-
angle: int
|
57 |
-
Rotation angle between -180 and 180 degrees
|
58 |
-
translate: [int]
|
59 |
-
Sequence of horizontal/vertical translations
|
60 |
-
scale: float
|
61 |
-
How to scale the image
|
62 |
-
invert: bool
|
63 |
-
Whether to invert the transformation
|
64 |
-
shear: float
|
65 |
-
Shear angle in degrees
|
66 |
-
interpolation: InterpolationMode
|
67 |
-
Interpolation mode to calculate output values
|
68 |
-
Returns
|
69 |
-
----------
|
70 |
-
img: Tensor
|
71 |
-
Transformed image
|
72 |
-
|
73 |
-
"""
|
74 |
-
if not invert:
|
75 |
-
img = transforms.functional.affine(img, angle=angle, translate=translate, scale=scale, shear=shear,
|
76 |
-
interpolation=interpolation)
|
77 |
-
else:
|
78 |
-
translate = [-t for t in translate]
|
79 |
-
img = transforms.functional.affine(img=img, angle=0, translate=translate, scale=1, shear=shear)
|
80 |
-
img = transforms.functional.affine(img=img, angle=-angle, translate=[0, 0], scale=1 / scale, shear=shear)
|
81 |
-
|
82 |
-
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/{data_utils/transform_utils.py → transform_utils.py}
RENAMED
File without changes
|
utils/visualize_att_maps.py
CHANGED
@@ -3,7 +3,7 @@ import numpy as np
|
|
3 |
import skimage
|
4 |
import torch
|
5 |
|
6 |
-
from utils.
|
7 |
|
8 |
# Define the colors to use for the attention maps
|
9 |
colors = cc.glasbey_category10
|
|
|
3 |
import skimage
|
4 |
import torch
|
5 |
|
6 |
+
from utils.transform_utils import inverse_normalize_w_resize
|
7 |
|
8 |
# Define the colors to use for the attention maps
|
9 |
colors = cc.glasbey_category10
|