qingke1's picture
initial commit
aea73e2
raw
history blame
8.09 kB
# -*- coding: utf-8 -*-
from einops import rearrange
from models.encoders.VIT.SAM.image_encoder import ImageEncoderViT
from models.encoders.VIT.vits_histo import VisionTransformer
import torch
import torch.nn as nn
from typing import Callable, Tuple, Type, List
class Conv2DBlock(nn.Module):
"""Conv2DBlock with convolution followed by batch-normalisation, ReLU activation and dropout
Args:
in_channels (int): Number of input channels for convolution
out_channels (int): Number of output channels for convolution
kernel_size (int, optional): Kernel size for convolution. Defaults to 3.
dropout (float, optional): Dropout. Defaults to 0.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
dropout: float = 0,
) -> None:
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=((kernel_size - 1) // 2),
),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Dropout(dropout),
)
def forward(self, x):
return self.block(x)
class Deconv2DBlock(nn.Module):
"""Deconvolution block with ConvTranspose2d followed by Conv2d, batch-normalisation, ReLU activation and dropout
Args:
in_channels (int): Number of input channels for deconv block
out_channels (int): Number of output channels for deconv and convolution.
kernel_size (int, optional): Kernel size for convolution. Defaults to 3.
dropout (float, optional): Dropout. Defaults to 0.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
dropout: float = 0,
) -> None:
super().__init__()
self.block = nn.Sequential(
nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=2,
stride=2,
padding=0,
output_padding=0,
),
nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=((kernel_size - 1) // 2),
),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Dropout(dropout),
)
def forward(self, x):
return self.block(x)
class ViTCellViT(VisionTransformer):
def __init__(
self,
extract_layers: List[int],
img_size: List[int] = [224],
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 0,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4,
qkv_bias: bool = False,
qk_scale: float = None,
drop_rate: float = 0,
attn_drop_rate: float = 0,
drop_path_rate: float = 0,
norm_layer: Callable = nn.LayerNorm,
**kwargs
):
"""Vision Transformer with 1D positional embedding
Args:
extract_layers: (List[int]): List of Transformer Blocks whose outputs should be returned in addition to the tokens. First blocks starts with 1, and maximum is N=depth.
img_size (int, optional): Input image size. Defaults to 224.
patch_size (int, optional): Patch Token size (one dimension only, cause tokens are squared). Defaults to 16.
in_chans (int, optional): Number of input channels. Defaults to 3.
num_classes (int, optional): Number of output classes. if num classes = 0, raw tokens are returned (nn.Identity).
Default to 0.
embed_dim (int, optional): Embedding dimension. Defaults to 768.
depth(int, optional): Number of Transformer Blocks. Defaults to 12.
num_heads (int, optional): Number of attention heads per Transformer Block. Defaults to 12.
mlp_ratio (float, optional): MLP ratio for hidden MLP dimension (Bottleneck = dim*mlp_ratio).
Defaults to 4.0.
qkv_bias (bool, optional): If bias should be used for query (q), key (k), and value (v). Defaults to False.
qk_scale (float, optional): Scaling parameter. Defaults to None.
drop_rate (float, optional): Dropout in MLP. Defaults to 0.0.
attn_drop_rate (float, optional): Dropout for attention layer. Defaults to 0.0.
drop_path_rate (float, optional): Dropout for skip connection. Defaults to 0.0.
norm_layer (Callable, optional): Normalization layer. Defaults to nn.LayerNorm.
"""
super().__init__(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
num_classes=num_classes,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
)
self.extract_layers = extract_layers
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass with returning intermediate outputs for skip connections
Args:
x (torch.Tensor): Input batch
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
torch.Tensor: Output of last layers (all tokens, without classification)
torch.Tensor: Classification output
torch.Tensor: Skip connection outputs from extract_layer selection
"""
extracted_layers = []
x = self.prepare_tokens(x)
for depth, blk in enumerate(self.blocks):
x = blk(x)
if depth + 1 in self.extract_layers:
extracted_layers.append(x)
x = self.norm(x)
output = self.head(x[:, 0])
return output, x[:, 0], extracted_layers
class ViTCellViTDeit(ImageEncoderViT):
def __init__(
self,
extract_layers: List[int],
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
super().__init__(
img_size,
patch_size,
in_chans,
embed_dim,
depth,
num_heads,
mlp_ratio,
out_chans,
qkv_bias,
norm_layer,
act_layer,
use_abs_pos,
use_rel_pos,
rel_pos_zero_init,
window_size,
global_attn_indexes,
)
self.extract_layers = extract_layers
def forward(self, x: torch.Tensor) -> torch.Tensor:
extracted_layers = []
x = self.patch_embed(x)
if self.pos_embed is not None:
token_size = x.shape[1]
x = x + self.pos_embed[:, :token_size, :token_size, :]
for depth, blk in enumerate(self.blocks):
x = blk(x)
if depth + 1 in self.extract_layers:
extracted_layers.append(x)
output = self.neck(x.permute(0, 3, 1, 2))
_output = rearrange(output, "b c h w -> b c (h w)")
return torch.mean(_output, axis=-1), output, extracted_layers