Spaces:
Sleeping
Sleeping
Commit
·
5662f96
1
Parent(s):
91efb25
remove all unused files
Browse files- README.md +1 -1
- app.py +2 -2
- utils/__init__.py +1 -1
- utils/data_utils/__init__.py +0 -5
- utils/data_utils/class_balanced_distributed_sampler.py +0 -100
- utils/data_utils/class_balanced_sampler.py +0 -31
- utils/data_utils/dataset_utils.py +0 -161
- utils/data_utils/reversible_affine_transform.py +0 -82
- utils/{data_utils/transform_utils.py → transform_utils.py} +0 -0
- utils/visualize_att_maps.py +1 -1
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: 😻
|
| 4 |
colorFrom: green
|
| 5 |
colorTo: pink
|
|
|
|
| 1 |
---
|
| 2 |
+
title: PdiscoFormer
|
| 3 |
emoji: 😻
|
| 4 |
colorFrom: green
|
| 5 |
colorTo: pink
|
app.py
CHANGED
|
@@ -4,9 +4,9 @@ from PIL import Image
|
|
| 4 |
|
| 5 |
from models import IndividualLandmarkViT
|
| 6 |
from utils import VisualizeAttentionMaps
|
| 7 |
-
from utils.
|
| 8 |
|
| 9 |
-
st.title("
|
| 10 |
model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16",
|
| 11 |
"ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8",
|
| 12 |
"ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25",
|
|
|
|
| 4 |
|
| 5 |
from models import IndividualLandmarkViT
|
| 6 |
from utils import VisualizeAttentionMaps
|
| 7 |
+
from utils.transform_utils import make_test_transforms
|
| 8 |
|
| 9 |
+
st.title("PdiscoFormer Part Discovery Visualizer")
|
| 10 |
model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16",
|
| 11 |
"ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8",
|
| 12 |
"ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25",
|
utils/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
from .data_utils import *
|
| 2 |
from .visualize_att_maps import *
|
| 3 |
from .misc_utils import *
|
| 4 |
from .get_landmark_coordinates import *
|
|
|
|
| 5 |
|
| 6 |
|
|
|
|
|
|
|
| 1 |
from .visualize_att_maps import *
|
| 2 |
from .misc_utils import *
|
| 3 |
from .get_landmark_coordinates import *
|
| 4 |
+
from .transform_utils import *
|
| 5 |
|
| 6 |
|
utils/data_utils/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
from .dataset_utils import *
|
| 2 |
-
from .reversible_affine_transform import *
|
| 3 |
-
from .transform_utils import *
|
| 4 |
-
from .class_balanced_distributed_sampler import *
|
| 5 |
-
from .class_balanced_sampler import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/data_utils/class_balanced_distributed_sampler.py
DELETED
|
@@ -1,100 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch.utils.data import Dataset
|
| 3 |
-
from typing import Optional
|
| 4 |
-
import math
|
| 5 |
-
import torch.distributed as dist
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class ClassBalancedDistributedSampler(torch.utils.data.Sampler):
|
| 9 |
-
"""
|
| 10 |
-
A custom sampler that sub-samples a given dataset based on class labels. Based on the DistributedSampler class
|
| 11 |
-
Ref: https://github.com/pytorch/pytorch/blob/04c1df651aa58bea50977f4efcf19b09ce27cefd/torch/utils/data/distributed.py#L13
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None,
|
| 15 |
-
shuffle: bool = True, seed: int = 0, drop_last: bool = False, num_samples_per_class=100) -> None:
|
| 16 |
-
|
| 17 |
-
if not shuffle:
|
| 18 |
-
raise ValueError("ClassBalancedDatasetSubSampler requires shuffling, otherwise use DistributedSampler")
|
| 19 |
-
|
| 20 |
-
# Check if the dataset has a generate_class_balanced_indices method
|
| 21 |
-
if not hasattr(dataset, 'generate_class_balanced_indices'):
|
| 22 |
-
raise ValueError("Dataset does not have a generate_class_balanced_indices method")
|
| 23 |
-
|
| 24 |
-
self.shuffle = shuffle
|
| 25 |
-
self.seed = seed
|
| 26 |
-
if num_replicas is None:
|
| 27 |
-
if not dist.is_available():
|
| 28 |
-
raise RuntimeError("Requires distributed package to be available")
|
| 29 |
-
num_replicas = dist.get_world_size()
|
| 30 |
-
if rank is None:
|
| 31 |
-
if not dist.is_available():
|
| 32 |
-
raise RuntimeError("Requires distributed package to be available")
|
| 33 |
-
rank = dist.get_rank()
|
| 34 |
-
if rank >= num_replicas or rank < 0:
|
| 35 |
-
raise ValueError(
|
| 36 |
-
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
|
| 37 |
-
self.dataset = dataset
|
| 38 |
-
self.num_replicas = num_replicas
|
| 39 |
-
self.rank = rank
|
| 40 |
-
self.epoch = 0
|
| 41 |
-
self.drop_last = drop_last
|
| 42 |
-
|
| 43 |
-
# Calculate the number of samples
|
| 44 |
-
g = torch.Generator()
|
| 45 |
-
g.manual_seed(self.seed + self.epoch)
|
| 46 |
-
self.num_samples_per_class = num_samples_per_class
|
| 47 |
-
indices = dataset.generate_class_balanced_indices(torch.Generator(),
|
| 48 |
-
num_samples_per_class=num_samples_per_class)
|
| 49 |
-
dataset_size = len(indices)
|
| 50 |
-
|
| 51 |
-
# If the dataset length is evenly divisible by # of replicas, then there
|
| 52 |
-
# is no need to drop any data, since the dataset will be split equally.
|
| 53 |
-
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
| 54 |
-
# Split to nearest available length that is evenly divisible.
|
| 55 |
-
# This is to ensure each rank receives the same amount of data when
|
| 56 |
-
# using this Sampler.
|
| 57 |
-
self.num_samples = math.ceil(
|
| 58 |
-
(dataset_size - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
| 59 |
-
)
|
| 60 |
-
else:
|
| 61 |
-
self.num_samples = math.ceil(dataset_size / self.num_replicas) # type: ignore[arg-type]
|
| 62 |
-
self.total_size = self.num_samples * self.num_replicas
|
| 63 |
-
|
| 64 |
-
def __iter__(self):
|
| 65 |
-
# deterministically shuffle based on epoch and seed, here shuffle is assumed to be True
|
| 66 |
-
g = torch.Generator()
|
| 67 |
-
g.manual_seed(self.seed + self.epoch)
|
| 68 |
-
indices = self.dataset.generate_class_balanced_indices(g, num_samples_per_class=self.num_samples_per_class)
|
| 69 |
-
|
| 70 |
-
if not self.drop_last:
|
| 71 |
-
# add extra samples to make it evenly divisible
|
| 72 |
-
padding_size = self.total_size - len(indices)
|
| 73 |
-
if padding_size <= len(indices):
|
| 74 |
-
indices += indices[:padding_size]
|
| 75 |
-
else:
|
| 76 |
-
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
| 77 |
-
else:
|
| 78 |
-
# remove tail of data to make it evenly divisible.
|
| 79 |
-
indices = indices[:self.total_size]
|
| 80 |
-
|
| 81 |
-
# subsample
|
| 82 |
-
indices = indices[self.rank:self.total_size:self.num_replicas]
|
| 83 |
-
|
| 84 |
-
return iter(indices)
|
| 85 |
-
|
| 86 |
-
def __len__(self) -> int:
|
| 87 |
-
return self.num_samples
|
| 88 |
-
|
| 89 |
-
def set_epoch(self, epoch: int) -> None:
|
| 90 |
-
r"""
|
| 91 |
-
Set the epoch for this sampler.
|
| 92 |
-
|
| 93 |
-
When :attr:`shuffle=True`, this ensures all replicas
|
| 94 |
-
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
| 95 |
-
sampler will yield the same ordering.
|
| 96 |
-
|
| 97 |
-
Args:
|
| 98 |
-
epoch (int): Epoch number.
|
| 99 |
-
"""
|
| 100 |
-
self.epoch = epoch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/data_utils/class_balanced_sampler.py
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch.utils.data import Dataset
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class ClassBalancedRandomSampler(torch.utils.data.Sampler):
|
| 6 |
-
"""
|
| 7 |
-
A custom sampler that sub-samples a given dataset based on class labels. Based on the RandomSampler class
|
| 8 |
-
This is essentially the non-ddp version of ClassBalancedDistributedSampler
|
| 9 |
-
Ref: https://github.com/pytorch/pytorch/blob/abe3c55a6a01c5b625eeb4fc9aab1421a5965cd2/torch/utils/data/sampler.py#L117
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
def __init__(self, dataset: Dataset, num_samples_per_class=100, seed: int = 0) -> None:
|
| 13 |
-
self.dataset = dataset
|
| 14 |
-
self.seed = seed
|
| 15 |
-
# Calculate the number of samples
|
| 16 |
-
self.generator = torch.Generator()
|
| 17 |
-
self.generator.manual_seed(self.seed)
|
| 18 |
-
self.num_samples_per_class = num_samples_per_class
|
| 19 |
-
indices = dataset.generate_class_balanced_indices(self.generator,
|
| 20 |
-
num_samples_per_class=num_samples_per_class)
|
| 21 |
-
self.num_samples = len(indices)
|
| 22 |
-
|
| 23 |
-
def __iter__(self):
|
| 24 |
-
# Change seed for every function call
|
| 25 |
-
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 26 |
-
self.generator.manual_seed(seed)
|
| 27 |
-
indices = self.dataset.generate_class_balanced_indices(self.generator, num_samples_per_class=self.num_samples_per_class)
|
| 28 |
-
return iter(indices)
|
| 29 |
-
|
| 30 |
-
def __len__(self) -> int:
|
| 31 |
-
return self.num_samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/data_utils/dataset_utils.py
DELETED
|
@@ -1,161 +0,0 @@
|
|
| 1 |
-
from PIL import Image
|
| 2 |
-
from torch import Tensor
|
| 3 |
-
from typing import List, Optional
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torchvision
|
| 6 |
-
import json
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def load_json(path: str):
|
| 10 |
-
"""
|
| 11 |
-
Load json file from path and return the data
|
| 12 |
-
:param path: Path to the json file
|
| 13 |
-
:return:
|
| 14 |
-
data: Data in the json file
|
| 15 |
-
"""
|
| 16 |
-
with open(path, 'r') as f:
|
| 17 |
-
data = json.load(f)
|
| 18 |
-
return data
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def save_json(data: dict, path: str):
|
| 22 |
-
"""
|
| 23 |
-
Save data to a json file
|
| 24 |
-
:param data: Data to be saved
|
| 25 |
-
:param path: Path to save the data
|
| 26 |
-
:return:
|
| 27 |
-
"""
|
| 28 |
-
with open(path, "w") as f:
|
| 29 |
-
json.dump(data, f)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def pil_loader(path):
|
| 33 |
-
"""
|
| 34 |
-
Load image from path using PIL
|
| 35 |
-
:param path: Path to the image
|
| 36 |
-
:return:
|
| 37 |
-
img: PIL Image
|
| 38 |
-
"""
|
| 39 |
-
with open(path, 'rb') as f:
|
| 40 |
-
img = Image.open(f)
|
| 41 |
-
return img.convert('RGB')
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def get_dimensions(image: Tensor):
|
| 45 |
-
"""
|
| 46 |
-
Get the dimensions of the image
|
| 47 |
-
:param image: Tensor or PIL Image or np.ndarray
|
| 48 |
-
:return:
|
| 49 |
-
h: Height of the image
|
| 50 |
-
w: Width of the image
|
| 51 |
-
"""
|
| 52 |
-
if isinstance(image, Tensor):
|
| 53 |
-
_, h, w = image.shape
|
| 54 |
-
elif isinstance(image, np.ndarray):
|
| 55 |
-
h, w, _ = image.shape
|
| 56 |
-
elif isinstance(image, Image.Image):
|
| 57 |
-
w, h = image.size
|
| 58 |
-
else:
|
| 59 |
-
raise ValueError(f"Invalid image type: {type(image)}")
|
| 60 |
-
return h, w
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def center_crop_boxes_kps(img: Tensor, output_size: Optional[List[int]] = 448, parts: Optional[Tensor] = None,
|
| 64 |
-
boxes: Optional[Tensor] = None, num_keypoints: int = 15):
|
| 65 |
-
"""
|
| 66 |
-
Calculate the center crop parameters for the bounding boxes and landmarks and update them
|
| 67 |
-
:param img: Image
|
| 68 |
-
:param output_size: Output size of the cropped image
|
| 69 |
-
:param parts: Locations of the landmarks of following format: <part_id> <x> <y> <visible>
|
| 70 |
-
:param boxes: Bounding boxes of the landmarks of following format: <image_id> <x> <y> <width> <height>
|
| 71 |
-
:param num_keypoints: Number of keypoints
|
| 72 |
-
:return:
|
| 73 |
-
cropped_img: Center cropped image
|
| 74 |
-
parts: Updated locations of the landmarks
|
| 75 |
-
boxes: Updated bounding boxes of the landmarks
|
| 76 |
-
"""
|
| 77 |
-
if isinstance(output_size, int):
|
| 78 |
-
output_size = (output_size, output_size)
|
| 79 |
-
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
| 80 |
-
output_size = (output_size[0], output_size[0])
|
| 81 |
-
elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
|
| 82 |
-
output_size = output_size
|
| 83 |
-
else:
|
| 84 |
-
raise ValueError(f"Invalid output size: {output_size}")
|
| 85 |
-
|
| 86 |
-
crop_height, crop_width = output_size
|
| 87 |
-
image_height, image_width = get_dimensions(img)
|
| 88 |
-
img = torchvision.transforms.functional.center_crop(img, output_size)
|
| 89 |
-
|
| 90 |
-
crop_top, crop_left = _get_center_crop_params_(image_height, image_width, output_size)
|
| 91 |
-
|
| 92 |
-
if parts is not None:
|
| 93 |
-
for j in range(num_keypoints):
|
| 94 |
-
# Skip if part is invisible
|
| 95 |
-
if parts[j][-1] == 0:
|
| 96 |
-
continue
|
| 97 |
-
parts[j][1] -= crop_left
|
| 98 |
-
parts[j][2] -= crop_top
|
| 99 |
-
|
| 100 |
-
# Skip if part is outside the crop
|
| 101 |
-
if parts[j][1] > crop_width or parts[j][2] > crop_height:
|
| 102 |
-
parts[j][-1] = 0
|
| 103 |
-
if parts[j][1] < 0 or parts[j][2] < 0:
|
| 104 |
-
parts[j][-1] = 0
|
| 105 |
-
|
| 106 |
-
parts[j][1] = min(crop_width, parts[j][1])
|
| 107 |
-
parts[j][2] = min(crop_height, parts[j][2])
|
| 108 |
-
parts[j][1] = max(0, parts[j][1])
|
| 109 |
-
parts[j][2] = max(0, parts[j][2])
|
| 110 |
-
|
| 111 |
-
if boxes is not None:
|
| 112 |
-
boxes[1] -= crop_left
|
| 113 |
-
boxes[2] -= crop_top
|
| 114 |
-
boxes[1] = max(0, boxes[1])
|
| 115 |
-
boxes[2] = max(0, boxes[2])
|
| 116 |
-
boxes[1] = min(crop_width, boxes[1])
|
| 117 |
-
boxes[2] = min(crop_height, boxes[2])
|
| 118 |
-
|
| 119 |
-
return img, parts, boxes
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
def _get_center_crop_params_(image_height: int, image_width: int, output_size: Optional[List[int]] = 448):
|
| 123 |
-
"""
|
| 124 |
-
Get the parameters for center cropping the image
|
| 125 |
-
:param image_height: Height of the image
|
| 126 |
-
:param image_width: Width of the image
|
| 127 |
-
:param output_size: Output size of the cropped image
|
| 128 |
-
:return:
|
| 129 |
-
crop_top: Top coordinate of the cropped image
|
| 130 |
-
crop_left: Left coordinate of the cropped image
|
| 131 |
-
"""
|
| 132 |
-
if isinstance(output_size, int):
|
| 133 |
-
output_size = (output_size, output_size)
|
| 134 |
-
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
| 135 |
-
output_size = (output_size[0], output_size[0])
|
| 136 |
-
elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
|
| 137 |
-
output_size = output_size
|
| 138 |
-
else:
|
| 139 |
-
raise ValueError(f"Invalid output size: {output_size}")
|
| 140 |
-
|
| 141 |
-
crop_height, crop_width = output_size
|
| 142 |
-
|
| 143 |
-
if crop_width > image_width or crop_height > image_height:
|
| 144 |
-
padding_ltrb = [
|
| 145 |
-
(crop_width - image_width) // 2 if crop_width > image_width else 0,
|
| 146 |
-
(crop_height - image_height) // 2 if crop_height > image_height else 0,
|
| 147 |
-
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
| 148 |
-
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
| 149 |
-
]
|
| 150 |
-
crop_top, crop_left = padding_ltrb[1], padding_ltrb[0]
|
| 151 |
-
return crop_top, crop_left
|
| 152 |
-
|
| 153 |
-
if crop_width == image_width and crop_height == image_height:
|
| 154 |
-
crop_top = 0
|
| 155 |
-
crop_left = 0
|
| 156 |
-
return crop_top, crop_left
|
| 157 |
-
|
| 158 |
-
crop_top = int(round((image_height - crop_height) / 2.0))
|
| 159 |
-
crop_left = int(round((image_width - crop_width) / 2.0))
|
| 160 |
-
|
| 161 |
-
return crop_top, crop_left
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/data_utils/reversible_affine_transform.py
DELETED
|
@@ -1,82 +0,0 @@
|
|
| 1 |
-
# Description: This file contains the code for the reversible affine transform
|
| 2 |
-
import torchvision.transforms as transforms
|
| 3 |
-
import torch
|
| 4 |
-
from typing import List, Optional, Tuple, Any
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def generate_affine_trans_params(
|
| 8 |
-
degrees: List[float],
|
| 9 |
-
translate: Optional[List[float]],
|
| 10 |
-
scale_ranges: Optional[List[float]],
|
| 11 |
-
shears: Optional[List[float]],
|
| 12 |
-
img_size: List[int],
|
| 13 |
-
) -> Tuple[float, Tuple[int, int], float, Any]:
|
| 14 |
-
"""Get parameters for affine transformation
|
| 15 |
-
|
| 16 |
-
Returns:
|
| 17 |
-
params to be passed to the affine transformation
|
| 18 |
-
"""
|
| 19 |
-
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
|
| 20 |
-
if translate is not None:
|
| 21 |
-
max_dx = float(translate[0] * img_size[0])
|
| 22 |
-
max_dy = float(translate[1] * img_size[1])
|
| 23 |
-
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
|
| 24 |
-
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
|
| 25 |
-
translations = (tx, ty)
|
| 26 |
-
else:
|
| 27 |
-
translations = (0, 0)
|
| 28 |
-
|
| 29 |
-
if scale_ranges is not None:
|
| 30 |
-
scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
|
| 31 |
-
else:
|
| 32 |
-
scale = 1.0
|
| 33 |
-
|
| 34 |
-
shear_x = shear_y = 0.0
|
| 35 |
-
if shears is not None:
|
| 36 |
-
shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
|
| 37 |
-
if len(shears) == 4:
|
| 38 |
-
shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
|
| 39 |
-
|
| 40 |
-
shear = (shear_x, shear_y)
|
| 41 |
-
if shear_x == 0.0 and shear_y == 0.0:
|
| 42 |
-
shear = 0.0
|
| 43 |
-
|
| 44 |
-
return angle, translations, scale, shear
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def rigid_transform(img, angle, translate, scale, invert=False, shear=0,
|
| 48 |
-
interpolation=transforms.InterpolationMode.BILINEAR):
|
| 49 |
-
"""
|
| 50 |
-
Affine transforms input image
|
| 51 |
-
Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L54
|
| 52 |
-
Parameters
|
| 53 |
-
----------
|
| 54 |
-
img: Tensor
|
| 55 |
-
Input image
|
| 56 |
-
angle: int
|
| 57 |
-
Rotation angle between -180 and 180 degrees
|
| 58 |
-
translate: [int]
|
| 59 |
-
Sequence of horizontal/vertical translations
|
| 60 |
-
scale: float
|
| 61 |
-
How to scale the image
|
| 62 |
-
invert: bool
|
| 63 |
-
Whether to invert the transformation
|
| 64 |
-
shear: float
|
| 65 |
-
Shear angle in degrees
|
| 66 |
-
interpolation: InterpolationMode
|
| 67 |
-
Interpolation mode to calculate output values
|
| 68 |
-
Returns
|
| 69 |
-
----------
|
| 70 |
-
img: Tensor
|
| 71 |
-
Transformed image
|
| 72 |
-
|
| 73 |
-
"""
|
| 74 |
-
if not invert:
|
| 75 |
-
img = transforms.functional.affine(img, angle=angle, translate=translate, scale=scale, shear=shear,
|
| 76 |
-
interpolation=interpolation)
|
| 77 |
-
else:
|
| 78 |
-
translate = [-t for t in translate]
|
| 79 |
-
img = transforms.functional.affine(img=img, angle=0, translate=translate, scale=1, shear=shear)
|
| 80 |
-
img = transforms.functional.affine(img=img, angle=-angle, translate=[0, 0], scale=1 / scale, shear=shear)
|
| 81 |
-
|
| 82 |
-
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/{data_utils/transform_utils.py → transform_utils.py}
RENAMED
|
File without changes
|
utils/visualize_att_maps.py
CHANGED
|
@@ -3,7 +3,7 @@ import numpy as np
|
|
| 3 |
import skimage
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from utils.
|
| 7 |
|
| 8 |
# Define the colors to use for the attention maps
|
| 9 |
colors = cc.glasbey_category10
|
|
|
|
| 3 |
import skimage
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from utils.transform_utils import inverse_normalize_w_resize
|
| 7 |
|
| 8 |
# Define the colors to use for the attention maps
|
| 9 |
colors = cc.glasbey_category10
|