pdiscoformer / models /individual_landmark_vit.py
ananthu-aniraj's picture
upload initial version
a8d9779
raw
history blame
15.4 kB
# Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
from pathlib import Path
import os
import torch
import torch.nn as nn
from torch import Tensor
from typing import Any, Union, Sequence, Optional, Dict
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from timm.models import create_model
from timm.models.vision_transformer import Block, Attention
from utils.misc_utils import compute_attention
from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn
from layers.independent_mlp import IndependentMLPs
SAFETENSORS_SINGLE_FILE = "model.safetensors"
class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
pipeline_tag='image-classification',
repo_url='https://github.com/ananthu-aniraj/pdiscoformer'):
def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8, num_classes: int = 200,
part_dropout: float = 0.3, return_transformer_qkv: bool = False,
modulation_type: str = "original", gumbel_softmax: bool = False,
gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
classifier_type: str = "linear") -> None:
super().__init__()
self.num_landmarks = num_landmarks
self.num_classes = num_classes
self.num_prefix_tokens = init_model.num_prefix_tokens
self.num_reg_tokens = init_model.num_reg_tokens
self.has_class_token = init_model.has_class_token
self.no_embed_class = init_model.no_embed_class
self.cls_token = init_model.cls_token
self.reg_token = init_model.reg_token
self.feature_dim = init_model.embed_dim
self.patch_embed = init_model.patch_embed
self.pos_embed = init_model.pos_embed
self.pos_drop = init_model.pos_drop
self.norm_pre = init_model.norm_pre
self.blocks = init_model.blocks
self.norm = init_model.norm
self.return_transformer_qkv = return_transformer_qkv
self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0])
self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1])
self.unflatten = nn.Unflatten(1, (self.h_fmap, self.w_fmap))
self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
self.gumbel_softmax = gumbel_softmax
self.gumbel_softmax_temperature = gumbel_softmax_temperature
self.gumbel_softmax_hard = gumbel_softmax_hard
self.modulation_type = modulation_type
if modulation_type == "layer_norm":
self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
elif modulation_type == "original":
self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
elif modulation_type == "parallel_mlp":
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
num_lin_layers=1, act_layer=True, bias=True)
elif modulation_type == "parallel_mlp_no_bias":
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
num_lin_layers=1, act_layer=True, bias=False)
elif modulation_type == "parallel_mlp_no_act":
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
num_lin_layers=1, act_layer=False, bias=True)
elif modulation_type == "parallel_mlp_no_act_no_bias":
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
num_lin_layers=1, act_layer=False, bias=False)
elif modulation_type == "none":
self.modulation = torch.nn.Identity()
else:
raise ValueError("modulation_type not implemented")
self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
self.classifier_type = classifier_type
if classifier_type == "independent_mlp":
self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
num_lin_layers=1, act_layer=False, out_dim=num_classes,
bias=False, stack_dim=1)
elif classifier_type == "linear":
self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
bias=False)
else:
raise ValueError("classifier_type not implemented")
self.convert_blocks_and_attention()
self._init_weights()
def _init_weights_head(self):
# Initialize weights with a truncated normal distribution
if self.classifier_type == "independent_mlp":
self.fc_class_landmarks.reset_weights()
else:
torch.nn.init.trunc_normal_(self.fc_class_landmarks.weight, std=0.02)
if self.fc_class_landmarks.bias is not None:
torch.nn.init.zeros_(self.fc_class_landmarks.bias)
def _init_weights(self):
self._init_weights_head()
def convert_blocks_and_attention(self):
for module in self.modules():
if isinstance(module, Block):
module.__class__ = BlockWQKVReturn
elif isinstance(module, Attention):
module.__class__ = AttentionWQKVReturn
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
pos_embed = self.pos_embed
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
if self.reg_token is not None:
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
x = x + pos_embed
return self.pos_drop(x)
def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, int | Any] | tuple[Any, Any, Any, Any, int | Any]:
x = self.patch_embed(x)
# Position Embedding
x = self._pos_embed(x)
# Forward pass through transformer
x = self.norm_pre(x)
x = self.blocks(x)
x = self.norm(x)
# Compute per landmark attention maps
# (b - a)^2 = b^2 - 2ab + a^2, b = feature maps vit, a = convolution kernel
batch_size = x.shape[0]
x = x[:, self.num_prefix_tokens:, :] # [B, num_patch_tokens, embed_dim]
x = self.unflatten(x) # [B, H, W, embed_dim]
x = x.permute(0, 3, 1, 2).contiguous() # [B, embed_dim, H, W]
ab = self.fc_landmarks(x) # [B, num_landmarks + 1, H, W]
b_sq = x.pow(2).sum(1, keepdim=True)
b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
a_sq = self.fc_landmarks.weight.pow(2).sum(1, keepdim=True).expand(-1, batch_size, x.shape[-2],
x.shape[-1]).contiguous()
a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
dist = b_sq - 2 * ab + a_sq
maps = -dist
# Softmax so that the attention maps for each pixel add up to 1
if self.gumbel_softmax:
maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W]
else:
maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W]
# Use maps to get weighted average features per landmark
all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).contiguous()
all_features = all_features.mean(-1).mean(-1).contiguous() # [B, embed_dim, num_landmarks + 1]
# Modulate the features
if self.modulation_type == "original":
all_features_mod = all_features * self.modulation # [B, embed_dim, num_landmarks + 1]
else:
all_features_mod = self.modulation(all_features) # [B, embed_dim, num_landmarks + 1]
# Classification based on the landmark features
scores = self.fc_class_landmarks(
self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
1).contiguous()
scores = scores.mean(dim=-1) # [B, num_classes]
return maps, scores
def get_specific_intermediate_layer(
self,
x: torch.Tensor,
n: int = 1,
return_qkv: bool = False,
return_att_weights: bool = False,
):
num_blocks = len(self.blocks)
attn_weights = []
if n >= num_blocks:
raise ValueError(f"n must be less than {num_blocks}")
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.norm_pre(x)
if n == -1:
if return_qkv:
raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True")
else:
return x
for i, blk in enumerate(self.blocks):
if self.return_transformer_qkv:
x, qkv = blk(x, return_qkv=True)
if return_att_weights:
attn_weight, _ = compute_attention(qkv)
attn_weights.append(attn_weight.detach())
else:
x = blk(x)
if i == n:
output = x.clone()
if self.return_transformer_qkv and return_qkv:
qkv_output = qkv.clone()
break
if self.return_transformer_qkv and return_qkv and return_att_weights:
return output, qkv_output, attn_weights
elif self.return_transformer_qkv and return_qkv:
return output, qkv_output
elif self.return_transformer_qkv and return_att_weights:
return output, attn_weights
else:
return output
def _intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
):
outputs, num_blocks = [], len(self.blocks)
if self.return_transformer_qkv:
qkv_outputs = []
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
if self.return_transformer_qkv:
x, qkv = blk(x, return_qkv=True)
else:
x = blk(x)
if i in take_indices:
outputs.append(x)
if self.return_transformer_qkv:
qkv_outputs.append(qkv)
if self.return_transformer_qkv:
return outputs, qkv_outputs
else:
return outputs
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
reshape: bool = False,
return_prefix_tokens: bool = False,
norm: bool = False,
) -> tuple[tuple, Any]:
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
Inspired by DINO / DINOv2 interface
"""
# take last n blocks if n is an int, if in is a sequence, select by matching indices
if self.return_transformer_qkv:
outputs, qkv = self._intermediate_layers(x, n)
else:
outputs = self._intermediate_layers(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
if reshape:
grid_size = self.patch_embed.grid_size
outputs = [
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_prefix_tokens:
return_out = tuple(zip(outputs, prefix_tokens))
else:
return_out = tuple(outputs)
if self.return_transformer_qkv:
return return_out, qkv
else:
return return_out
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Union[str, bool, None],
map_location: str = "cpu",
strict: bool = False,
timm_backbone: str = "hf_hub:timm/vit_base_patch14_reg4_dinov2.lvd142m",
input_size: int = 518,
**model_kwargs):
base_model = create_model(timm_backbone, pretrained=False, img_size=input_size)
model = cls(base_model, **model_kwargs)
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
return cls._load_as_safetensor(model, model_file, map_location, strict)
else:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
return cls._load_as_safetensor(model, model_file, map_location, strict)
def pdiscoformer_vit_bb(backbone, img_size=224, num_cls=200, k=8, **kwargs):
base_model = create_model(
backbone,
pretrained=False,
img_size=img_size,
)
model = IndividualLandmarkViT(base_model, num_landmarks=k, num_classes=num_cls,
modulation_type="layer_norm", gumbel_softmax=True,
modulation_orth=True)
return model
def pdisconet_vit_bb(backbone, img_size=224, num_cls=200, k=8, **kwargs):
base_model = create_model(
backbone,
pretrained=False,
img_size=img_size,
)
model = IndividualLandmarkViT(base_model, num_landmarks=k, num_classes=num_cls,
modulation_type="original")
return model