Spaces:
Sleeping
Sleeping
Commit
·
20239f9
1
Parent(s):
b507f8e
add initial files
Browse files- .gitignore +40 -0
- files/images/Laysan_Albatross_0050_870.jpg +0 -0
- layers/__init__.py +2 -0
- layers/independent_mlp.py +69 -0
- layers/transformer_layers.py +54 -0
- load_model.py +226 -0
- models/__init__.py +4 -0
- models/individual_landmark_convnext.py +110 -0
- models/individual_landmark_resnet.py +141 -0
- models/individual_landmark_vit.py +366 -0
- models/vit_baseline.py +239 -0
- requirements.txt +5 -1
- utils/__init__.py +6 -0
- utils/data_utils/__init__.py +5 -0
- utils/data_utils/class_balanced_distributed_sampler.py +100 -0
- utils/data_utils/class_balanced_sampler.py +31 -0
- utils/data_utils/dataset_utils.py +161 -0
- utils/data_utils/reversible_affine_transform.py +82 -0
- utils/data_utils/transform_utils.py +118 -0
- utils/get_landmark_coordinates.py +41 -0
- utils/misc_utils.py +135 -0
- utils/visualize_att_maps.py +135 -0
.gitignore
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# editor settings
|
| 2 |
+
.idea
|
| 3 |
+
.vscode
|
| 4 |
+
_darcs
|
| 5 |
+
|
| 6 |
+
# compilation and distribution
|
| 7 |
+
__pycache__
|
| 8 |
+
_ext
|
| 9 |
+
*.pyc
|
| 10 |
+
*.pyd
|
| 11 |
+
*.so
|
| 12 |
+
*.dll
|
| 13 |
+
*.egg-info/
|
| 14 |
+
build/
|
| 15 |
+
dist/
|
| 16 |
+
wheels/
|
| 17 |
+
|
| 18 |
+
# pytorch/python/numpy formats
|
| 19 |
+
*.pth
|
| 20 |
+
*.pkl
|
| 21 |
+
*.npy
|
| 22 |
+
*.ts
|
| 23 |
+
*.pt
|
| 24 |
+
|
| 25 |
+
# ipython/jupyter notebooks
|
| 26 |
+
*.ipynb
|
| 27 |
+
**/.ipynb_checkpoints/
|
| 28 |
+
|
| 29 |
+
# Editor temporaries
|
| 30 |
+
*.swn
|
| 31 |
+
*.swo
|
| 32 |
+
*.swp
|
| 33 |
+
*~
|
| 34 |
+
|
| 35 |
+
# Results temporary
|
| 36 |
+
*.png
|
| 37 |
+
*.txt
|
| 38 |
+
*.tsv
|
| 39 |
+
wandb/
|
| 40 |
+
exps/
|
files/images/Laysan_Albatross_0050_870.jpg
ADDED
|
layers/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .transformer_layers import *
|
| 2 |
+
from .independent_mlp import *
|
layers/independent_mlp.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file contains the implementation of the IndependentMLPs class
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class IndependentMLPs(torch.nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
This class implements the MLP used for classification with the option to use an additional independent MLP layer
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(self, part_dim, latent_dim, bias=False, num_lin_layers=1, act_layer=True, out_dim=None, stack_dim=-1):
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
:param part_dim: Number of parts
|
| 14 |
+
:param latent_dim: Latent dimension
|
| 15 |
+
:param bias: Whether to use bias
|
| 16 |
+
:param num_lin_layers: Number of linear layers
|
| 17 |
+
:param act_layer: Whether to use activation layer
|
| 18 |
+
:param out_dim: Output dimension (default: None)
|
| 19 |
+
:param stack_dim: Dimension to stack the outputs (default: -1)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
self.bias = bias
|
| 25 |
+
self.latent_dim = latent_dim
|
| 26 |
+
if out_dim is None:
|
| 27 |
+
out_dim = latent_dim
|
| 28 |
+
self.out_dim = out_dim
|
| 29 |
+
self.part_dim = part_dim
|
| 30 |
+
self.stack_dim = stack_dim
|
| 31 |
+
|
| 32 |
+
layer_stack = torch.nn.ModuleList()
|
| 33 |
+
for i in range(part_dim):
|
| 34 |
+
layer_stack.append(torch.nn.Sequential())
|
| 35 |
+
for j in range(num_lin_layers):
|
| 36 |
+
layer_stack[i].add_module(f"fc_{j}", torch.nn.Linear(latent_dim, self.out_dim, bias=bias))
|
| 37 |
+
if act_layer:
|
| 38 |
+
layer_stack[i].add_module(f"act_{j}", torch.nn.GELU())
|
| 39 |
+
self.feature_layers = layer_stack
|
| 40 |
+
self.reset_weights()
|
| 41 |
+
|
| 42 |
+
def __repr__(self):
|
| 43 |
+
return f"IndependentMLPs(part_dim={self.part_dim}, latent_dim={self.latent_dim}), bias={self.bias}"
|
| 44 |
+
|
| 45 |
+
def reset_weights(self):
|
| 46 |
+
""" Initialize weights with a identity matrix"""
|
| 47 |
+
for layer in self.feature_layers:
|
| 48 |
+
for m in layer.modules():
|
| 49 |
+
if isinstance(m, torch.nn.Linear):
|
| 50 |
+
# Initialize weights with a truncated normal distribution
|
| 51 |
+
torch.nn.init.trunc_normal_(m.weight, std=0.02)
|
| 52 |
+
if m.bias is not None:
|
| 53 |
+
torch.nn.init.zeros_(m.bias)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
""" Input X has the dimensions batch x latent_dim x part_dim """
|
| 57 |
+
|
| 58 |
+
outputs = []
|
| 59 |
+
for i, layer in enumerate(self.feature_layers):
|
| 60 |
+
if self.stack_dim == -1:
|
| 61 |
+
in_ = x[..., i]
|
| 62 |
+
else:
|
| 63 |
+
in_ = x[:, i, ...] # Select feature i
|
| 64 |
+
out = layer(in_) # Apply MLP to feature i
|
| 65 |
+
outputs.append(out)
|
| 66 |
+
|
| 67 |
+
x = torch.stack(outputs, dim=self.stack_dim) # Stack the outputs
|
| 68 |
+
|
| 69 |
+
return x
|
layers/transformer_layers.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Attention Block with option to return the mean of k over heads from attention
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from timm.models.vision_transformer import Attention, Block
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AttentionWQKVReturn(Attention):
|
| 10 |
+
"""
|
| 11 |
+
Modifications:
|
| 12 |
+
- Return the qkv tensors from the attention
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 16 |
+
B, N, C = x.shape
|
| 17 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 18 |
+
q, k, v = qkv.unbind(0)
|
| 19 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 20 |
+
|
| 21 |
+
if self.fused_attn:
|
| 22 |
+
x = F.scaled_dot_product_attention(
|
| 23 |
+
q, k, v,
|
| 24 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
| 25 |
+
)
|
| 26 |
+
else:
|
| 27 |
+
q = q * self.scale
|
| 28 |
+
attn = q @ k.transpose(-2, -1)
|
| 29 |
+
attn = attn.softmax(dim=-1)
|
| 30 |
+
attn = self.attn_drop(attn)
|
| 31 |
+
x = attn @ v
|
| 32 |
+
|
| 33 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 34 |
+
x = self.proj(x)
|
| 35 |
+
x = self.proj_drop(x)
|
| 36 |
+
return x, torch.stack((q, k, v), dim=0)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BlockWQKVReturn(Block):
|
| 40 |
+
"""
|
| 41 |
+
Modifications:
|
| 42 |
+
- Use AttentionWQKVReturn instead of Attention
|
| 43 |
+
- Return the qkv tensors from the attention
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def forward(self, x: torch.Tensor, return_qkv: bool = False) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
|
| 47 |
+
# Note: this is copied from timm.models.vision_transformer.Block with modifications.
|
| 48 |
+
x_attn, qkv = self.attn(self.norm1(x))
|
| 49 |
+
x = x + self.drop_path1(self.ls1(x_attn))
|
| 50 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
| 51 |
+
if return_qkv:
|
| 52 |
+
return x, qkv
|
| 53 |
+
else:
|
| 54 |
+
return x
|
load_model.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from timm.models import create_model
|
| 7 |
+
from torchvision.models import get_model
|
| 8 |
+
|
| 9 |
+
from models import pdiscoformer_vit_bb, pdisconet_vit_bb, pdisconet_resnet_torchvision_bb
|
| 10 |
+
from models.individual_landmark_resnet import IndividualLandmarkResNet
|
| 11 |
+
from models.individual_landmark_convnext import IndividualLandmarkConvNext
|
| 12 |
+
from models.individual_landmark_vit import IndividualLandmarkViT
|
| 13 |
+
from utils import load_state_dict_pdisco
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_model_arch(args, num_cls):
|
| 17 |
+
"""
|
| 18 |
+
Function to load the model
|
| 19 |
+
:param args: Arguments from the command line
|
| 20 |
+
:param num_cls: Number of classes in the dataset
|
| 21 |
+
:return:
|
| 22 |
+
"""
|
| 23 |
+
if 'resnet' in args.model_arch:
|
| 24 |
+
num_layers_split = [int(s) for s in args.model_arch if s.isdigit()]
|
| 25 |
+
num_layers = int(''.join(map(str, num_layers_split)))
|
| 26 |
+
if num_layers >= 100:
|
| 27 |
+
timm_model_arch = args.model_arch + ".a1h_in1k"
|
| 28 |
+
else:
|
| 29 |
+
timm_model_arch = args.model_arch + ".a1_in1k"
|
| 30 |
+
|
| 31 |
+
if "resnet" in args.model_arch and args.use_torchvision_resnet_model:
|
| 32 |
+
weights = "DEFAULT" if args.pretrained_start_weights else None
|
| 33 |
+
base_model = get_model(args.model_arch, weights=weights)
|
| 34 |
+
elif "resnet" in args.model_arch and not args.use_torchvision_resnet_model:
|
| 35 |
+
if args.eval_only:
|
| 36 |
+
base_model = create_model(
|
| 37 |
+
timm_model_arch,
|
| 38 |
+
pretrained=args.pretrained_start_weights,
|
| 39 |
+
num_classes=num_cls,
|
| 40 |
+
output_stride=args.output_stride,
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
base_model = create_model(
|
| 44 |
+
timm_model_arch,
|
| 45 |
+
pretrained=args.pretrained_start_weights,
|
| 46 |
+
drop_path_rate=args.drop_path,
|
| 47 |
+
num_classes=num_cls,
|
| 48 |
+
output_stride=args.output_stride,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
elif "convnext" in args.model_arch:
|
| 52 |
+
if args.eval_only:
|
| 53 |
+
base_model = create_model(
|
| 54 |
+
args.model_arch,
|
| 55 |
+
pretrained=args.pretrained_start_weights,
|
| 56 |
+
num_classes=num_cls,
|
| 57 |
+
output_stride=args.output_stride,
|
| 58 |
+
)
|
| 59 |
+
else:
|
| 60 |
+
base_model = create_model(
|
| 61 |
+
args.model_arch,
|
| 62 |
+
pretrained=args.pretrained_start_weights,
|
| 63 |
+
drop_path_rate=args.drop_path,
|
| 64 |
+
num_classes=num_cls,
|
| 65 |
+
output_stride=args.output_stride,
|
| 66 |
+
)
|
| 67 |
+
elif "vit" in args.model_arch:
|
| 68 |
+
if args.eval_only:
|
| 69 |
+
base_model = create_model(
|
| 70 |
+
args.model_arch,
|
| 71 |
+
pretrained=args.pretrained_start_weights,
|
| 72 |
+
img_size=args.image_size,
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
base_model = create_model(
|
| 76 |
+
args.model_arch,
|
| 77 |
+
pretrained=args.pretrained_start_weights,
|
| 78 |
+
drop_path_rate=args.drop_path,
|
| 79 |
+
img_size=args.image_size,
|
| 80 |
+
)
|
| 81 |
+
vit_patch_size = base_model.patch_embed.proj.kernel_size[0]
|
| 82 |
+
if args.image_size % vit_patch_size != 0:
|
| 83 |
+
raise ValueError(f"Image size {args.image_size} must be divisible by patch size {vit_patch_size}")
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError('Model not supported.')
|
| 86 |
+
|
| 87 |
+
return base_model
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def init_pdisco_model(base_model, args, num_cls):
|
| 91 |
+
"""
|
| 92 |
+
Function to initialize the model
|
| 93 |
+
:param base_model: Base model
|
| 94 |
+
:param args: Arguments from the command line
|
| 95 |
+
:param num_cls: Number of classes in the dataset
|
| 96 |
+
:return:
|
| 97 |
+
"""
|
| 98 |
+
# Initialize the network
|
| 99 |
+
if 'convnext' in args.model_arch:
|
| 100 |
+
sl_channels = base_model.stages[-1].downsample[-1].in_channels
|
| 101 |
+
fl_channels = base_model.head.in_features
|
| 102 |
+
model = IndividualLandmarkConvNext(base_model, args.num_parts, num_classes=num_cls,
|
| 103 |
+
sl_channels=sl_channels, fl_channels=fl_channels,
|
| 104 |
+
part_dropout=args.part_dropout, modulation_type=args.modulation_type,
|
| 105 |
+
gumbel_softmax=args.gumbel_softmax,
|
| 106 |
+
gumbel_softmax_temperature=args.gumbel_softmax_temperature,
|
| 107 |
+
gumbel_softmax_hard=args.gumbel_softmax_hard,
|
| 108 |
+
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
|
| 109 |
+
noise_variance=args.noise_variance)
|
| 110 |
+
elif 'resnet' in args.model_arch:
|
| 111 |
+
sl_channels = base_model.layer4[0].conv1.in_channels
|
| 112 |
+
fl_channels = base_model.fc.in_features
|
| 113 |
+
model = IndividualLandmarkResNet(base_model, args.num_parts, num_classes=num_cls,
|
| 114 |
+
sl_channels=sl_channels, fl_channels=fl_channels,
|
| 115 |
+
use_torchvision_model=args.use_torchvision_resnet_model,
|
| 116 |
+
part_dropout=args.part_dropout, modulation_type=args.modulation_type,
|
| 117 |
+
gumbel_softmax=args.gumbel_softmax,
|
| 118 |
+
gumbel_softmax_temperature=args.gumbel_softmax_temperature,
|
| 119 |
+
gumbel_softmax_hard=args.gumbel_softmax_hard,
|
| 120 |
+
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
|
| 121 |
+
noise_variance=args.noise_variance)
|
| 122 |
+
elif 'vit' in args.model_arch:
|
| 123 |
+
model = IndividualLandmarkViT(base_model, num_landmarks=args.num_parts, num_classes=num_cls,
|
| 124 |
+
part_dropout=args.part_dropout,
|
| 125 |
+
modulation_type=args.modulation_type, gumbel_softmax=args.gumbel_softmax,
|
| 126 |
+
gumbel_softmax_temperature=args.gumbel_softmax_temperature,
|
| 127 |
+
gumbel_softmax_hard=args.gumbel_softmax_hard,
|
| 128 |
+
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
|
| 129 |
+
noise_variance=args.noise_variance)
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError('Model not supported.')
|
| 132 |
+
|
| 133 |
+
return model
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_model_pdisco(args, num_cls):
|
| 137 |
+
"""
|
| 138 |
+
Function to load the model
|
| 139 |
+
:param args: Arguments from the command line
|
| 140 |
+
:param num_cls: Number of classes in the dataset
|
| 141 |
+
:return:
|
| 142 |
+
"""
|
| 143 |
+
base_model = load_model_arch(args, num_cls)
|
| 144 |
+
model = init_pdisco_model(base_model, args, num_cls)
|
| 145 |
+
|
| 146 |
+
return model
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def pdiscoformer_vit(pretrained=True, model_dataset="cub", k=8, model_url="", img_size=224, num_cls=200):
|
| 150 |
+
"""
|
| 151 |
+
Function to load the PDiscoFormer model with ViT backbone
|
| 152 |
+
:param pretrained: Boolean flag to load the pretrained weights
|
| 153 |
+
:param model_dataset: Dataset for which the model is trained
|
| 154 |
+
:param k: Number of unsupervised landmarks the model is trained on
|
| 155 |
+
:param model_url: URL to load the model weights from
|
| 156 |
+
:param img_size: Image size
|
| 157 |
+
:param num_cls: Number of classes in the dataset
|
| 158 |
+
:return: PDiscoFormer model with ViT backbone
|
| 159 |
+
"""
|
| 160 |
+
model = pdiscoformer_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
|
| 161 |
+
if pretrained:
|
| 162 |
+
hub_dir = torch.hub.get_dir()
|
| 163 |
+
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdiscoformer_{model_dataset}")
|
| 164 |
+
|
| 165 |
+
Path(model_dir).mkdir(parents=True, exist_ok=True)
|
| 166 |
+
url_path = model_url + str(k) + "_parts_snapshot_best.pt"
|
| 167 |
+
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
|
| 168 |
+
if 'model_state' in snapshot_data:
|
| 169 |
+
_, state_dict = load_state_dict_pdisco(snapshot_data)
|
| 170 |
+
else:
|
| 171 |
+
state_dict = copy.deepcopy(snapshot_data)
|
| 172 |
+
model.load_state_dict(state_dict, strict=True)
|
| 173 |
+
return model
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def pdisconet_vit(pretrained=True, model_dataset="nabirds", k=8, model_url="", img_size=224, num_cls=555):
|
| 177 |
+
"""
|
| 178 |
+
Function to load the PDiscoNet model with ViT backbone
|
| 179 |
+
:param pretrained: Boolean flag to load the pretrained weights
|
| 180 |
+
:param model_dataset: Dataset for which the model is trained
|
| 181 |
+
:param k: Number of unsupervised landmarks the model is trained on
|
| 182 |
+
:param model_url: URL to load the model weights from
|
| 183 |
+
:param img_size: Image size
|
| 184 |
+
:param num_cls: Number of classes in the dataset
|
| 185 |
+
:return: PDiscoNet model with ViT backbone
|
| 186 |
+
"""
|
| 187 |
+
model = pdisconet_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
|
| 188 |
+
if pretrained:
|
| 189 |
+
hub_dir = torch.hub.get_dir()
|
| 190 |
+
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}")
|
| 191 |
+
|
| 192 |
+
Path(model_dir).mkdir(parents=True, exist_ok=True)
|
| 193 |
+
url_path = model_url + str(k) + "_parts_snapshot_best.pt"
|
| 194 |
+
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
|
| 195 |
+
if 'model_state' in snapshot_data:
|
| 196 |
+
_, state_dict = load_state_dict_pdisco(snapshot_data)
|
| 197 |
+
else:
|
| 198 |
+
state_dict = copy.deepcopy(snapshot_data)
|
| 199 |
+
model.load_state_dict(state_dict, strict=True)
|
| 200 |
+
return model
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def pdisconet_resnet101(pretrained=True, model_dataset="nabirds", k=8, model_url="", num_cls=555):
|
| 204 |
+
"""
|
| 205 |
+
Function to load the PDiscoNet model with ResNet-101 backbone
|
| 206 |
+
:param pretrained: Boolean flag to load the pretrained weights
|
| 207 |
+
:param model_dataset: Dataset for which the model is trained
|
| 208 |
+
:param k: Number of unsupervised landmarks the model is trained on
|
| 209 |
+
:param model_url: URL to load the model weights from
|
| 210 |
+
:param num_cls: Number of classes in the dataset
|
| 211 |
+
:return: PDiscoNet model with ResNet-101 backbone
|
| 212 |
+
"""
|
| 213 |
+
model = pdisconet_resnet_torchvision_bb("resnet101", num_cls=num_cls, k=k)
|
| 214 |
+
if pretrained:
|
| 215 |
+
hub_dir = torch.hub.get_dir()
|
| 216 |
+
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}")
|
| 217 |
+
|
| 218 |
+
Path(model_dir).mkdir(parents=True, exist_ok=True)
|
| 219 |
+
url_path = model_url + str(k) + "_parts_snapshot_best.pt"
|
| 220 |
+
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
|
| 221 |
+
if 'model_state' in snapshot_data:
|
| 222 |
+
_, state_dict = load_state_dict_pdisco(snapshot_data)
|
| 223 |
+
else:
|
| 224 |
+
state_dict = copy.deepcopy(snapshot_data)
|
| 225 |
+
model.load_state_dict(state_dict, strict=True)
|
| 226 |
+
return model
|
models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .individual_landmark_resnet import *
|
| 2 |
+
from .individual_landmark_convnext import *
|
| 3 |
+
from .vit_baseline import *
|
| 4 |
+
from .individual_landmark_vit import *
|
models/individual_landmark_convnext.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
from torch.nn import Parameter
|
| 4 |
+
from typing import Any
|
| 5 |
+
from layers.independent_mlp import IndependentMLPs
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Baseline model, a modified convnext with reduced downsampling for a spatially larger feature tensor in the last layer
|
| 9 |
+
class IndividualLandmarkConvNext(torch.nn.Module):
|
| 10 |
+
def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8,
|
| 11 |
+
num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048, part_dropout: float = 0.3,
|
| 12 |
+
modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False,
|
| 13 |
+
gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
|
| 14 |
+
classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
self.num_landmarks = num_landmarks
|
| 18 |
+
self.num_classes = num_classes
|
| 19 |
+
self.noise_variance = noise_variance
|
| 20 |
+
self.stem = init_model.stem
|
| 21 |
+
self.stages = init_model.stages
|
| 22 |
+
self.feature_dim = sl_channels + fl_channels
|
| 23 |
+
self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
|
| 24 |
+
self.gumbel_softmax = gumbel_softmax
|
| 25 |
+
self.gumbel_softmax_temperature = gumbel_softmax_temperature
|
| 26 |
+
self.gumbel_softmax_hard = gumbel_softmax_hard
|
| 27 |
+
self.modulation_type = modulation_type
|
| 28 |
+
if modulation_type == "layer_norm":
|
| 29 |
+
self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
|
| 30 |
+
elif modulation_type == "original":
|
| 31 |
+
self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
|
| 32 |
+
elif modulation_type == "parallel_mlp":
|
| 33 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
| 34 |
+
num_lin_layers=1, act_layer=True, bias=True)
|
| 35 |
+
elif modulation_type == "parallel_mlp_no_bias":
|
| 36 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
| 37 |
+
num_lin_layers=1, act_layer=True, bias=False)
|
| 38 |
+
elif modulation_type == "parallel_mlp_no_act":
|
| 39 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
| 40 |
+
num_lin_layers=1, act_layer=False, bias=True)
|
| 41 |
+
elif modulation_type == "parallel_mlp_no_act_no_bias":
|
| 42 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
| 43 |
+
num_lin_layers=1, act_layer=False, bias=False)
|
| 44 |
+
elif modulation_type == "none":
|
| 45 |
+
self.modulation = torch.nn.Identity()
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError("modulation_type not implemented")
|
| 48 |
+
self.modulation_orth = modulation_orth
|
| 49 |
+
self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
|
| 50 |
+
self.classifier_type = classifier_type
|
| 51 |
+
if classifier_type == "independent_mlp":
|
| 52 |
+
self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
|
| 53 |
+
num_lin_layers=1, act_layer=False, out_dim=num_classes,
|
| 54 |
+
bias=False, stack_dim=1)
|
| 55 |
+
elif classifier_type == "linear":
|
| 56 |
+
self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
|
| 57 |
+
bias=False)
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError("classifier_type not implemented")
|
| 60 |
+
|
| 61 |
+
def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]:
|
| 62 |
+
# Pretrained ConvNeXt part of the model
|
| 63 |
+
x = self.stem(x)
|
| 64 |
+
x = self.stages[0](x)
|
| 65 |
+
x = self.stages[1](x)
|
| 66 |
+
l3 = self.stages[2](x)
|
| 67 |
+
x = self.stages[3](l3)
|
| 68 |
+
x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False)
|
| 69 |
+
x = torch.cat((x, l3), dim=1)
|
| 70 |
+
|
| 71 |
+
# Compute per landmark attention maps
|
| 72 |
+
# (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel
|
| 73 |
+
batch_size = x.shape[0]
|
| 74 |
+
ab = self.fc_landmarks(x)
|
| 75 |
+
b_sq = x.pow(2).sum(1, keepdim=True)
|
| 76 |
+
b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
|
| 77 |
+
a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2],
|
| 78 |
+
x.shape[-1]).contiguous()
|
| 79 |
+
a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
|
| 80 |
+
|
| 81 |
+
dist = b_sq - 2 * ab + a_sq
|
| 82 |
+
maps = -dist
|
| 83 |
+
|
| 84 |
+
# Softmax so that the attention maps for each pixel add up to 1
|
| 85 |
+
if self.gumbel_softmax:
|
| 86 |
+
maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
|
| 87 |
+
hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W]
|
| 88 |
+
else:
|
| 89 |
+
maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W]
|
| 90 |
+
|
| 91 |
+
# Use maps to get weighted average features per landmark
|
| 92 |
+
all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous()
|
| 93 |
+
if self.noise_variance > 0.0:
|
| 94 |
+
all_features += torch.randn_like(all_features,
|
| 95 |
+
device=all_features.device) * x.std().detach() * self.noise_variance
|
| 96 |
+
|
| 97 |
+
# Modulate the features
|
| 98 |
+
if self.modulation_type == "original":
|
| 99 |
+
all_features_mod = all_features * self.modulation
|
| 100 |
+
else:
|
| 101 |
+
all_features_mod = self.modulation(all_features)
|
| 102 |
+
|
| 103 |
+
# Classification based on the landmark features
|
| 104 |
+
scores = self.fc_class_landmarks(
|
| 105 |
+
self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
|
| 106 |
+
1).contiguous()
|
| 107 |
+
if self.modulation_orth:
|
| 108 |
+
return all_features_mod, maps, scores, dist
|
| 109 |
+
else:
|
| 110 |
+
return all_features, maps, scores, dist
|
models/individual_landmark_resnet.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/robertdvdk/part_detection/blob/main/nets.py
|
| 2 |
+
import torch
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
from timm.models import create_model
|
| 5 |
+
from torchvision.models import get_model
|
| 6 |
+
from torch.nn import Parameter
|
| 7 |
+
from typing import Any
|
| 8 |
+
from layers.independent_mlp import IndependentMLPs
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Baseline model, a modified ResNet with reduced downsampling for a spatially larger feature tensor in the last layer
|
| 12 |
+
class IndividualLandmarkResNet(torch.nn.Module):
|
| 13 |
+
def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8,
|
| 14 |
+
num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048,
|
| 15 |
+
use_torchvision_model: bool = False, part_dropout: float = 0.3,
|
| 16 |
+
modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False,
|
| 17 |
+
gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
|
| 18 |
+
classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.num_landmarks = num_landmarks
|
| 22 |
+
self.num_classes = num_classes
|
| 23 |
+
self.noise_variance = noise_variance
|
| 24 |
+
self.conv1 = init_model.conv1
|
| 25 |
+
self.bn1 = init_model.bn1
|
| 26 |
+
if use_torchvision_model:
|
| 27 |
+
self.act1 = init_model.relu
|
| 28 |
+
else:
|
| 29 |
+
self.act1 = init_model.act1
|
| 30 |
+
self.maxpool = init_model.maxpool
|
| 31 |
+
self.layer1 = init_model.layer1
|
| 32 |
+
self.layer2 = init_model.layer2
|
| 33 |
+
self.layer3 = init_model.layer3
|
| 34 |
+
self.layer4 = init_model.layer4
|
| 35 |
+
self.feature_dim = sl_channels + fl_channels
|
| 36 |
+
self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
|
| 37 |
+
self.gumbel_softmax = gumbel_softmax
|
| 38 |
+
self.gumbel_softmax_temperature = gumbel_softmax_temperature
|
| 39 |
+
self.gumbel_softmax_hard = gumbel_softmax_hard
|
| 40 |
+
self.modulation_type = modulation_type
|
| 41 |
+
if modulation_type == "layer_norm":
|
| 42 |
+
self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
|
| 43 |
+
elif modulation_type == "original":
|
| 44 |
+
self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
|
| 45 |
+
elif modulation_type == "parallel_mlp":
|
| 46 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
| 47 |
+
num_lin_layers=1, act_layer=True, bias=True)
|
| 48 |
+
elif modulation_type == "parallel_mlp_no_bias":
|
| 49 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
| 50 |
+
num_lin_layers=1, act_layer=True, bias=False)
|
| 51 |
+
elif modulation_type == "parallel_mlp_no_act":
|
| 52 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
| 53 |
+
num_lin_layers=1, act_layer=False, bias=True)
|
| 54 |
+
elif modulation_type == "parallel_mlp_no_act_no_bias":
|
| 55 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
| 56 |
+
num_lin_layers=1, act_layer=False, bias=False)
|
| 57 |
+
elif modulation_type == "none":
|
| 58 |
+
self.modulation = torch.nn.Identity()
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError("modulation_type not implemented")
|
| 61 |
+
|
| 62 |
+
self.modulation_orth = modulation_orth
|
| 63 |
+
|
| 64 |
+
self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
|
| 65 |
+
self.classifier_type = classifier_type
|
| 66 |
+
if classifier_type == "independent_mlp":
|
| 67 |
+
self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
|
| 68 |
+
num_lin_layers=1, act_layer=False, out_dim=num_classes,
|
| 69 |
+
bias=False, stack_dim=1)
|
| 70 |
+
elif classifier_type == "linear":
|
| 71 |
+
self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
|
| 72 |
+
bias=False)
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError("classifier_type not implemented")
|
| 75 |
+
|
| 76 |
+
def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]:
|
| 77 |
+
# Pretrained ResNet part of the model
|
| 78 |
+
x = self.conv1(x)
|
| 79 |
+
x = self.bn1(x)
|
| 80 |
+
x = self.act1(x)
|
| 81 |
+
x = self.maxpool(x)
|
| 82 |
+
x = self.layer1(x)
|
| 83 |
+
x = self.layer2(x)
|
| 84 |
+
l3 = self.layer3(x)
|
| 85 |
+
x = self.layer4(l3)
|
| 86 |
+
x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False)
|
| 87 |
+
x = torch.cat((x, l3), dim=1)
|
| 88 |
+
|
| 89 |
+
# Compute per landmark attention maps
|
| 90 |
+
# (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel
|
| 91 |
+
batch_size = x.shape[0]
|
| 92 |
+
|
| 93 |
+
ab = self.fc_landmarks(x)
|
| 94 |
+
b_sq = x.pow(2).sum(1, keepdim=True)
|
| 95 |
+
b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
|
| 96 |
+
a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2],
|
| 97 |
+
x.shape[-1]).contiguous()
|
| 98 |
+
a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
|
| 99 |
+
|
| 100 |
+
dist = b_sq - 2 * ab + a_sq
|
| 101 |
+
maps = -dist
|
| 102 |
+
|
| 103 |
+
# Softmax so that the attention maps for each pixel add up to 1
|
| 104 |
+
if self.gumbel_softmax:
|
| 105 |
+
maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
|
| 106 |
+
hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W]
|
| 107 |
+
else:
|
| 108 |
+
maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W]
|
| 109 |
+
|
| 110 |
+
# Use maps to get weighted average features per landmark
|
| 111 |
+
all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous()
|
| 112 |
+
if self.noise_variance > 0.0:
|
| 113 |
+
all_features += torch.randn_like(all_features,
|
| 114 |
+
device=all_features.device) * x.std().detach() * self.noise_variance
|
| 115 |
+
|
| 116 |
+
# Modulate the features
|
| 117 |
+
if self.modulation_type == "original":
|
| 118 |
+
all_features_mod = all_features * self.modulation
|
| 119 |
+
else:
|
| 120 |
+
all_features_mod = self.modulation(all_features)
|
| 121 |
+
|
| 122 |
+
# Classification based on the landmark features
|
| 123 |
+
scores = self.fc_class_landmarks(
|
| 124 |
+
self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
|
| 125 |
+
1).contiguous()
|
| 126 |
+
if self.modulation_orth:
|
| 127 |
+
return all_features_mod, maps, scores, dist
|
| 128 |
+
else:
|
| 129 |
+
return all_features, maps, scores, dist
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def pdisconet_resnet_torchvision_bb(backbone, num_cls=200, k=8, **kwargs):
|
| 133 |
+
base_model = get_model(backbone)
|
| 134 |
+
return IndividualLandmarkResNet(base_model, num_landmarks=k, num_classes=num_cls,
|
| 135 |
+
modulation_type="original")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def pdisconet_resnet_timm_bb(backbone, num_cls=200, k=8, output_stride=32, **kwargs):
|
| 139 |
+
base_model = create_model(backbone, pretrained=True, output_stride=output_stride)
|
| 140 |
+
return IndividualLandmarkResNet(base_model, num_landmarks=k, num_classes=num_cls,
|
| 141 |
+
modulation_type="original")
|
models/individual_landmark_vit.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from typing import Any, Union, Sequence, Optional, Dict
|
| 8 |
+
|
| 9 |
+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
| 10 |
+
|
| 11 |
+
from timm.models import create_model
|
| 12 |
+
from timm.models.vision_transformer import Block, Attention
|
| 13 |
+
from utils.misc_utils import compute_attention
|
| 14 |
+
|
| 15 |
+
from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn
|
| 16 |
+
from layers.independent_mlp import IndependentMLPs
|
| 17 |
+
|
| 18 |
+
SAFETENSORS_SINGLE_FILE = "model.safetensors"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
|
| 22 |
+
pipeline_tag='image-classification',
|
| 23 |
+
repo_url='https://github.com/ananthu-aniraj/pdiscoformer'):
|
| 24 |
+
|
| 25 |
+
def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8, num_classes: int = 200,
|
| 26 |
+
part_dropout: float = 0.3, return_transformer_qkv: bool = False,
|
| 27 |
+
modulation_type: str = "original", gumbel_softmax: bool = False,
|
| 28 |
+
gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
|
| 29 |
+
modulation_orth: bool = False, classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.num_landmarks = num_landmarks
|
| 32 |
+
self.num_classes = num_classes
|
| 33 |
+
self.noise_variance = noise_variance
|
| 34 |
+
self.num_prefix_tokens = init_model.num_prefix_tokens
|
| 35 |
+
self.num_reg_tokens = init_model.num_reg_tokens
|
| 36 |
+
self.has_class_token = init_model.has_class_token
|
| 37 |
+
self.no_embed_class = init_model.no_embed_class
|
| 38 |
+
self.cls_token = init_model.cls_token
|
| 39 |
+
self.reg_token = init_model.reg_token
|
| 40 |
+
|
| 41 |
+
self.feature_dim = init_model.embed_dim
|
| 42 |
+
self.patch_embed = init_model.patch_embed
|
| 43 |
+
self.pos_embed = init_model.pos_embed
|
| 44 |
+
self.pos_drop = init_model.pos_drop
|
| 45 |
+
self.norm_pre = init_model.norm_pre
|
| 46 |
+
self.blocks = init_model.blocks
|
| 47 |
+
self.norm = init_model.norm
|
| 48 |
+
self.return_transformer_qkv = return_transformer_qkv
|
| 49 |
+
self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0])
|
| 50 |
+
self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1])
|
| 51 |
+
|
| 52 |
+
self.unflatten = nn.Unflatten(1, (self.h_fmap, self.w_fmap))
|
| 53 |
+
self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
|
| 54 |
+
self.gumbel_softmax = gumbel_softmax
|
| 55 |
+
self.gumbel_softmax_temperature = gumbel_softmax_temperature
|
| 56 |
+
self.gumbel_softmax_hard = gumbel_softmax_hard
|
| 57 |
+
self.modulation_type = modulation_type
|
| 58 |
+
if modulation_type == "layer_norm":
|
| 59 |
+
self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
|
| 60 |
+
elif modulation_type == "original":
|
| 61 |
+
self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
|
| 62 |
+
elif modulation_type == "parallel_mlp":
|
| 63 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
| 64 |
+
num_lin_layers=1, act_layer=True, bias=True)
|
| 65 |
+
elif modulation_type == "parallel_mlp_no_bias":
|
| 66 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
| 67 |
+
num_lin_layers=1, act_layer=True, bias=False)
|
| 68 |
+
elif modulation_type == "parallel_mlp_no_act":
|
| 69 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
| 70 |
+
num_lin_layers=1, act_layer=False, bias=True)
|
| 71 |
+
elif modulation_type == "parallel_mlp_no_act_no_bias":
|
| 72 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
| 73 |
+
num_lin_layers=1, act_layer=False, bias=False)
|
| 74 |
+
elif modulation_type == "none":
|
| 75 |
+
self.modulation = torch.nn.Identity()
|
| 76 |
+
else:
|
| 77 |
+
raise ValueError("modulation_type not implemented")
|
| 78 |
+
self.modulation_orth = modulation_orth
|
| 79 |
+
self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
|
| 80 |
+
self.classifier_type = classifier_type
|
| 81 |
+
if classifier_type == "independent_mlp":
|
| 82 |
+
self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
|
| 83 |
+
num_lin_layers=1, act_layer=False, out_dim=num_classes,
|
| 84 |
+
bias=False, stack_dim=1)
|
| 85 |
+
elif classifier_type == "linear":
|
| 86 |
+
self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
|
| 87 |
+
bias=False)
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError("classifier_type not implemented")
|
| 90 |
+
self.convert_blocks_and_attention()
|
| 91 |
+
self._init_weights()
|
| 92 |
+
|
| 93 |
+
def _init_weights_head(self):
|
| 94 |
+
# Initialize weights with a truncated normal distribution
|
| 95 |
+
if self.classifier_type == "independent_mlp":
|
| 96 |
+
self.fc_class_landmarks.reset_weights()
|
| 97 |
+
else:
|
| 98 |
+
torch.nn.init.trunc_normal_(self.fc_class_landmarks.weight, std=0.02)
|
| 99 |
+
if self.fc_class_landmarks.bias is not None:
|
| 100 |
+
torch.nn.init.zeros_(self.fc_class_landmarks.bias)
|
| 101 |
+
|
| 102 |
+
def _init_weights(self):
|
| 103 |
+
self._init_weights_head()
|
| 104 |
+
|
| 105 |
+
def convert_blocks_and_attention(self):
|
| 106 |
+
for module in self.modules():
|
| 107 |
+
if isinstance(module, Block):
|
| 108 |
+
module.__class__ = BlockWQKVReturn
|
| 109 |
+
elif isinstance(module, Attention):
|
| 110 |
+
module.__class__ = AttentionWQKVReturn
|
| 111 |
+
|
| 112 |
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
pos_embed = self.pos_embed
|
| 114 |
+
to_cat = []
|
| 115 |
+
if self.cls_token is not None:
|
| 116 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
| 117 |
+
if self.reg_token is not None:
|
| 118 |
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
| 119 |
+
if self.no_embed_class:
|
| 120 |
+
# deit-3, updated JAX (big vision)
|
| 121 |
+
# position embedding does not overlap with class token, add then concat
|
| 122 |
+
x = x + pos_embed
|
| 123 |
+
if to_cat:
|
| 124 |
+
x = torch.cat(to_cat + [x], dim=1)
|
| 125 |
+
else:
|
| 126 |
+
# original timm, JAX, and deit vit impl
|
| 127 |
+
# pos_embed has entry for class token, concat then add
|
| 128 |
+
if to_cat:
|
| 129 |
+
x = torch.cat(to_cat + [x], dim=1)
|
| 130 |
+
x = x + pos_embed
|
| 131 |
+
return self.pos_drop(x)
|
| 132 |
+
|
| 133 |
+
def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, int | Any] | tuple[Any, Any, Any, Any, int | Any]:
|
| 134 |
+
|
| 135 |
+
x = self.patch_embed(x)
|
| 136 |
+
|
| 137 |
+
# Position Embedding
|
| 138 |
+
x = self._pos_embed(x)
|
| 139 |
+
|
| 140 |
+
# Forward pass through transformer
|
| 141 |
+
x = self.norm_pre(x)
|
| 142 |
+
|
| 143 |
+
x = self.blocks(x)
|
| 144 |
+
x = self.norm(x)
|
| 145 |
+
|
| 146 |
+
# Compute per landmark attention maps
|
| 147 |
+
# (b - a)^2 = b^2 - 2ab + a^2, b = feature maps vit, a = convolution kernel
|
| 148 |
+
batch_size = x.shape[0]
|
| 149 |
+
x = x[:, self.num_prefix_tokens:, :] # [B, num_patch_tokens, embed_dim]
|
| 150 |
+
x = self.unflatten(x) # [B, H, W, embed_dim]
|
| 151 |
+
x = x.permute(0, 3, 1, 2).contiguous() # [B, embed_dim, H, W]
|
| 152 |
+
ab = self.fc_landmarks(x) # [B, num_landmarks + 1, H, W]
|
| 153 |
+
b_sq = x.pow(2).sum(1, keepdim=True)
|
| 154 |
+
b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
|
| 155 |
+
a_sq = self.fc_landmarks.weight.pow(2).sum(1, keepdim=True).expand(-1, batch_size, x.shape[-2],
|
| 156 |
+
x.shape[-1]).contiguous()
|
| 157 |
+
a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
|
| 158 |
+
|
| 159 |
+
dist = b_sq - 2 * ab + a_sq
|
| 160 |
+
maps = -dist
|
| 161 |
+
|
| 162 |
+
# Softmax so that the attention maps for each pixel add up to 1
|
| 163 |
+
if self.gumbel_softmax:
|
| 164 |
+
maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
|
| 165 |
+
hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W]
|
| 166 |
+
else:
|
| 167 |
+
maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W]
|
| 168 |
+
|
| 169 |
+
# Use maps to get weighted average features per landmark
|
| 170 |
+
all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).contiguous()
|
| 171 |
+
if self.noise_variance > 0.0:
|
| 172 |
+
all_features += torch.randn_like(all_features,
|
| 173 |
+
device=all_features.device) * x.std().detach() * self.noise_variance
|
| 174 |
+
|
| 175 |
+
all_features = all_features.mean(-1).mean(-1).contiguous() # [B, embed_dim, num_landmarks + 1]
|
| 176 |
+
|
| 177 |
+
# Modulate the features
|
| 178 |
+
if self.modulation_type == "original":
|
| 179 |
+
all_features_mod = all_features * self.modulation # [B, embed_dim, num_landmarks + 1]
|
| 180 |
+
else:
|
| 181 |
+
all_features_mod = self.modulation(all_features) # [B, embed_dim, num_landmarks + 1]
|
| 182 |
+
|
| 183 |
+
# Classification based on the landmark features
|
| 184 |
+
scores = self.fc_class_landmarks(
|
| 185 |
+
self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
|
| 186 |
+
1).contiguous()
|
| 187 |
+
if self.modulation_orth:
|
| 188 |
+
return all_features_mod, maps, scores, dist
|
| 189 |
+
else:
|
| 190 |
+
return all_features, maps, scores, dist
|
| 191 |
+
|
| 192 |
+
def get_specific_intermediate_layer(
|
| 193 |
+
self,
|
| 194 |
+
x: torch.Tensor,
|
| 195 |
+
n: int = 1,
|
| 196 |
+
return_qkv: bool = False,
|
| 197 |
+
return_att_weights: bool = False,
|
| 198 |
+
):
|
| 199 |
+
num_blocks = len(self.blocks)
|
| 200 |
+
attn_weights = []
|
| 201 |
+
if n >= num_blocks:
|
| 202 |
+
raise ValueError(f"n must be less than {num_blocks}")
|
| 203 |
+
|
| 204 |
+
# forward pass
|
| 205 |
+
x = self.patch_embed(x)
|
| 206 |
+
x = self._pos_embed(x)
|
| 207 |
+
x = self.norm_pre(x)
|
| 208 |
+
|
| 209 |
+
if n == -1:
|
| 210 |
+
if return_qkv:
|
| 211 |
+
raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True")
|
| 212 |
+
else:
|
| 213 |
+
return x
|
| 214 |
+
|
| 215 |
+
for i, blk in enumerate(self.blocks):
|
| 216 |
+
if self.return_transformer_qkv:
|
| 217 |
+
x, qkv = blk(x, return_qkv=True)
|
| 218 |
+
|
| 219 |
+
if return_att_weights:
|
| 220 |
+
attn_weight, _ = compute_attention(qkv)
|
| 221 |
+
attn_weights.append(attn_weight.detach())
|
| 222 |
+
else:
|
| 223 |
+
x = blk(x)
|
| 224 |
+
if i == n:
|
| 225 |
+
output = x.clone()
|
| 226 |
+
if self.return_transformer_qkv and return_qkv:
|
| 227 |
+
qkv_output = qkv.clone()
|
| 228 |
+
break
|
| 229 |
+
if self.return_transformer_qkv and return_qkv and return_att_weights:
|
| 230 |
+
return output, qkv_output, attn_weights
|
| 231 |
+
elif self.return_transformer_qkv and return_qkv:
|
| 232 |
+
return output, qkv_output
|
| 233 |
+
elif self.return_transformer_qkv and return_att_weights:
|
| 234 |
+
return output, attn_weights
|
| 235 |
+
else:
|
| 236 |
+
return output
|
| 237 |
+
|
| 238 |
+
def _intermediate_layers(
|
| 239 |
+
self,
|
| 240 |
+
x: torch.Tensor,
|
| 241 |
+
n: Union[int, Sequence] = 1,
|
| 242 |
+
):
|
| 243 |
+
outputs, num_blocks = [], len(self.blocks)
|
| 244 |
+
if self.return_transformer_qkv:
|
| 245 |
+
qkv_outputs = []
|
| 246 |
+
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
| 247 |
+
|
| 248 |
+
# forward pass
|
| 249 |
+
x = self.patch_embed(x)
|
| 250 |
+
x = self._pos_embed(x)
|
| 251 |
+
x = self.norm_pre(x)
|
| 252 |
+
|
| 253 |
+
for i, blk in enumerate(self.blocks):
|
| 254 |
+
if self.return_transformer_qkv:
|
| 255 |
+
x, qkv = blk(x, return_qkv=True)
|
| 256 |
+
else:
|
| 257 |
+
x = blk(x)
|
| 258 |
+
if i in take_indices:
|
| 259 |
+
outputs.append(x)
|
| 260 |
+
if self.return_transformer_qkv:
|
| 261 |
+
qkv_outputs.append(qkv)
|
| 262 |
+
if self.return_transformer_qkv:
|
| 263 |
+
return outputs, qkv_outputs
|
| 264 |
+
else:
|
| 265 |
+
return outputs
|
| 266 |
+
|
| 267 |
+
def get_intermediate_layers(
|
| 268 |
+
self,
|
| 269 |
+
x: torch.Tensor,
|
| 270 |
+
n: Union[int, Sequence] = 1,
|
| 271 |
+
reshape: bool = False,
|
| 272 |
+
return_prefix_tokens: bool = False,
|
| 273 |
+
norm: bool = False,
|
| 274 |
+
) -> tuple[tuple, Any]:
|
| 275 |
+
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
| 276 |
+
Inspired by DINO / DINOv2 interface
|
| 277 |
+
"""
|
| 278 |
+
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
| 279 |
+
if self.return_transformer_qkv:
|
| 280 |
+
outputs, qkv = self._intermediate_layers(x, n)
|
| 281 |
+
else:
|
| 282 |
+
outputs = self._intermediate_layers(x, n)
|
| 283 |
+
|
| 284 |
+
if norm:
|
| 285 |
+
outputs = [self.norm(out) for out in outputs]
|
| 286 |
+
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
| 287 |
+
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
| 288 |
+
|
| 289 |
+
if reshape:
|
| 290 |
+
grid_size = self.patch_embed.grid_size
|
| 291 |
+
outputs = [
|
| 292 |
+
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
| 293 |
+
for out in outputs
|
| 294 |
+
]
|
| 295 |
+
|
| 296 |
+
if return_prefix_tokens:
|
| 297 |
+
return_out = tuple(zip(outputs, prefix_tokens))
|
| 298 |
+
else:
|
| 299 |
+
return_out = tuple(outputs)
|
| 300 |
+
|
| 301 |
+
if self.return_transformer_qkv:
|
| 302 |
+
return return_out, qkv
|
| 303 |
+
else:
|
| 304 |
+
return return_out
|
| 305 |
+
|
| 306 |
+
@classmethod
|
| 307 |
+
def _from_pretrained(
|
| 308 |
+
cls,
|
| 309 |
+
*,
|
| 310 |
+
model_id: str,
|
| 311 |
+
revision: Optional[str],
|
| 312 |
+
cache_dir: Optional[Union[str, Path]],
|
| 313 |
+
force_download: bool,
|
| 314 |
+
proxies: Optional[Dict],
|
| 315 |
+
resume_download: Optional[bool],
|
| 316 |
+
local_files_only: bool,
|
| 317 |
+
token: Union[str, bool, None],
|
| 318 |
+
map_location: str = "cpu",
|
| 319 |
+
strict: bool = False,
|
| 320 |
+
timm_backbone: str = "hf_hub:timm/vit_base_patch14_reg4_dinov2.lvd142m",
|
| 321 |
+
input_size: int = 518,
|
| 322 |
+
**model_kwargs):
|
| 323 |
+
base_model = create_model(timm_backbone, pretrained=False, img_size=input_size)
|
| 324 |
+
model = cls(base_model, **model_kwargs)
|
| 325 |
+
if os.path.isdir(model_id):
|
| 326 |
+
print("Loading weights from local directory")
|
| 327 |
+
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
| 328 |
+
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
| 329 |
+
else:
|
| 330 |
+
model_file = hf_hub_download(
|
| 331 |
+
repo_id=model_id,
|
| 332 |
+
filename=SAFETENSORS_SINGLE_FILE,
|
| 333 |
+
revision=revision,
|
| 334 |
+
cache_dir=cache_dir,
|
| 335 |
+
force_download=force_download,
|
| 336 |
+
proxies=proxies,
|
| 337 |
+
resume_download=resume_download,
|
| 338 |
+
token=token,
|
| 339 |
+
local_files_only=local_files_only,
|
| 340 |
+
)
|
| 341 |
+
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def pdiscoformer_vit_bb(backbone, img_size=224, num_cls=200, k=8, **kwargs):
|
| 345 |
+
base_model = create_model(
|
| 346 |
+
backbone,
|
| 347 |
+
pretrained=False,
|
| 348 |
+
img_size=img_size,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
model = IndividualLandmarkViT(base_model, num_landmarks=k, num_classes=num_cls,
|
| 352 |
+
modulation_type="layer_norm", gumbel_softmax=True,
|
| 353 |
+
modulation_orth=True)
|
| 354 |
+
return model
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def pdisconet_vit_bb(backbone, img_size=224, num_cls=200, k=8, **kwargs):
|
| 358 |
+
base_model = create_model(
|
| 359 |
+
backbone,
|
| 360 |
+
pretrained=False,
|
| 361 |
+
img_size=img_size,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
model = IndividualLandmarkViT(base_model, num_landmarks=k, num_classes=num_cls,
|
| 365 |
+
modulation_type="original")
|
| 366 |
+
return model
|
models/vit_baseline.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from typing import Tuple, Union, Sequence, Any
|
| 5 |
+
from timm.layers import trunc_normal_
|
| 6 |
+
from timm.models.vision_transformer import Block, Attention
|
| 7 |
+
from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn
|
| 8 |
+
|
| 9 |
+
from utils.misc_utils import compute_attention
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaselineViT(torch.nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Modifications:
|
| 15 |
+
- Use PDiscoBlock instead of Block
|
| 16 |
+
- Use PDiscoAttention instead of Attention
|
| 17 |
+
- Return the mean of k over heads from attention
|
| 18 |
+
- Option to use only class tokens or only patch tokens or both (concat) for classification
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, init_model: torch.nn.Module, num_classes: int,
|
| 22 |
+
class_tokens_only: bool = False,
|
| 23 |
+
patch_tokens_only: bool = False, return_transformer_qkv: bool = False) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.num_classes = num_classes
|
| 26 |
+
self.class_tokens_only = class_tokens_only
|
| 27 |
+
self.patch_tokens_only = patch_tokens_only
|
| 28 |
+
self.num_prefix_tokens = init_model.num_prefix_tokens
|
| 29 |
+
self.num_reg_tokens = init_model.num_reg_tokens
|
| 30 |
+
self.has_class_token = init_model.has_class_token
|
| 31 |
+
self.no_embed_class = init_model.no_embed_class
|
| 32 |
+
self.cls_token = init_model.cls_token
|
| 33 |
+
self.reg_token = init_model.reg_token
|
| 34 |
+
|
| 35 |
+
self.patch_embed = init_model.patch_embed
|
| 36 |
+
|
| 37 |
+
self.pos_embed = init_model.pos_embed
|
| 38 |
+
self.pos_drop = init_model.pos_drop
|
| 39 |
+
self.part_embed = nn.Identity()
|
| 40 |
+
self.patch_prune = nn.Identity()
|
| 41 |
+
self.norm_pre = init_model.norm_pre
|
| 42 |
+
self.blocks = init_model.blocks
|
| 43 |
+
self.norm = init_model.norm
|
| 44 |
+
|
| 45 |
+
self.fc_norm = init_model.fc_norm
|
| 46 |
+
if class_tokens_only or patch_tokens_only:
|
| 47 |
+
self.head = nn.Linear(init_model.embed_dim, num_classes)
|
| 48 |
+
else:
|
| 49 |
+
self.head = nn.Linear(init_model.embed_dim * 2, num_classes)
|
| 50 |
+
|
| 51 |
+
self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0])
|
| 52 |
+
self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1])
|
| 53 |
+
|
| 54 |
+
self.return_transformer_qkv = return_transformer_qkv
|
| 55 |
+
self.convert_blocks_and_attention()
|
| 56 |
+
self._init_weights_head()
|
| 57 |
+
|
| 58 |
+
def convert_blocks_and_attention(self):
|
| 59 |
+
for module in self.modules():
|
| 60 |
+
if isinstance(module, Block):
|
| 61 |
+
module.__class__ = BlockWQKVReturn
|
| 62 |
+
elif isinstance(module, Attention):
|
| 63 |
+
module.__class__ = AttentionWQKVReturn
|
| 64 |
+
|
| 65 |
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
pos_embed = self.pos_embed
|
| 67 |
+
to_cat = []
|
| 68 |
+
if self.cls_token is not None:
|
| 69 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
| 70 |
+
if self.reg_token is not None:
|
| 71 |
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
| 72 |
+
if self.no_embed_class:
|
| 73 |
+
# deit-3, updated JAX (big vision)
|
| 74 |
+
# position embedding does not overlap with class token, add then concat
|
| 75 |
+
x = x + pos_embed
|
| 76 |
+
if to_cat:
|
| 77 |
+
x = torch.cat(to_cat + [x], dim=1)
|
| 78 |
+
else:
|
| 79 |
+
# original timm, JAX, and deit vit impl
|
| 80 |
+
# pos_embed has entry for class token, concat then add
|
| 81 |
+
if to_cat:
|
| 82 |
+
x = torch.cat(to_cat + [x], dim=1)
|
| 83 |
+
x = x + pos_embed
|
| 84 |
+
return self.pos_drop(x)
|
| 85 |
+
|
| 86 |
+
def _init_weights_head(self):
|
| 87 |
+
trunc_normal_(self.head.weight, std=.02)
|
| 88 |
+
if self.head.bias is not None:
|
| 89 |
+
nn.init.constant_(self.head.bias, 0.)
|
| 90 |
+
|
| 91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
|
| 92 |
+
|
| 93 |
+
x = self.patch_embed(x)
|
| 94 |
+
|
| 95 |
+
# Position Embedding
|
| 96 |
+
x = self._pos_embed(x)
|
| 97 |
+
|
| 98 |
+
x = self.part_embed(x)
|
| 99 |
+
x = self.patch_prune(x)
|
| 100 |
+
|
| 101 |
+
# Forward pass through transformer
|
| 102 |
+
x = self.norm_pre(x)
|
| 103 |
+
|
| 104 |
+
if self.return_transformer_qkv:
|
| 105 |
+
# Return keys of last attention layer
|
| 106 |
+
for i, blk in enumerate(self.blocks):
|
| 107 |
+
x, qkv = blk(x, return_qkv=True)
|
| 108 |
+
else:
|
| 109 |
+
x = self.blocks(x)
|
| 110 |
+
|
| 111 |
+
x = self.norm(x)
|
| 112 |
+
|
| 113 |
+
# Classification head
|
| 114 |
+
x = self.fc_norm(x)
|
| 115 |
+
if self.class_tokens_only: # only use class token
|
| 116 |
+
x = x[:, 0, :]
|
| 117 |
+
elif self.patch_tokens_only: # only use patch tokens
|
| 118 |
+
x = x[:, self.num_prefix_tokens:, :].mean(dim=1)
|
| 119 |
+
else:
|
| 120 |
+
x = torch.cat([x[:, 0, :], x[:, self.num_prefix_tokens:, :].mean(dim=1)], dim=1)
|
| 121 |
+
x = self.head(x)
|
| 122 |
+
if self.return_transformer_qkv:
|
| 123 |
+
return x, qkv
|
| 124 |
+
else:
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
def get_specific_intermediate_layer(
|
| 128 |
+
self,
|
| 129 |
+
x: torch.Tensor,
|
| 130 |
+
n: int = 1,
|
| 131 |
+
return_qkv: bool = False,
|
| 132 |
+
return_att_weights: bool = False,
|
| 133 |
+
):
|
| 134 |
+
num_blocks = len(self.blocks)
|
| 135 |
+
attn_weights = []
|
| 136 |
+
if n >= num_blocks:
|
| 137 |
+
raise ValueError(f"n must be less than {num_blocks}")
|
| 138 |
+
|
| 139 |
+
# forward pass
|
| 140 |
+
x = self.patch_embed(x)
|
| 141 |
+
x = self._pos_embed(x)
|
| 142 |
+
x = self.norm_pre(x)
|
| 143 |
+
|
| 144 |
+
if n == -1:
|
| 145 |
+
if return_qkv:
|
| 146 |
+
raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True")
|
| 147 |
+
else:
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
for i, blk in enumerate(self.blocks):
|
| 151 |
+
if self.return_transformer_qkv:
|
| 152 |
+
x, qkv = blk(x, return_qkv=True)
|
| 153 |
+
|
| 154 |
+
if return_att_weights:
|
| 155 |
+
attn_weight, _ = compute_attention(qkv)
|
| 156 |
+
attn_weights.append(attn_weight.detach())
|
| 157 |
+
else:
|
| 158 |
+
x = blk(x)
|
| 159 |
+
if i == n:
|
| 160 |
+
output = x.clone()
|
| 161 |
+
if self.return_transformer_qkv and return_qkv:
|
| 162 |
+
qkv_output = qkv.clone()
|
| 163 |
+
break
|
| 164 |
+
if self.return_transformer_qkv and return_qkv and return_att_weights:
|
| 165 |
+
return output, qkv_output, attn_weights
|
| 166 |
+
elif self.return_transformer_qkv and return_qkv:
|
| 167 |
+
return output, qkv_output
|
| 168 |
+
elif self.return_transformer_qkv and return_att_weights:
|
| 169 |
+
return output, attn_weights
|
| 170 |
+
else:
|
| 171 |
+
return output
|
| 172 |
+
|
| 173 |
+
def _intermediate_layers(
|
| 174 |
+
self,
|
| 175 |
+
x: torch.Tensor,
|
| 176 |
+
n: Union[int, Sequence] = 1,
|
| 177 |
+
):
|
| 178 |
+
outputs, num_blocks = [], len(self.blocks)
|
| 179 |
+
if self.return_transformer_qkv:
|
| 180 |
+
qkv_outputs = []
|
| 181 |
+
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
| 182 |
+
|
| 183 |
+
# forward pass
|
| 184 |
+
x = self.patch_embed(x)
|
| 185 |
+
x = self._pos_embed(x)
|
| 186 |
+
x = self.norm_pre(x)
|
| 187 |
+
|
| 188 |
+
for i, blk in enumerate(self.blocks):
|
| 189 |
+
if self.return_transformer_qkv:
|
| 190 |
+
x, qkv = blk(x, return_qkv=True)
|
| 191 |
+
else:
|
| 192 |
+
x = blk(x)
|
| 193 |
+
if i in take_indices:
|
| 194 |
+
outputs.append(x)
|
| 195 |
+
if self.return_transformer_qkv:
|
| 196 |
+
qkv_outputs.append(qkv)
|
| 197 |
+
if self.return_transformer_qkv:
|
| 198 |
+
return outputs, qkv_outputs
|
| 199 |
+
else:
|
| 200 |
+
return outputs
|
| 201 |
+
|
| 202 |
+
def get_intermediate_layers(
|
| 203 |
+
self,
|
| 204 |
+
x: torch.Tensor,
|
| 205 |
+
n: Union[int, Sequence] = 1,
|
| 206 |
+
reshape: bool = False,
|
| 207 |
+
return_prefix_tokens: bool = False,
|
| 208 |
+
norm: bool = False,
|
| 209 |
+
) -> tuple[tuple, Any]:
|
| 210 |
+
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
| 211 |
+
Inspired by DINO / DINOv2 interface
|
| 212 |
+
"""
|
| 213 |
+
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
| 214 |
+
if self.return_transformer_qkv:
|
| 215 |
+
outputs, qkv = self._intermediate_layers(x, n)
|
| 216 |
+
else:
|
| 217 |
+
outputs = self._intermediate_layers(x, n)
|
| 218 |
+
|
| 219 |
+
if norm:
|
| 220 |
+
outputs = [self.norm(out) for out in outputs]
|
| 221 |
+
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
| 222 |
+
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
| 223 |
+
|
| 224 |
+
if reshape:
|
| 225 |
+
grid_size = self.patch_embed.grid_size
|
| 226 |
+
outputs = [
|
| 227 |
+
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
| 228 |
+
for out in outputs
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
if return_prefix_tokens:
|
| 232 |
+
return_out = tuple(zip(outputs, prefix_tokens))
|
| 233 |
+
else:
|
| 234 |
+
return_out = tuple(outputs)
|
| 235 |
+
|
| 236 |
+
if self.return_transformer_qkv:
|
| 237 |
+
return return_out, qkv
|
| 238 |
+
else:
|
| 239 |
+
return return_out
|
requirements.txt
CHANGED
|
@@ -3,4 +3,8 @@ timm
|
|
| 3 |
colorcet
|
| 4 |
matplotlib
|
| 5 |
torchvision
|
| 6 |
-
streamlit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
colorcet
|
| 4 |
matplotlib
|
| 5 |
torchvision
|
| 6 |
+
streamlit
|
| 7 |
+
numpy
|
| 8 |
+
pillow
|
| 9 |
+
scikit-image
|
| 10 |
+
huggingface-hub
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .data_utils import *
|
| 2 |
+
from .visualize_att_maps import *
|
| 3 |
+
from .misc_utils import *
|
| 4 |
+
from .get_landmark_coordinates import *
|
| 5 |
+
|
| 6 |
+
|
utils/data_utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dataset_utils import *
|
| 2 |
+
from .reversible_affine_transform import *
|
| 3 |
+
from .transform_utils import *
|
| 4 |
+
from .class_balanced_distributed_sampler import *
|
| 5 |
+
from .class_balanced_sampler import *
|
utils/data_utils/class_balanced_distributed_sampler.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
from typing import Optional
|
| 4 |
+
import math
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ClassBalancedDistributedSampler(torch.utils.data.Sampler):
|
| 9 |
+
"""
|
| 10 |
+
A custom sampler that sub-samples a given dataset based on class labels. Based on the DistributedSampler class
|
| 11 |
+
Ref: https://github.com/pytorch/pytorch/blob/04c1df651aa58bea50977f4efcf19b09ce27cefd/torch/utils/data/distributed.py#L13
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None,
|
| 15 |
+
shuffle: bool = True, seed: int = 0, drop_last: bool = False, num_samples_per_class=100) -> None:
|
| 16 |
+
|
| 17 |
+
if not shuffle:
|
| 18 |
+
raise ValueError("ClassBalancedDatasetSubSampler requires shuffling, otherwise use DistributedSampler")
|
| 19 |
+
|
| 20 |
+
# Check if the dataset has a generate_class_balanced_indices method
|
| 21 |
+
if not hasattr(dataset, 'generate_class_balanced_indices'):
|
| 22 |
+
raise ValueError("Dataset does not have a generate_class_balanced_indices method")
|
| 23 |
+
|
| 24 |
+
self.shuffle = shuffle
|
| 25 |
+
self.seed = seed
|
| 26 |
+
if num_replicas is None:
|
| 27 |
+
if not dist.is_available():
|
| 28 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 29 |
+
num_replicas = dist.get_world_size()
|
| 30 |
+
if rank is None:
|
| 31 |
+
if not dist.is_available():
|
| 32 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 33 |
+
rank = dist.get_rank()
|
| 34 |
+
if rank >= num_replicas or rank < 0:
|
| 35 |
+
raise ValueError(
|
| 36 |
+
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
|
| 37 |
+
self.dataset = dataset
|
| 38 |
+
self.num_replicas = num_replicas
|
| 39 |
+
self.rank = rank
|
| 40 |
+
self.epoch = 0
|
| 41 |
+
self.drop_last = drop_last
|
| 42 |
+
|
| 43 |
+
# Calculate the number of samples
|
| 44 |
+
g = torch.Generator()
|
| 45 |
+
g.manual_seed(self.seed + self.epoch)
|
| 46 |
+
self.num_samples_per_class = num_samples_per_class
|
| 47 |
+
indices = dataset.generate_class_balanced_indices(torch.Generator(),
|
| 48 |
+
num_samples_per_class=num_samples_per_class)
|
| 49 |
+
dataset_size = len(indices)
|
| 50 |
+
|
| 51 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
| 52 |
+
# is no need to drop any data, since the dataset will be split equally.
|
| 53 |
+
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
| 54 |
+
# Split to nearest available length that is evenly divisible.
|
| 55 |
+
# This is to ensure each rank receives the same amount of data when
|
| 56 |
+
# using this Sampler.
|
| 57 |
+
self.num_samples = math.ceil(
|
| 58 |
+
(dataset_size - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
| 59 |
+
)
|
| 60 |
+
else:
|
| 61 |
+
self.num_samples = math.ceil(dataset_size / self.num_replicas) # type: ignore[arg-type]
|
| 62 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 63 |
+
|
| 64 |
+
def __iter__(self):
|
| 65 |
+
# deterministically shuffle based on epoch and seed, here shuffle is assumed to be True
|
| 66 |
+
g = torch.Generator()
|
| 67 |
+
g.manual_seed(self.seed + self.epoch)
|
| 68 |
+
indices = self.dataset.generate_class_balanced_indices(g, num_samples_per_class=self.num_samples_per_class)
|
| 69 |
+
|
| 70 |
+
if not self.drop_last:
|
| 71 |
+
# add extra samples to make it evenly divisible
|
| 72 |
+
padding_size = self.total_size - len(indices)
|
| 73 |
+
if padding_size <= len(indices):
|
| 74 |
+
indices += indices[:padding_size]
|
| 75 |
+
else:
|
| 76 |
+
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
| 77 |
+
else:
|
| 78 |
+
# remove tail of data to make it evenly divisible.
|
| 79 |
+
indices = indices[:self.total_size]
|
| 80 |
+
|
| 81 |
+
# subsample
|
| 82 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
| 83 |
+
|
| 84 |
+
return iter(indices)
|
| 85 |
+
|
| 86 |
+
def __len__(self) -> int:
|
| 87 |
+
return self.num_samples
|
| 88 |
+
|
| 89 |
+
def set_epoch(self, epoch: int) -> None:
|
| 90 |
+
r"""
|
| 91 |
+
Set the epoch for this sampler.
|
| 92 |
+
|
| 93 |
+
When :attr:`shuffle=True`, this ensures all replicas
|
| 94 |
+
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
| 95 |
+
sampler will yield the same ordering.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
epoch (int): Epoch number.
|
| 99 |
+
"""
|
| 100 |
+
self.epoch = epoch
|
utils/data_utils/class_balanced_sampler.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ClassBalancedRandomSampler(torch.utils.data.Sampler):
|
| 6 |
+
"""
|
| 7 |
+
A custom sampler that sub-samples a given dataset based on class labels. Based on the RandomSampler class
|
| 8 |
+
This is essentially the non-ddp version of ClassBalancedDistributedSampler
|
| 9 |
+
Ref: https://github.com/pytorch/pytorch/blob/abe3c55a6a01c5b625eeb4fc9aab1421a5965cd2/torch/utils/data/sampler.py#L117
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, dataset: Dataset, num_samples_per_class=100, seed: int = 0) -> None:
|
| 13 |
+
self.dataset = dataset
|
| 14 |
+
self.seed = seed
|
| 15 |
+
# Calculate the number of samples
|
| 16 |
+
self.generator = torch.Generator()
|
| 17 |
+
self.generator.manual_seed(self.seed)
|
| 18 |
+
self.num_samples_per_class = num_samples_per_class
|
| 19 |
+
indices = dataset.generate_class_balanced_indices(self.generator,
|
| 20 |
+
num_samples_per_class=num_samples_per_class)
|
| 21 |
+
self.num_samples = len(indices)
|
| 22 |
+
|
| 23 |
+
def __iter__(self):
|
| 24 |
+
# Change seed for every function call
|
| 25 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 26 |
+
self.generator.manual_seed(seed)
|
| 27 |
+
indices = self.dataset.generate_class_balanced_indices(self.generator, num_samples_per_class=self.num_samples_per_class)
|
| 28 |
+
return iter(indices)
|
| 29 |
+
|
| 30 |
+
def __len__(self) -> int:
|
| 31 |
+
return self.num_samples
|
utils/data_utils/dataset_utils.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torchvision
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_json(path: str):
|
| 10 |
+
"""
|
| 11 |
+
Load json file from path and return the data
|
| 12 |
+
:param path: Path to the json file
|
| 13 |
+
:return:
|
| 14 |
+
data: Data in the json file
|
| 15 |
+
"""
|
| 16 |
+
with open(path, 'r') as f:
|
| 17 |
+
data = json.load(f)
|
| 18 |
+
return data
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def save_json(data: dict, path: str):
|
| 22 |
+
"""
|
| 23 |
+
Save data to a json file
|
| 24 |
+
:param data: Data to be saved
|
| 25 |
+
:param path: Path to save the data
|
| 26 |
+
:return:
|
| 27 |
+
"""
|
| 28 |
+
with open(path, "w") as f:
|
| 29 |
+
json.dump(data, f)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def pil_loader(path):
|
| 33 |
+
"""
|
| 34 |
+
Load image from path using PIL
|
| 35 |
+
:param path: Path to the image
|
| 36 |
+
:return:
|
| 37 |
+
img: PIL Image
|
| 38 |
+
"""
|
| 39 |
+
with open(path, 'rb') as f:
|
| 40 |
+
img = Image.open(f)
|
| 41 |
+
return img.convert('RGB')
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_dimensions(image: Tensor):
|
| 45 |
+
"""
|
| 46 |
+
Get the dimensions of the image
|
| 47 |
+
:param image: Tensor or PIL Image or np.ndarray
|
| 48 |
+
:return:
|
| 49 |
+
h: Height of the image
|
| 50 |
+
w: Width of the image
|
| 51 |
+
"""
|
| 52 |
+
if isinstance(image, Tensor):
|
| 53 |
+
_, h, w = image.shape
|
| 54 |
+
elif isinstance(image, np.ndarray):
|
| 55 |
+
h, w, _ = image.shape
|
| 56 |
+
elif isinstance(image, Image.Image):
|
| 57 |
+
w, h = image.size
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Invalid image type: {type(image)}")
|
| 60 |
+
return h, w
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def center_crop_boxes_kps(img: Tensor, output_size: Optional[List[int]] = 448, parts: Optional[Tensor] = None,
|
| 64 |
+
boxes: Optional[Tensor] = None, num_keypoints: int = 15):
|
| 65 |
+
"""
|
| 66 |
+
Calculate the center crop parameters for the bounding boxes and landmarks and update them
|
| 67 |
+
:param img: Image
|
| 68 |
+
:param output_size: Output size of the cropped image
|
| 69 |
+
:param parts: Locations of the landmarks of following format: <part_id> <x> <y> <visible>
|
| 70 |
+
:param boxes: Bounding boxes of the landmarks of following format: <image_id> <x> <y> <width> <height>
|
| 71 |
+
:param num_keypoints: Number of keypoints
|
| 72 |
+
:return:
|
| 73 |
+
cropped_img: Center cropped image
|
| 74 |
+
parts: Updated locations of the landmarks
|
| 75 |
+
boxes: Updated bounding boxes of the landmarks
|
| 76 |
+
"""
|
| 77 |
+
if isinstance(output_size, int):
|
| 78 |
+
output_size = (output_size, output_size)
|
| 79 |
+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
| 80 |
+
output_size = (output_size[0], output_size[0])
|
| 81 |
+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
|
| 82 |
+
output_size = output_size
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f"Invalid output size: {output_size}")
|
| 85 |
+
|
| 86 |
+
crop_height, crop_width = output_size
|
| 87 |
+
image_height, image_width = get_dimensions(img)
|
| 88 |
+
img = torchvision.transforms.functional.center_crop(img, output_size)
|
| 89 |
+
|
| 90 |
+
crop_top, crop_left = _get_center_crop_params_(image_height, image_width, output_size)
|
| 91 |
+
|
| 92 |
+
if parts is not None:
|
| 93 |
+
for j in range(num_keypoints):
|
| 94 |
+
# Skip if part is invisible
|
| 95 |
+
if parts[j][-1] == 0:
|
| 96 |
+
continue
|
| 97 |
+
parts[j][1] -= crop_left
|
| 98 |
+
parts[j][2] -= crop_top
|
| 99 |
+
|
| 100 |
+
# Skip if part is outside the crop
|
| 101 |
+
if parts[j][1] > crop_width or parts[j][2] > crop_height:
|
| 102 |
+
parts[j][-1] = 0
|
| 103 |
+
if parts[j][1] < 0 or parts[j][2] < 0:
|
| 104 |
+
parts[j][-1] = 0
|
| 105 |
+
|
| 106 |
+
parts[j][1] = min(crop_width, parts[j][1])
|
| 107 |
+
parts[j][2] = min(crop_height, parts[j][2])
|
| 108 |
+
parts[j][1] = max(0, parts[j][1])
|
| 109 |
+
parts[j][2] = max(0, parts[j][2])
|
| 110 |
+
|
| 111 |
+
if boxes is not None:
|
| 112 |
+
boxes[1] -= crop_left
|
| 113 |
+
boxes[2] -= crop_top
|
| 114 |
+
boxes[1] = max(0, boxes[1])
|
| 115 |
+
boxes[2] = max(0, boxes[2])
|
| 116 |
+
boxes[1] = min(crop_width, boxes[1])
|
| 117 |
+
boxes[2] = min(crop_height, boxes[2])
|
| 118 |
+
|
| 119 |
+
return img, parts, boxes
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _get_center_crop_params_(image_height: int, image_width: int, output_size: Optional[List[int]] = 448):
|
| 123 |
+
"""
|
| 124 |
+
Get the parameters for center cropping the image
|
| 125 |
+
:param image_height: Height of the image
|
| 126 |
+
:param image_width: Width of the image
|
| 127 |
+
:param output_size: Output size of the cropped image
|
| 128 |
+
:return:
|
| 129 |
+
crop_top: Top coordinate of the cropped image
|
| 130 |
+
crop_left: Left coordinate of the cropped image
|
| 131 |
+
"""
|
| 132 |
+
if isinstance(output_size, int):
|
| 133 |
+
output_size = (output_size, output_size)
|
| 134 |
+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
| 135 |
+
output_size = (output_size[0], output_size[0])
|
| 136 |
+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
|
| 137 |
+
output_size = output_size
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError(f"Invalid output size: {output_size}")
|
| 140 |
+
|
| 141 |
+
crop_height, crop_width = output_size
|
| 142 |
+
|
| 143 |
+
if crop_width > image_width or crop_height > image_height:
|
| 144 |
+
padding_ltrb = [
|
| 145 |
+
(crop_width - image_width) // 2 if crop_width > image_width else 0,
|
| 146 |
+
(crop_height - image_height) // 2 if crop_height > image_height else 0,
|
| 147 |
+
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
| 148 |
+
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
| 149 |
+
]
|
| 150 |
+
crop_top, crop_left = padding_ltrb[1], padding_ltrb[0]
|
| 151 |
+
return crop_top, crop_left
|
| 152 |
+
|
| 153 |
+
if crop_width == image_width and crop_height == image_height:
|
| 154 |
+
crop_top = 0
|
| 155 |
+
crop_left = 0
|
| 156 |
+
return crop_top, crop_left
|
| 157 |
+
|
| 158 |
+
crop_top = int(round((image_height - crop_height) / 2.0))
|
| 159 |
+
crop_left = int(round((image_width - crop_width) / 2.0))
|
| 160 |
+
|
| 161 |
+
return crop_top, crop_left
|
utils/data_utils/reversible_affine_transform.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Description: This file contains the code for the reversible affine transform
|
| 2 |
+
import torchvision.transforms as transforms
|
| 3 |
+
import torch
|
| 4 |
+
from typing import List, Optional, Tuple, Any
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def generate_affine_trans_params(
|
| 8 |
+
degrees: List[float],
|
| 9 |
+
translate: Optional[List[float]],
|
| 10 |
+
scale_ranges: Optional[List[float]],
|
| 11 |
+
shears: Optional[List[float]],
|
| 12 |
+
img_size: List[int],
|
| 13 |
+
) -> Tuple[float, Tuple[int, int], float, Any]:
|
| 14 |
+
"""Get parameters for affine transformation
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
params to be passed to the affine transformation
|
| 18 |
+
"""
|
| 19 |
+
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
|
| 20 |
+
if translate is not None:
|
| 21 |
+
max_dx = float(translate[0] * img_size[0])
|
| 22 |
+
max_dy = float(translate[1] * img_size[1])
|
| 23 |
+
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
|
| 24 |
+
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
|
| 25 |
+
translations = (tx, ty)
|
| 26 |
+
else:
|
| 27 |
+
translations = (0, 0)
|
| 28 |
+
|
| 29 |
+
if scale_ranges is not None:
|
| 30 |
+
scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
|
| 31 |
+
else:
|
| 32 |
+
scale = 1.0
|
| 33 |
+
|
| 34 |
+
shear_x = shear_y = 0.0
|
| 35 |
+
if shears is not None:
|
| 36 |
+
shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
|
| 37 |
+
if len(shears) == 4:
|
| 38 |
+
shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
|
| 39 |
+
|
| 40 |
+
shear = (shear_x, shear_y)
|
| 41 |
+
if shear_x == 0.0 and shear_y == 0.0:
|
| 42 |
+
shear = 0.0
|
| 43 |
+
|
| 44 |
+
return angle, translations, scale, shear
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def rigid_transform(img, angle, translate, scale, invert=False, shear=0,
|
| 48 |
+
interpolation=transforms.InterpolationMode.BILINEAR):
|
| 49 |
+
"""
|
| 50 |
+
Affine transforms input image
|
| 51 |
+
Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L54
|
| 52 |
+
Parameters
|
| 53 |
+
----------
|
| 54 |
+
img: Tensor
|
| 55 |
+
Input image
|
| 56 |
+
angle: int
|
| 57 |
+
Rotation angle between -180 and 180 degrees
|
| 58 |
+
translate: [int]
|
| 59 |
+
Sequence of horizontal/vertical translations
|
| 60 |
+
scale: float
|
| 61 |
+
How to scale the image
|
| 62 |
+
invert: bool
|
| 63 |
+
Whether to invert the transformation
|
| 64 |
+
shear: float
|
| 65 |
+
Shear angle in degrees
|
| 66 |
+
interpolation: InterpolationMode
|
| 67 |
+
Interpolation mode to calculate output values
|
| 68 |
+
Returns
|
| 69 |
+
----------
|
| 70 |
+
img: Tensor
|
| 71 |
+
Transformed image
|
| 72 |
+
|
| 73 |
+
"""
|
| 74 |
+
if not invert:
|
| 75 |
+
img = transforms.functional.affine(img, angle=angle, translate=translate, scale=scale, shear=shear,
|
| 76 |
+
interpolation=interpolation)
|
| 77 |
+
else:
|
| 78 |
+
translate = [-t for t in translate]
|
| 79 |
+
img = transforms.functional.affine(img=img, angle=0, translate=translate, scale=1, shear=shear)
|
| 80 |
+
img = transforms.functional.affine(img=img, angle=-angle, translate=[0, 0], scale=1 / scale, shear=shear)
|
| 81 |
+
|
| 82 |
+
return img
|
utils/data_utils/transform_utils.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision import transforms as transforms
|
| 3 |
+
from torchvision.transforms import Compose
|
| 4 |
+
|
| 5 |
+
from timm.data.constants import \
|
| 6 |
+
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
| 7 |
+
from timm.data import create_transform
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def make_train_transforms(args):
|
| 11 |
+
train_transforms: Compose = transforms.Compose([
|
| 12 |
+
transforms.Resize(size=args.image_size, antialias=True),
|
| 13 |
+
transforms.RandomHorizontalFlip(p=args.hflip),
|
| 14 |
+
transforms.RandomVerticalFlip(p=args.vflip),
|
| 15 |
+
transforms.ColorJitter(),
|
| 16 |
+
transforms.RandomAffine(degrees=90, translate=(0.2, 0.2), scale=(0.8, 1.2)),
|
| 17 |
+
transforms.RandomCrop(args.image_size),
|
| 18 |
+
transforms.ToTensor(),
|
| 19 |
+
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
|
| 20 |
+
|
| 21 |
+
])
|
| 22 |
+
return train_transforms
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def make_test_transforms(args):
|
| 26 |
+
test_transforms: Compose = transforms.Compose([
|
| 27 |
+
transforms.Resize(size=args.image_size, antialias=True),
|
| 28 |
+
transforms.CenterCrop(args.image_size),
|
| 29 |
+
transforms.ToTensor(),
|
| 30 |
+
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
|
| 31 |
+
|
| 32 |
+
])
|
| 33 |
+
return test_transforms
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_transform_timm(args, is_train=True):
|
| 37 |
+
resize_im = args.image_size > 32
|
| 38 |
+
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
|
| 39 |
+
mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
|
| 40 |
+
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
|
| 41 |
+
|
| 42 |
+
if is_train:
|
| 43 |
+
# this should always dispatch to transforms_imagenet_train
|
| 44 |
+
transform = create_transform(
|
| 45 |
+
input_size=args.image_size,
|
| 46 |
+
is_training=True,
|
| 47 |
+
color_jitter=args.color_jitter,
|
| 48 |
+
hflip=args.hflip,
|
| 49 |
+
vflip=args.vflip,
|
| 50 |
+
auto_augment=args.aa,
|
| 51 |
+
interpolation=args.train_interpolation,
|
| 52 |
+
re_prob=args.reprob,
|
| 53 |
+
re_mode=args.remode,
|
| 54 |
+
re_count=args.recount,
|
| 55 |
+
mean=mean,
|
| 56 |
+
std=std,
|
| 57 |
+
)
|
| 58 |
+
if not resize_im:
|
| 59 |
+
transform.transforms[0] = transforms.RandomCrop(
|
| 60 |
+
args.image_size, padding=4)
|
| 61 |
+
return transform
|
| 62 |
+
|
| 63 |
+
t = []
|
| 64 |
+
if resize_im:
|
| 65 |
+
# warping (no cropping) when evaluated at 384 or larger
|
| 66 |
+
if args.image_size >= 384:
|
| 67 |
+
t.append(
|
| 68 |
+
transforms.Resize((args.image_size, args.image_size),
|
| 69 |
+
interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
|
| 70 |
+
)
|
| 71 |
+
print(f"Warping {args.image_size} size input images...")
|
| 72 |
+
else:
|
| 73 |
+
if args.crop_pct is None:
|
| 74 |
+
args.crop_pct = 224 / 256
|
| 75 |
+
size = int(args.image_size / args.crop_pct)
|
| 76 |
+
t.append(
|
| 77 |
+
# to maintain same ratio w.r.t. 224 images
|
| 78 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
|
| 79 |
+
)
|
| 80 |
+
t.append(transforms.CenterCrop(args.image_size))
|
| 81 |
+
|
| 82 |
+
t.append(transforms.ToTensor())
|
| 83 |
+
t.append(transforms.Normalize(mean, std))
|
| 84 |
+
return transforms.Compose(t)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def inverse_normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
|
| 88 |
+
mean = torch.as_tensor(mean)
|
| 89 |
+
std = torch.as_tensor(std)
|
| 90 |
+
un_normalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
|
| 91 |
+
return un_normalize
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def normalize_only(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
|
| 95 |
+
normalize = transforms.Normalize(mean=mean, std=std)
|
| 96 |
+
return normalize
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def inverse_normalize_w_resize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
| 100 |
+
resize_resolution=(256, 256)):
|
| 101 |
+
mean = torch.as_tensor(mean)
|
| 102 |
+
std = torch.as_tensor(std)
|
| 103 |
+
resize_unnorm = transforms.Compose([
|
| 104 |
+
transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()),
|
| 105 |
+
transforms.Resize(size=resize_resolution, antialias=True)])
|
| 106 |
+
return resize_unnorm
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def load_transforms(args):
|
| 110 |
+
# Get the transforms and load the dataset
|
| 111 |
+
if args.augmentations_to_use == 'timm':
|
| 112 |
+
train_transforms = build_transform_timm(args, is_train=True)
|
| 113 |
+
elif args.augmentations_to_use == 'cub_original':
|
| 114 |
+
train_transforms = make_train_transforms(args)
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError('Augmentations not supported.')
|
| 117 |
+
test_transforms = make_test_transforms(args)
|
| 118 |
+
return train_transforms, test_transforms
|
utils/get_landmark_coordinates.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file contains the function to generate the center coordinates as tensor for the current net.
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def landmark_coordinates(maps, grid_x=None, grid_y=None):
|
| 6 |
+
"""
|
| 7 |
+
Generate the center coordinates as tensor for the current net.
|
| 8 |
+
Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L19
|
| 9 |
+
Parameters
|
| 10 |
+
----------
|
| 11 |
+
maps: torch.Tensor
|
| 12 |
+
Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability
|
| 13 |
+
grid_x: torch.Tensor
|
| 14 |
+
The grid x coordinates
|
| 15 |
+
grid_y: torch.Tensor
|
| 16 |
+
The grid y coordinates
|
| 17 |
+
Returns
|
| 18 |
+
----------
|
| 19 |
+
loc_x: Tensor
|
| 20 |
+
The centroid x coordinates
|
| 21 |
+
loc_y: Tensor
|
| 22 |
+
The centroid y coordinates
|
| 23 |
+
grid_x: Tensor
|
| 24 |
+
grid_y: Tensor
|
| 25 |
+
"""
|
| 26 |
+
return_grid = False
|
| 27 |
+
if grid_x is None or grid_y is None:
|
| 28 |
+
return_grid = True
|
| 29 |
+
grid_x, grid_y = torch.meshgrid(torch.arange(maps.shape[2]),
|
| 30 |
+
torch.arange(maps.shape[3]), indexing='ij')
|
| 31 |
+
grid_x = grid_x.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
|
| 32 |
+
grid_y = grid_y.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
|
| 33 |
+
map_sums = maps.sum(3).sum(2).detach()
|
| 34 |
+
maps_x = grid_x * maps
|
| 35 |
+
maps_y = grid_y * maps
|
| 36 |
+
loc_x = maps_x.sum(3).sum(2) / map_sums
|
| 37 |
+
loc_y = maps_y.sum(3).sum(2) / map_sums
|
| 38 |
+
if return_grid:
|
| 39 |
+
return loc_x, loc_y, grid_x, grid_y
|
| 40 |
+
else:
|
| 41 |
+
return loc_x, loc_y
|
utils/misc_utils.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import reduce
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def factors(n):
|
| 11 |
+
return reduce(list.__add__,
|
| 12 |
+
([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def file_line_count(filename: str) -> int:
|
| 16 |
+
"""Count the number of lines in a file"""
|
| 17 |
+
with open(filename, 'rb') as f:
|
| 18 |
+
return sum(1 for _ in f)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def compute_attention(qkv, scale=None):
|
| 22 |
+
"""
|
| 23 |
+
Compute attention matrix (same as in the pytorch scaled dot product attention)
|
| 24 |
+
Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
| 25 |
+
:param qkv: Query, key and value tensors concatenated along the first dimension
|
| 26 |
+
:param scale: Scale factor for the attention computation
|
| 27 |
+
:return:
|
| 28 |
+
"""
|
| 29 |
+
if isinstance(qkv, torch.Tensor):
|
| 30 |
+
query, key, value = qkv.unbind(0)
|
| 31 |
+
else:
|
| 32 |
+
query, key, value = qkv
|
| 33 |
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
| 34 |
+
L, S = query.size(-2), key.size(-2)
|
| 35 |
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
| 36 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
| 37 |
+
attn_weight += attn_bias
|
| 38 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
| 39 |
+
attn_out = attn_weight @ value
|
| 40 |
+
return attn_weight, attn_out
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def compute_dot_product_similarity(a, b):
|
| 44 |
+
scores = a @ b.transpose(-1, -2)
|
| 45 |
+
return scores
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def compute_cross_entropy(p, q):
|
| 49 |
+
q = torch.nn.functional.log_softmax(q, dim=-1)
|
| 50 |
+
loss = torch.sum(p * q, dim=-1)
|
| 51 |
+
return - loss.mean()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def rollout(attentions, discard_ratio=0.9, head_fusion="max", device=torch.device("cuda")):
|
| 55 |
+
"""
|
| 56 |
+
Perform attention rollout,
|
| 57 |
+
Ref: https://github.com/jacobgil/vit-explain/blob/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/vit_rollout.py#L9C1-L42C16
|
| 58 |
+
Parameters
|
| 59 |
+
----------
|
| 60 |
+
attentions : list
|
| 61 |
+
List of attention matrices, one for each transformer layer
|
| 62 |
+
discard_ratio : float
|
| 63 |
+
Ratio of lowest attention values to discard
|
| 64 |
+
head_fusion : str
|
| 65 |
+
Type of fusion to use for attention heads. One of "mean", "max", "min"
|
| 66 |
+
device : torch.device
|
| 67 |
+
Device to use for computation
|
| 68 |
+
Returns
|
| 69 |
+
-------
|
| 70 |
+
mask : np.ndarray
|
| 71 |
+
Mask of shape (width, width), where width is the square root of the number of patches
|
| 72 |
+
"""
|
| 73 |
+
result = torch.eye(attentions[0].size(-1), device=device)
|
| 74 |
+
attentions = [attention.to(device) for attention in attentions]
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
for attention in attentions:
|
| 77 |
+
if head_fusion == "mean":
|
| 78 |
+
attention_heads_fused = attention.mean(axis=1)
|
| 79 |
+
elif head_fusion == "max":
|
| 80 |
+
attention_heads_fused = attention.max(axis=1).values
|
| 81 |
+
elif head_fusion == "min":
|
| 82 |
+
attention_heads_fused = attention.min(axis=1).values
|
| 83 |
+
else:
|
| 84 |
+
raise "Attention head fusion type Not supported"
|
| 85 |
+
|
| 86 |
+
# Drop the lowest attentions, but
|
| 87 |
+
# don't drop the class token
|
| 88 |
+
flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
|
| 89 |
+
_, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
|
| 90 |
+
indices = indices[indices != 0]
|
| 91 |
+
flat[0, indices] = 0
|
| 92 |
+
|
| 93 |
+
I = torch.eye(attention_heads_fused.size(-1), device=device)
|
| 94 |
+
a = (attention_heads_fused + 1.0 * I) / 2
|
| 95 |
+
a = a / a.sum(dim=-1)
|
| 96 |
+
|
| 97 |
+
result = torch.matmul(a, result)
|
| 98 |
+
|
| 99 |
+
# Normalize the result by max value in each row
|
| 100 |
+
result = result / result.max(dim=-1, keepdim=True)[0]
|
| 101 |
+
return result
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def sync_bn_conversion(model: torch.nn.Module):
|
| 105 |
+
"""
|
| 106 |
+
Convert BatchNorm to SyncBatchNorm (used for DDP)
|
| 107 |
+
:param model: PyTorch model
|
| 108 |
+
:return:
|
| 109 |
+
model: PyTorch model with SyncBatchNorm layers
|
| 110 |
+
"""
|
| 111 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
| 112 |
+
return model
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def check_snapshot(args):
|
| 116 |
+
"""
|
| 117 |
+
Create directory to save training checkpoints, otherwise load the existing checkpoint.
|
| 118 |
+
Additionally, if it is an array training job, create a new directory for each training job.
|
| 119 |
+
:param args: Arguments from the argument parser
|
| 120 |
+
:return:
|
| 121 |
+
"""
|
| 122 |
+
# Check if it is an array training job (i.e. training with multiple random seeds on the same settings)
|
| 123 |
+
if args.array_training_job and not args.resume_training:
|
| 124 |
+
args.snapshot_dir = os.path.join(args.snapshot_dir, str(args.seed))
|
| 125 |
+
if not os.path.exists(args.snapshot_dir):
|
| 126 |
+
save_dir = Path(args.snapshot_dir)
|
| 127 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 128 |
+
else:
|
| 129 |
+
# Create directory to save training checkpoints, otherwise load the existing checkpoint
|
| 130 |
+
if not os.path.exists(args.snapshot_dir):
|
| 131 |
+
if ".pt" not in args.snapshot_dir or ".pth" not in args.snapshot_dir:
|
| 132 |
+
save_dir = Path(args.snapshot_dir)
|
| 133 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 134 |
+
else:
|
| 135 |
+
raise ValueError('Snapshot checkpoint does not exist.')
|
utils/visualize_att_maps.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
| 3 |
+
import colorcet as cc
|
| 4 |
+
import numpy as np
|
| 5 |
+
import skimage
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from utils.data_utils.transform_utils import inverse_normalize_w_resize
|
| 11 |
+
from utils.misc_utils import factors
|
| 12 |
+
|
| 13 |
+
# Define the colors to use for the attention maps
|
| 14 |
+
colors = cc.glasbey_category10
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class VisualizeAttentionMaps:
|
| 18 |
+
def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5, sub_path_test="",
|
| 19 |
+
dataset_name="", bg_label=0, batch_size=32, num_parts=15, plot_ims_separately=False,
|
| 20 |
+
plot_landmark_amaps=False):
|
| 21 |
+
"""
|
| 22 |
+
Plot attention maps and optionally landmark centroids on images.
|
| 23 |
+
:param snapshot_dir: Directory to save the visualization results
|
| 24 |
+
:param save_resolution: Size of the images to save
|
| 25 |
+
:param alpha: The transparency of the attention maps
|
| 26 |
+
:param sub_path_test: The sub-path of the test dataset
|
| 27 |
+
:param dataset_name: The name of the dataset
|
| 28 |
+
:param bg_label: The background label index in the attention maps
|
| 29 |
+
:param batch_size: The batch size
|
| 30 |
+
:param num_parts: The number of parts in the attention maps
|
| 31 |
+
:param plot_ims_separately: Whether to plot the images separately
|
| 32 |
+
:param plot_landmark_amaps: Whether to plot the landmark attention maps
|
| 33 |
+
"""
|
| 34 |
+
self.save_resolution = save_resolution
|
| 35 |
+
self.alpha = alpha
|
| 36 |
+
self.sub_path_test = sub_path_test
|
| 37 |
+
self.dataset_name = dataset_name
|
| 38 |
+
self.bg_label = bg_label
|
| 39 |
+
self.snapshot_dir = snapshot_dir
|
| 40 |
+
|
| 41 |
+
self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
|
| 42 |
+
self.batch_size = batch_size
|
| 43 |
+
self.nrows = factors(self.batch_size)[-1]
|
| 44 |
+
self.ncols = factors(self.batch_size)[-2]
|
| 45 |
+
self.num_parts = num_parts
|
| 46 |
+
self.req_colors = colors[:num_parts]
|
| 47 |
+
self.plot_ims_separately = plot_ims_separately
|
| 48 |
+
self.plot_landmark_amaps = plot_landmark_amaps
|
| 49 |
+
if self.nrows == 1 and self.ncols == 1:
|
| 50 |
+
self.figs_size = (10, 10)
|
| 51 |
+
else:
|
| 52 |
+
self.figs_size = (self.ncols * 2, self.nrows * 2)
|
| 53 |
+
|
| 54 |
+
def recalculate_nrows_ncols(self):
|
| 55 |
+
self.nrows = factors(self.batch_size)[-1]
|
| 56 |
+
self.ncols = factors(self.batch_size)[-2]
|
| 57 |
+
if self.nrows == 1 and self.ncols == 1:
|
| 58 |
+
self.figs_size = (10, 10)
|
| 59 |
+
else:
|
| 60 |
+
self.figs_size = (self.ncols * 2, self.nrows * 2)
|
| 61 |
+
|
| 62 |
+
@torch.no_grad()
|
| 63 |
+
def show_maps(self, ims, maps, epoch=0, curr_iter=0, extra_info=""):
|
| 64 |
+
"""
|
| 65 |
+
Plot images, attention maps and landmark centroids.
|
| 66 |
+
Parameters
|
| 67 |
+
----------
|
| 68 |
+
ims: Tensor, [batch_size, 3, width_im, height_im]
|
| 69 |
+
Input images on which to show the attention maps
|
| 70 |
+
maps: Tensor, [batch_size, number of parts + 1, width_map, height_map]
|
| 71 |
+
The attention maps to display
|
| 72 |
+
epoch: int
|
| 73 |
+
The epoch number
|
| 74 |
+
curr_iter: int
|
| 75 |
+
The current iteration number
|
| 76 |
+
extra_info: str
|
| 77 |
+
Any extra information to add to the file name
|
| 78 |
+
"""
|
| 79 |
+
ims = self.resize_unnorm(ims)
|
| 80 |
+
if ims.shape[0] != self.batch_size:
|
| 81 |
+
self.batch_size = ims.shape[0]
|
| 82 |
+
self.recalculate_nrows_ncols()
|
| 83 |
+
fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
|
| 84 |
+
ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
|
| 85 |
+
map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
|
| 86 |
+
mode='bilinear',
|
| 87 |
+
align_corners=True).argmax(dim=1).cpu().numpy()
|
| 88 |
+
for i, ax in enumerate(axs.ravel()):
|
| 89 |
+
curr_map = skimage.color.label2rgb(label=map_argmax[i], image=ims[i], colors=self.req_colors,
|
| 90 |
+
bg_label=self.bg_label, alpha=self.alpha)
|
| 91 |
+
ax.imshow(curr_map)
|
| 92 |
+
ax.axis('off')
|
| 93 |
+
save_dir = Path(os.path.join(self.snapshot_dir, 'results_vis_' + self.sub_path_test))
|
| 94 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 95 |
+
save_path = os.path.join(save_dir, f'{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
|
| 96 |
+
fig.tight_layout()
|
| 97 |
+
if self.snapshot_dir != "":
|
| 98 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
| 99 |
+
else:
|
| 100 |
+
plt.show()
|
| 101 |
+
plt.close('all')
|
| 102 |
+
|
| 103 |
+
if self.plot_ims_separately:
|
| 104 |
+
fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
|
| 105 |
+
for i, ax in enumerate(axs.ravel()):
|
| 106 |
+
ax.imshow(ims[i])
|
| 107 |
+
ax.axis('off')
|
| 108 |
+
save_path = os.path.join(save_dir, f'image_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.jpg')
|
| 109 |
+
fig.tight_layout()
|
| 110 |
+
if self.snapshot_dir != "":
|
| 111 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
| 112 |
+
else:
|
| 113 |
+
plt.show()
|
| 114 |
+
plt.close('all')
|
| 115 |
+
|
| 116 |
+
if self.plot_landmark_amaps:
|
| 117 |
+
if self.batch_size > 1:
|
| 118 |
+
raise ValueError('Not implemented for batch size > 1')
|
| 119 |
+
for i in range(self.num_parts):
|
| 120 |
+
fig, ax = plt.subplots(1, 1, figsize=self.figs_size)
|
| 121 |
+
divider = make_axes_locatable(ax)
|
| 122 |
+
cax = divider.append_axes('right', size='5%', pad=0.05)
|
| 123 |
+
im = ax.imshow(maps[0, i, ...].detach().cpu().numpy(), cmap='cet_gouldian')
|
| 124 |
+
fig.colorbar(im, cax=cax, orientation='vertical')
|
| 125 |
+
ax.axis('off')
|
| 126 |
+
save_path = os.path.join(save_dir,
|
| 127 |
+
f'landmark_{i}_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
|
| 128 |
+
fig.tight_layout()
|
| 129 |
+
if self.snapshot_dir != "":
|
| 130 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
| 131 |
+
else:
|
| 132 |
+
plt.show()
|
| 133 |
+
plt.close()
|
| 134 |
+
|
| 135 |
+
plt.close('all')
|