Spaces:
Running
Running
File size: 8,085 Bytes
aea73e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
# -*- coding: utf-8 -*-
from einops import rearrange
from models.encoders.VIT.SAM.image_encoder import ImageEncoderViT
from models.encoders.VIT.vits_histo import VisionTransformer
import torch
import torch.nn as nn
from typing import Callable, Tuple, Type, List
class Conv2DBlock(nn.Module):
"""Conv2DBlock with convolution followed by batch-normalisation, ReLU activation and dropout
Args:
in_channels (int): Number of input channels for convolution
out_channels (int): Number of output channels for convolution
kernel_size (int, optional): Kernel size for convolution. Defaults to 3.
dropout (float, optional): Dropout. Defaults to 0.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
dropout: float = 0,
) -> None:
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=((kernel_size - 1) // 2),
),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Dropout(dropout),
)
def forward(self, x):
return self.block(x)
class Deconv2DBlock(nn.Module):
"""Deconvolution block with ConvTranspose2d followed by Conv2d, batch-normalisation, ReLU activation and dropout
Args:
in_channels (int): Number of input channels for deconv block
out_channels (int): Number of output channels for deconv and convolution.
kernel_size (int, optional): Kernel size for convolution. Defaults to 3.
dropout (float, optional): Dropout. Defaults to 0.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
dropout: float = 0,
) -> None:
super().__init__()
self.block = nn.Sequential(
nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=2,
stride=2,
padding=0,
output_padding=0,
),
nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=((kernel_size - 1) // 2),
),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Dropout(dropout),
)
def forward(self, x):
return self.block(x)
class ViTCellViT(VisionTransformer):
def __init__(
self,
extract_layers: List[int],
img_size: List[int] = [224],
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 0,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4,
qkv_bias: bool = False,
qk_scale: float = None,
drop_rate: float = 0,
attn_drop_rate: float = 0,
drop_path_rate: float = 0,
norm_layer: Callable = nn.LayerNorm,
**kwargs
):
"""Vision Transformer with 1D positional embedding
Args:
extract_layers: (List[int]): List of Transformer Blocks whose outputs should be returned in addition to the tokens. First blocks starts with 1, and maximum is N=depth.
img_size (int, optional): Input image size. Defaults to 224.
patch_size (int, optional): Patch Token size (one dimension only, cause tokens are squared). Defaults to 16.
in_chans (int, optional): Number of input channels. Defaults to 3.
num_classes (int, optional): Number of output classes. if num classes = 0, raw tokens are returned (nn.Identity).
Default to 0.
embed_dim (int, optional): Embedding dimension. Defaults to 768.
depth(int, optional): Number of Transformer Blocks. Defaults to 12.
num_heads (int, optional): Number of attention heads per Transformer Block. Defaults to 12.
mlp_ratio (float, optional): MLP ratio for hidden MLP dimension (Bottleneck = dim*mlp_ratio).
Defaults to 4.0.
qkv_bias (bool, optional): If bias should be used for query (q), key (k), and value (v). Defaults to False.
qk_scale (float, optional): Scaling parameter. Defaults to None.
drop_rate (float, optional): Dropout in MLP. Defaults to 0.0.
attn_drop_rate (float, optional): Dropout for attention layer. Defaults to 0.0.
drop_path_rate (float, optional): Dropout for skip connection. Defaults to 0.0.
norm_layer (Callable, optional): Normalization layer. Defaults to nn.LayerNorm.
"""
super().__init__(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
num_classes=num_classes,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
)
self.extract_layers = extract_layers
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass with returning intermediate outputs for skip connections
Args:
x (torch.Tensor): Input batch
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
torch.Tensor: Output of last layers (all tokens, without classification)
torch.Tensor: Classification output
torch.Tensor: Skip connection outputs from extract_layer selection
"""
extracted_layers = []
x = self.prepare_tokens(x)
for depth, blk in enumerate(self.blocks):
x = blk(x)
if depth + 1 in self.extract_layers:
extracted_layers.append(x)
x = self.norm(x)
output = self.head(x[:, 0])
return output, x[:, 0], extracted_layers
class ViTCellViTDeit(ImageEncoderViT):
def __init__(
self,
extract_layers: List[int],
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
super().__init__(
img_size,
patch_size,
in_chans,
embed_dim,
depth,
num_heads,
mlp_ratio,
out_chans,
qkv_bias,
norm_layer,
act_layer,
use_abs_pos,
use_rel_pos,
rel_pos_zero_init,
window_size,
global_attn_indexes,
)
self.extract_layers = extract_layers
def forward(self, x: torch.Tensor) -> torch.Tensor:
extracted_layers = []
x = self.patch_embed(x)
if self.pos_embed is not None:
token_size = x.shape[1]
x = x + self.pos_embed[:, :token_size, :token_size, :]
for depth, blk in enumerate(self.blocks):
x = blk(x)
if depth + 1 in self.extract_layers:
extracted_layers.append(x)
output = self.neck(x.permute(0, 3, 1, 2))
_output = rearrange(output, "b c h w -> b c (h w)")
return torch.mean(_output, axis=-1), output, extracted_layers
|