ananthu-aniraj commited on
Commit
5662f96
·
1 Parent(s): 91efb25

remove all unused files

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Pdiscoformer
3
  emoji: 😻
4
  colorFrom: green
5
  colorTo: pink
 
1
  ---
2
+ title: PdiscoFormer
3
  emoji: 😻
4
  colorFrom: green
5
  colorTo: pink
app.py CHANGED
@@ -4,9 +4,9 @@ from PIL import Image
4
 
5
  from models import IndividualLandmarkViT
6
  from utils import VisualizeAttentionMaps
7
- from utils.data_utils.transform_utils import make_test_transforms
8
 
9
- st.title("Pdiscoformer Part Discovery Visualizer")
10
  model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16",
11
  "ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8",
12
  "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25",
 
4
 
5
  from models import IndividualLandmarkViT
6
  from utils import VisualizeAttentionMaps
7
+ from utils.transform_utils import make_test_transforms
8
 
9
+ st.title("PdiscoFormer Part Discovery Visualizer")
10
  model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16",
11
  "ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8",
12
  "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25",
utils/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
- from .data_utils import *
2
  from .visualize_att_maps import *
3
  from .misc_utils import *
4
  from .get_landmark_coordinates import *
 
5
 
6
 
 
 
1
  from .visualize_att_maps import *
2
  from .misc_utils import *
3
  from .get_landmark_coordinates import *
4
+ from .transform_utils import *
5
 
6
 
utils/data_utils/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from .dataset_utils import *
2
- from .reversible_affine_transform import *
3
- from .transform_utils import *
4
- from .class_balanced_distributed_sampler import *
5
- from .class_balanced_sampler import *
 
 
 
 
 
 
utils/data_utils/class_balanced_distributed_sampler.py DELETED
@@ -1,100 +0,0 @@
1
- import torch
2
- from torch.utils.data import Dataset
3
- from typing import Optional
4
- import math
5
- import torch.distributed as dist
6
-
7
-
8
- class ClassBalancedDistributedSampler(torch.utils.data.Sampler):
9
- """
10
- A custom sampler that sub-samples a given dataset based on class labels. Based on the DistributedSampler class
11
- Ref: https://github.com/pytorch/pytorch/blob/04c1df651aa58bea50977f4efcf19b09ce27cefd/torch/utils/data/distributed.py#L13
12
- """
13
-
14
- def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None,
15
- shuffle: bool = True, seed: int = 0, drop_last: bool = False, num_samples_per_class=100) -> None:
16
-
17
- if not shuffle:
18
- raise ValueError("ClassBalancedDatasetSubSampler requires shuffling, otherwise use DistributedSampler")
19
-
20
- # Check if the dataset has a generate_class_balanced_indices method
21
- if not hasattr(dataset, 'generate_class_balanced_indices'):
22
- raise ValueError("Dataset does not have a generate_class_balanced_indices method")
23
-
24
- self.shuffle = shuffle
25
- self.seed = seed
26
- if num_replicas is None:
27
- if not dist.is_available():
28
- raise RuntimeError("Requires distributed package to be available")
29
- num_replicas = dist.get_world_size()
30
- if rank is None:
31
- if not dist.is_available():
32
- raise RuntimeError("Requires distributed package to be available")
33
- rank = dist.get_rank()
34
- if rank >= num_replicas or rank < 0:
35
- raise ValueError(
36
- f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
37
- self.dataset = dataset
38
- self.num_replicas = num_replicas
39
- self.rank = rank
40
- self.epoch = 0
41
- self.drop_last = drop_last
42
-
43
- # Calculate the number of samples
44
- g = torch.Generator()
45
- g.manual_seed(self.seed + self.epoch)
46
- self.num_samples_per_class = num_samples_per_class
47
- indices = dataset.generate_class_balanced_indices(torch.Generator(),
48
- num_samples_per_class=num_samples_per_class)
49
- dataset_size = len(indices)
50
-
51
- # If the dataset length is evenly divisible by # of replicas, then there
52
- # is no need to drop any data, since the dataset will be split equally.
53
- if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
54
- # Split to nearest available length that is evenly divisible.
55
- # This is to ensure each rank receives the same amount of data when
56
- # using this Sampler.
57
- self.num_samples = math.ceil(
58
- (dataset_size - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
59
- )
60
- else:
61
- self.num_samples = math.ceil(dataset_size / self.num_replicas) # type: ignore[arg-type]
62
- self.total_size = self.num_samples * self.num_replicas
63
-
64
- def __iter__(self):
65
- # deterministically shuffle based on epoch and seed, here shuffle is assumed to be True
66
- g = torch.Generator()
67
- g.manual_seed(self.seed + self.epoch)
68
- indices = self.dataset.generate_class_balanced_indices(g, num_samples_per_class=self.num_samples_per_class)
69
-
70
- if not self.drop_last:
71
- # add extra samples to make it evenly divisible
72
- padding_size = self.total_size - len(indices)
73
- if padding_size <= len(indices):
74
- indices += indices[:padding_size]
75
- else:
76
- indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
77
- else:
78
- # remove tail of data to make it evenly divisible.
79
- indices = indices[:self.total_size]
80
-
81
- # subsample
82
- indices = indices[self.rank:self.total_size:self.num_replicas]
83
-
84
- return iter(indices)
85
-
86
- def __len__(self) -> int:
87
- return self.num_samples
88
-
89
- def set_epoch(self, epoch: int) -> None:
90
- r"""
91
- Set the epoch for this sampler.
92
-
93
- When :attr:`shuffle=True`, this ensures all replicas
94
- use a different random ordering for each epoch. Otherwise, the next iteration of this
95
- sampler will yield the same ordering.
96
-
97
- Args:
98
- epoch (int): Epoch number.
99
- """
100
- self.epoch = epoch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/data_utils/class_balanced_sampler.py DELETED
@@ -1,31 +0,0 @@
1
- import torch
2
- from torch.utils.data import Dataset
3
-
4
-
5
- class ClassBalancedRandomSampler(torch.utils.data.Sampler):
6
- """
7
- A custom sampler that sub-samples a given dataset based on class labels. Based on the RandomSampler class
8
- This is essentially the non-ddp version of ClassBalancedDistributedSampler
9
- Ref: https://github.com/pytorch/pytorch/blob/abe3c55a6a01c5b625eeb4fc9aab1421a5965cd2/torch/utils/data/sampler.py#L117
10
- """
11
-
12
- def __init__(self, dataset: Dataset, num_samples_per_class=100, seed: int = 0) -> None:
13
- self.dataset = dataset
14
- self.seed = seed
15
- # Calculate the number of samples
16
- self.generator = torch.Generator()
17
- self.generator.manual_seed(self.seed)
18
- self.num_samples_per_class = num_samples_per_class
19
- indices = dataset.generate_class_balanced_indices(self.generator,
20
- num_samples_per_class=num_samples_per_class)
21
- self.num_samples = len(indices)
22
-
23
- def __iter__(self):
24
- # Change seed for every function call
25
- seed = int(torch.empty((), dtype=torch.int64).random_().item())
26
- self.generator.manual_seed(seed)
27
- indices = self.dataset.generate_class_balanced_indices(self.generator, num_samples_per_class=self.num_samples_per_class)
28
- return iter(indices)
29
-
30
- def __len__(self) -> int:
31
- return self.num_samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/data_utils/dataset_utils.py DELETED
@@ -1,161 +0,0 @@
1
- from PIL import Image
2
- from torch import Tensor
3
- from typing import List, Optional
4
- import numpy as np
5
- import torchvision
6
- import json
7
-
8
-
9
- def load_json(path: str):
10
- """
11
- Load json file from path and return the data
12
- :param path: Path to the json file
13
- :return:
14
- data: Data in the json file
15
- """
16
- with open(path, 'r') as f:
17
- data = json.load(f)
18
- return data
19
-
20
-
21
- def save_json(data: dict, path: str):
22
- """
23
- Save data to a json file
24
- :param data: Data to be saved
25
- :param path: Path to save the data
26
- :return:
27
- """
28
- with open(path, "w") as f:
29
- json.dump(data, f)
30
-
31
-
32
- def pil_loader(path):
33
- """
34
- Load image from path using PIL
35
- :param path: Path to the image
36
- :return:
37
- img: PIL Image
38
- """
39
- with open(path, 'rb') as f:
40
- img = Image.open(f)
41
- return img.convert('RGB')
42
-
43
-
44
- def get_dimensions(image: Tensor):
45
- """
46
- Get the dimensions of the image
47
- :param image: Tensor or PIL Image or np.ndarray
48
- :return:
49
- h: Height of the image
50
- w: Width of the image
51
- """
52
- if isinstance(image, Tensor):
53
- _, h, w = image.shape
54
- elif isinstance(image, np.ndarray):
55
- h, w, _ = image.shape
56
- elif isinstance(image, Image.Image):
57
- w, h = image.size
58
- else:
59
- raise ValueError(f"Invalid image type: {type(image)}")
60
- return h, w
61
-
62
-
63
- def center_crop_boxes_kps(img: Tensor, output_size: Optional[List[int]] = 448, parts: Optional[Tensor] = None,
64
- boxes: Optional[Tensor] = None, num_keypoints: int = 15):
65
- """
66
- Calculate the center crop parameters for the bounding boxes and landmarks and update them
67
- :param img: Image
68
- :param output_size: Output size of the cropped image
69
- :param parts: Locations of the landmarks of following format: <part_id> <x> <y> <visible>
70
- :param boxes: Bounding boxes of the landmarks of following format: <image_id> <x> <y> <width> <height>
71
- :param num_keypoints: Number of keypoints
72
- :return:
73
- cropped_img: Center cropped image
74
- parts: Updated locations of the landmarks
75
- boxes: Updated bounding boxes of the landmarks
76
- """
77
- if isinstance(output_size, int):
78
- output_size = (output_size, output_size)
79
- elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
80
- output_size = (output_size[0], output_size[0])
81
- elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
82
- output_size = output_size
83
- else:
84
- raise ValueError(f"Invalid output size: {output_size}")
85
-
86
- crop_height, crop_width = output_size
87
- image_height, image_width = get_dimensions(img)
88
- img = torchvision.transforms.functional.center_crop(img, output_size)
89
-
90
- crop_top, crop_left = _get_center_crop_params_(image_height, image_width, output_size)
91
-
92
- if parts is not None:
93
- for j in range(num_keypoints):
94
- # Skip if part is invisible
95
- if parts[j][-1] == 0:
96
- continue
97
- parts[j][1] -= crop_left
98
- parts[j][2] -= crop_top
99
-
100
- # Skip if part is outside the crop
101
- if parts[j][1] > crop_width or parts[j][2] > crop_height:
102
- parts[j][-1] = 0
103
- if parts[j][1] < 0 or parts[j][2] < 0:
104
- parts[j][-1] = 0
105
-
106
- parts[j][1] = min(crop_width, parts[j][1])
107
- parts[j][2] = min(crop_height, parts[j][2])
108
- parts[j][1] = max(0, parts[j][1])
109
- parts[j][2] = max(0, parts[j][2])
110
-
111
- if boxes is not None:
112
- boxes[1] -= crop_left
113
- boxes[2] -= crop_top
114
- boxes[1] = max(0, boxes[1])
115
- boxes[2] = max(0, boxes[2])
116
- boxes[1] = min(crop_width, boxes[1])
117
- boxes[2] = min(crop_height, boxes[2])
118
-
119
- return img, parts, boxes
120
-
121
-
122
- def _get_center_crop_params_(image_height: int, image_width: int, output_size: Optional[List[int]] = 448):
123
- """
124
- Get the parameters for center cropping the image
125
- :param image_height: Height of the image
126
- :param image_width: Width of the image
127
- :param output_size: Output size of the cropped image
128
- :return:
129
- crop_top: Top coordinate of the cropped image
130
- crop_left: Left coordinate of the cropped image
131
- """
132
- if isinstance(output_size, int):
133
- output_size = (output_size, output_size)
134
- elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
135
- output_size = (output_size[0], output_size[0])
136
- elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
137
- output_size = output_size
138
- else:
139
- raise ValueError(f"Invalid output size: {output_size}")
140
-
141
- crop_height, crop_width = output_size
142
-
143
- if crop_width > image_width or crop_height > image_height:
144
- padding_ltrb = [
145
- (crop_width - image_width) // 2 if crop_width > image_width else 0,
146
- (crop_height - image_height) // 2 if crop_height > image_height else 0,
147
- (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
148
- (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
149
- ]
150
- crop_top, crop_left = padding_ltrb[1], padding_ltrb[0]
151
- return crop_top, crop_left
152
-
153
- if crop_width == image_width and crop_height == image_height:
154
- crop_top = 0
155
- crop_left = 0
156
- return crop_top, crop_left
157
-
158
- crop_top = int(round((image_height - crop_height) / 2.0))
159
- crop_left = int(round((image_width - crop_width) / 2.0))
160
-
161
- return crop_top, crop_left
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/data_utils/reversible_affine_transform.py DELETED
@@ -1,82 +0,0 @@
1
- # Description: This file contains the code for the reversible affine transform
2
- import torchvision.transforms as transforms
3
- import torch
4
- from typing import List, Optional, Tuple, Any
5
-
6
-
7
- def generate_affine_trans_params(
8
- degrees: List[float],
9
- translate: Optional[List[float]],
10
- scale_ranges: Optional[List[float]],
11
- shears: Optional[List[float]],
12
- img_size: List[int],
13
- ) -> Tuple[float, Tuple[int, int], float, Any]:
14
- """Get parameters for affine transformation
15
-
16
- Returns:
17
- params to be passed to the affine transformation
18
- """
19
- angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
20
- if translate is not None:
21
- max_dx = float(translate[0] * img_size[0])
22
- max_dy = float(translate[1] * img_size[1])
23
- tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
24
- ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
25
- translations = (tx, ty)
26
- else:
27
- translations = (0, 0)
28
-
29
- if scale_ranges is not None:
30
- scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
31
- else:
32
- scale = 1.0
33
-
34
- shear_x = shear_y = 0.0
35
- if shears is not None:
36
- shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
37
- if len(shears) == 4:
38
- shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
39
-
40
- shear = (shear_x, shear_y)
41
- if shear_x == 0.0 and shear_y == 0.0:
42
- shear = 0.0
43
-
44
- return angle, translations, scale, shear
45
-
46
-
47
- def rigid_transform(img, angle, translate, scale, invert=False, shear=0,
48
- interpolation=transforms.InterpolationMode.BILINEAR):
49
- """
50
- Affine transforms input image
51
- Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L54
52
- Parameters
53
- ----------
54
- img: Tensor
55
- Input image
56
- angle: int
57
- Rotation angle between -180 and 180 degrees
58
- translate: [int]
59
- Sequence of horizontal/vertical translations
60
- scale: float
61
- How to scale the image
62
- invert: bool
63
- Whether to invert the transformation
64
- shear: float
65
- Shear angle in degrees
66
- interpolation: InterpolationMode
67
- Interpolation mode to calculate output values
68
- Returns
69
- ----------
70
- img: Tensor
71
- Transformed image
72
-
73
- """
74
- if not invert:
75
- img = transforms.functional.affine(img, angle=angle, translate=translate, scale=scale, shear=shear,
76
- interpolation=interpolation)
77
- else:
78
- translate = [-t for t in translate]
79
- img = transforms.functional.affine(img=img, angle=0, translate=translate, scale=1, shear=shear)
80
- img = transforms.functional.affine(img=img, angle=-angle, translate=[0, 0], scale=1 / scale, shear=shear)
81
-
82
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/{data_utils/transform_utils.py → transform_utils.py} RENAMED
File without changes
utils/visualize_att_maps.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import skimage
4
  import torch
5
 
6
- from utils.data_utils.transform_utils import inverse_normalize_w_resize
7
 
8
  # Define the colors to use for the attention maps
9
  colors = cc.glasbey_category10
 
3
  import skimage
4
  import torch
5
 
6
+ from utils.transform_utils import inverse_normalize_w_resize
7
 
8
  # Define the colors to use for the attention maps
9
  colors = cc.glasbey_category10