Spaces:
Sleeping
Sleeping
File size: 8,696 Bytes
20239f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
# Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
import torch
import torch.nn as nn
from typing import Tuple, Union, Sequence, Any
from timm.layers import trunc_normal_
from timm.models.vision_transformer import Block, Attention
from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn
from utils.misc_utils import compute_attention
class BaselineViT(torch.nn.Module):
"""
Modifications:
- Use PDiscoBlock instead of Block
- Use PDiscoAttention instead of Attention
- Return the mean of k over heads from attention
- Option to use only class tokens or only patch tokens or both (concat) for classification
"""
def __init__(self, init_model: torch.nn.Module, num_classes: int,
class_tokens_only: bool = False,
patch_tokens_only: bool = False, return_transformer_qkv: bool = False) -> None:
super().__init__()
self.num_classes = num_classes
self.class_tokens_only = class_tokens_only
self.patch_tokens_only = patch_tokens_only
self.num_prefix_tokens = init_model.num_prefix_tokens
self.num_reg_tokens = init_model.num_reg_tokens
self.has_class_token = init_model.has_class_token
self.no_embed_class = init_model.no_embed_class
self.cls_token = init_model.cls_token
self.reg_token = init_model.reg_token
self.patch_embed = init_model.patch_embed
self.pos_embed = init_model.pos_embed
self.pos_drop = init_model.pos_drop
self.part_embed = nn.Identity()
self.patch_prune = nn.Identity()
self.norm_pre = init_model.norm_pre
self.blocks = init_model.blocks
self.norm = init_model.norm
self.fc_norm = init_model.fc_norm
if class_tokens_only or patch_tokens_only:
self.head = nn.Linear(init_model.embed_dim, num_classes)
else:
self.head = nn.Linear(init_model.embed_dim * 2, num_classes)
self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0])
self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1])
self.return_transformer_qkv = return_transformer_qkv
self.convert_blocks_and_attention()
self._init_weights_head()
def convert_blocks_and_attention(self):
for module in self.modules():
if isinstance(module, Block):
module.__class__ = BlockWQKVReturn
elif isinstance(module, Attention):
module.__class__ = AttentionWQKVReturn
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
pos_embed = self.pos_embed
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
if self.reg_token is not None:
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
x = x + pos_embed
return self.pos_drop(x)
def _init_weights_head(self):
trunc_normal_(self.head.weight, std=.02)
if self.head.bias is not None:
nn.init.constant_(self.head.bias, 0.)
def forward(self, x: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
x = self.patch_embed(x)
# Position Embedding
x = self._pos_embed(x)
x = self.part_embed(x)
x = self.patch_prune(x)
# Forward pass through transformer
x = self.norm_pre(x)
if self.return_transformer_qkv:
# Return keys of last attention layer
for i, blk in enumerate(self.blocks):
x, qkv = blk(x, return_qkv=True)
else:
x = self.blocks(x)
x = self.norm(x)
# Classification head
x = self.fc_norm(x)
if self.class_tokens_only: # only use class token
x = x[:, 0, :]
elif self.patch_tokens_only: # only use patch tokens
x = x[:, self.num_prefix_tokens:, :].mean(dim=1)
else:
x = torch.cat([x[:, 0, :], x[:, self.num_prefix_tokens:, :].mean(dim=1)], dim=1)
x = self.head(x)
if self.return_transformer_qkv:
return x, qkv
else:
return x
def get_specific_intermediate_layer(
self,
x: torch.Tensor,
n: int = 1,
return_qkv: bool = False,
return_att_weights: bool = False,
):
num_blocks = len(self.blocks)
attn_weights = []
if n >= num_blocks:
raise ValueError(f"n must be less than {num_blocks}")
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.norm_pre(x)
if n == -1:
if return_qkv:
raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True")
else:
return x
for i, blk in enumerate(self.blocks):
if self.return_transformer_qkv:
x, qkv = blk(x, return_qkv=True)
if return_att_weights:
attn_weight, _ = compute_attention(qkv)
attn_weights.append(attn_weight.detach())
else:
x = blk(x)
if i == n:
output = x.clone()
if self.return_transformer_qkv and return_qkv:
qkv_output = qkv.clone()
break
if self.return_transformer_qkv and return_qkv and return_att_weights:
return output, qkv_output, attn_weights
elif self.return_transformer_qkv and return_qkv:
return output, qkv_output
elif self.return_transformer_qkv and return_att_weights:
return output, attn_weights
else:
return output
def _intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
):
outputs, num_blocks = [], len(self.blocks)
if self.return_transformer_qkv:
qkv_outputs = []
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
if self.return_transformer_qkv:
x, qkv = blk(x, return_qkv=True)
else:
x = blk(x)
if i in take_indices:
outputs.append(x)
if self.return_transformer_qkv:
qkv_outputs.append(qkv)
if self.return_transformer_qkv:
return outputs, qkv_outputs
else:
return outputs
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
reshape: bool = False,
return_prefix_tokens: bool = False,
norm: bool = False,
) -> tuple[tuple, Any]:
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
Inspired by DINO / DINOv2 interface
"""
# take last n blocks if n is an int, if in is a sequence, select by matching indices
if self.return_transformer_qkv:
outputs, qkv = self._intermediate_layers(x, n)
else:
outputs = self._intermediate_layers(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
if reshape:
grid_size = self.patch_embed.grid_size
outputs = [
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_prefix_tokens:
return_out = tuple(zip(outputs, prefix_tokens))
else:
return_out = tuple(outputs)
if self.return_transformer_qkv:
return return_out, qkv
else:
return return_out
|