Spaces:
Sleeping
Sleeping
# Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py | |
from pathlib import Path | |
import os | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from typing import Any, Union, Sequence, Optional, Dict | |
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download | |
from timm.models import create_model | |
from timm.models.vision_transformer import Block, Attention | |
from utils.misc_utils import compute_attention | |
from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn | |
from layers.independent_mlp import IndependentMLPs | |
SAFETENSORS_SINGLE_FILE = "model.safetensors" | |
class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin, | |
pipeline_tag='image-classification', | |
repo_url='https://github.com/ananthu-aniraj/pdiscoformer'): | |
def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8, num_classes: int = 200, | |
part_dropout: float = 0.3, return_transformer_qkv: bool = False, | |
modulation_type: str = "original", gumbel_softmax: bool = False, | |
gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False, | |
classifier_type: str = "linear") -> None: | |
super().__init__() | |
self.num_landmarks = num_landmarks | |
self.num_classes = num_classes | |
self.num_prefix_tokens = init_model.num_prefix_tokens | |
self.num_reg_tokens = init_model.num_reg_tokens | |
self.has_class_token = init_model.has_class_token | |
self.no_embed_class = init_model.no_embed_class | |
self.cls_token = init_model.cls_token | |
self.reg_token = init_model.reg_token | |
self.feature_dim = init_model.embed_dim | |
self.patch_embed = init_model.patch_embed | |
self.pos_embed = init_model.pos_embed | |
self.pos_drop = init_model.pos_drop | |
self.norm_pre = init_model.norm_pre | |
self.blocks = init_model.blocks | |
self.norm = init_model.norm | |
self.return_transformer_qkv = return_transformer_qkv | |
self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0]) | |
self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1]) | |
self.unflatten = nn.Unflatten(1, (self.h_fmap, self.w_fmap)) | |
self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False) | |
self.gumbel_softmax = gumbel_softmax | |
self.gumbel_softmax_temperature = gumbel_softmax_temperature | |
self.gumbel_softmax_hard = gumbel_softmax_hard | |
self.modulation_type = modulation_type | |
if modulation_type == "layer_norm": | |
self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1]) | |
elif modulation_type == "original": | |
self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1)) | |
elif modulation_type == "parallel_mlp": | |
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, | |
num_lin_layers=1, act_layer=True, bias=True) | |
elif modulation_type == "parallel_mlp_no_bias": | |
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, | |
num_lin_layers=1, act_layer=True, bias=False) | |
elif modulation_type == "parallel_mlp_no_act": | |
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, | |
num_lin_layers=1, act_layer=False, bias=True) | |
elif modulation_type == "parallel_mlp_no_act_no_bias": | |
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, | |
num_lin_layers=1, act_layer=False, bias=False) | |
elif modulation_type == "none": | |
self.modulation = torch.nn.Identity() | |
else: | |
raise ValueError("modulation_type not implemented") | |
self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout) | |
self.classifier_type = classifier_type | |
if classifier_type == "independent_mlp": | |
self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim, | |
num_lin_layers=1, act_layer=False, out_dim=num_classes, | |
bias=False, stack_dim=1) | |
elif classifier_type == "linear": | |
self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes, | |
bias=False) | |
else: | |
raise ValueError("classifier_type not implemented") | |
self.convert_blocks_and_attention() | |
self._init_weights() | |
def _init_weights_head(self): | |
# Initialize weights with a truncated normal distribution | |
if self.classifier_type == "independent_mlp": | |
self.fc_class_landmarks.reset_weights() | |
else: | |
torch.nn.init.trunc_normal_(self.fc_class_landmarks.weight, std=0.02) | |
if self.fc_class_landmarks.bias is not None: | |
torch.nn.init.zeros_(self.fc_class_landmarks.bias) | |
def _init_weights(self): | |
self._init_weights_head() | |
def convert_blocks_and_attention(self): | |
for module in self.modules(): | |
if isinstance(module, Block): | |
module.__class__ = BlockWQKVReturn | |
elif isinstance(module, Attention): | |
module.__class__ = AttentionWQKVReturn | |
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: | |
pos_embed = self.pos_embed | |
to_cat = [] | |
if self.cls_token is not None: | |
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) | |
if self.reg_token is not None: | |
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) | |
if self.no_embed_class: | |
# deit-3, updated JAX (big vision) | |
# position embedding does not overlap with class token, add then concat | |
x = x + pos_embed | |
if to_cat: | |
x = torch.cat(to_cat + [x], dim=1) | |
else: | |
# original timm, JAX, and deit vit impl | |
# pos_embed has entry for class token, concat then add | |
if to_cat: | |
x = torch.cat(to_cat + [x], dim=1) | |
x = x + pos_embed | |
return self.pos_drop(x) | |
def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, int | Any] | tuple[Any, Any, Any, Any, int | Any]: | |
x = self.patch_embed(x) | |
# Position Embedding | |
x = self._pos_embed(x) | |
# Forward pass through transformer | |
x = self.norm_pre(x) | |
x = self.blocks(x) | |
x = self.norm(x) | |
# Compute per landmark attention maps | |
# (b - a)^2 = b^2 - 2ab + a^2, b = feature maps vit, a = convolution kernel | |
batch_size = x.shape[0] | |
x = x[:, self.num_prefix_tokens:, :] # [B, num_patch_tokens, embed_dim] | |
x = self.unflatten(x) # [B, H, W, embed_dim] | |
x = x.permute(0, 3, 1, 2).contiguous() # [B, embed_dim, H, W] | |
ab = self.fc_landmarks(x) # [B, num_landmarks + 1, H, W] | |
b_sq = x.pow(2).sum(1, keepdim=True) | |
b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous() | |
a_sq = self.fc_landmarks.weight.pow(2).sum(1, keepdim=True).expand(-1, batch_size, x.shape[-2], | |
x.shape[-1]).contiguous() | |
a_sq = a_sq.permute(1, 0, 2, 3).contiguous() | |
dist = b_sq - 2 * ab + a_sq | |
maps = -dist | |
# Softmax so that the attention maps for each pixel add up to 1 | |
if self.gumbel_softmax: | |
maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature, | |
hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W] | |
else: | |
maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W] | |
# Use maps to get weighted average features per landmark | |
all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).contiguous() | |
all_features = all_features.mean(-1).mean(-1).contiguous() # [B, embed_dim, num_landmarks + 1] | |
# Modulate the features | |
if self.modulation_type == "original": | |
all_features_mod = all_features * self.modulation # [B, embed_dim, num_landmarks + 1] | |
else: | |
all_features_mod = self.modulation(all_features) # [B, embed_dim, num_landmarks + 1] | |
# Classification based on the landmark features | |
scores = self.fc_class_landmarks( | |
self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2, | |
1).contiguous() | |
scores = scores.mean(dim=-1) # [B, num_classes] | |
return maps, scores | |
def get_specific_intermediate_layer( | |
self, | |
x: torch.Tensor, | |
n: int = 1, | |
return_qkv: bool = False, | |
return_att_weights: bool = False, | |
): | |
num_blocks = len(self.blocks) | |
attn_weights = [] | |
if n >= num_blocks: | |
raise ValueError(f"n must be less than {num_blocks}") | |
# forward pass | |
x = self.patch_embed(x) | |
x = self._pos_embed(x) | |
x = self.norm_pre(x) | |
if n == -1: | |
if return_qkv: | |
raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True") | |
else: | |
return x | |
for i, blk in enumerate(self.blocks): | |
if self.return_transformer_qkv: | |
x, qkv = blk(x, return_qkv=True) | |
if return_att_weights: | |
attn_weight, _ = compute_attention(qkv) | |
attn_weights.append(attn_weight.detach()) | |
else: | |
x = blk(x) | |
if i == n: | |
output = x.clone() | |
if self.return_transformer_qkv and return_qkv: | |
qkv_output = qkv.clone() | |
break | |
if self.return_transformer_qkv and return_qkv and return_att_weights: | |
return output, qkv_output, attn_weights | |
elif self.return_transformer_qkv and return_qkv: | |
return output, qkv_output | |
elif self.return_transformer_qkv and return_att_weights: | |
return output, attn_weights | |
else: | |
return output | |
def _intermediate_layers( | |
self, | |
x: torch.Tensor, | |
n: Union[int, Sequence] = 1, | |
): | |
outputs, num_blocks = [], len(self.blocks) | |
if self.return_transformer_qkv: | |
qkv_outputs = [] | |
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) | |
# forward pass | |
x = self.patch_embed(x) | |
x = self._pos_embed(x) | |
x = self.norm_pre(x) | |
for i, blk in enumerate(self.blocks): | |
if self.return_transformer_qkv: | |
x, qkv = blk(x, return_qkv=True) | |
else: | |
x = blk(x) | |
if i in take_indices: | |
outputs.append(x) | |
if self.return_transformer_qkv: | |
qkv_outputs.append(qkv) | |
if self.return_transformer_qkv: | |
return outputs, qkv_outputs | |
else: | |
return outputs | |
def get_intermediate_layers( | |
self, | |
x: torch.Tensor, | |
n: Union[int, Sequence] = 1, | |
reshape: bool = False, | |
return_prefix_tokens: bool = False, | |
norm: bool = False, | |
) -> tuple[tuple, Any]: | |
""" Intermediate layer accessor (NOTE: This is a WIP experiment). | |
Inspired by DINO / DINOv2 interface | |
""" | |
# take last n blocks if n is an int, if in is a sequence, select by matching indices | |
if self.return_transformer_qkv: | |
outputs, qkv = self._intermediate_layers(x, n) | |
else: | |
outputs = self._intermediate_layers(x, n) | |
if norm: | |
outputs = [self.norm(out) for out in outputs] | |
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] | |
outputs = [out[:, self.num_prefix_tokens:] for out in outputs] | |
if reshape: | |
grid_size = self.patch_embed.grid_size | |
outputs = [ | |
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous() | |
for out in outputs | |
] | |
if return_prefix_tokens: | |
return_out = tuple(zip(outputs, prefix_tokens)) | |
else: | |
return_out = tuple(outputs) | |
if self.return_transformer_qkv: | |
return return_out, qkv | |
else: | |
return return_out | |
def _from_pretrained( | |
cls, | |
*, | |
model_id: str, | |
revision: Optional[str], | |
cache_dir: Optional[Union[str, Path]], | |
force_download: bool, | |
proxies: Optional[Dict], | |
resume_download: Optional[bool], | |
local_files_only: bool, | |
token: Union[str, bool, None], | |
map_location: str = "cpu", | |
strict: bool = False, | |
timm_backbone: str = "hf_hub:timm/vit_base_patch14_reg4_dinov2.lvd142m", | |
input_size: int = 518, | |
**model_kwargs): | |
base_model = create_model(timm_backbone, pretrained=False, img_size=input_size) | |
model = cls(base_model, **model_kwargs) | |
if os.path.isdir(model_id): | |
print("Loading weights from local directory") | |
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) | |
return cls._load_as_safetensor(model, model_file, map_location, strict) | |
else: | |
model_file = hf_hub_download( | |
repo_id=model_id, | |
filename=SAFETENSORS_SINGLE_FILE, | |
revision=revision, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
resume_download=resume_download, | |
token=token, | |
local_files_only=local_files_only, | |
) | |
return cls._load_as_safetensor(model, model_file, map_location, strict) | |
def pdiscoformer_vit_bb(backbone, img_size=224, num_cls=200, k=8, **kwargs): | |
base_model = create_model( | |
backbone, | |
pretrained=False, | |
img_size=img_size, | |
) | |
model = IndividualLandmarkViT(base_model, num_landmarks=k, num_classes=num_cls, | |
modulation_type="layer_norm", gumbel_softmax=True, | |
modulation_orth=True) | |
return model | |
def pdisconet_vit_bb(backbone, img_size=224, num_cls=200, k=8, **kwargs): | |
base_model = create_model( | |
backbone, | |
pretrained=False, | |
img_size=img_size, | |
) | |
model = IndividualLandmarkViT(base_model, num_landmarks=k, num_classes=num_cls, | |
modulation_type="original") | |
return model | |