# -*- coding: utf-8 -*- from einops import rearrange from models.encoders.VIT.SAM.image_encoder import ImageEncoderViT from models.encoders.VIT.vits_histo import VisionTransformer import torch import torch.nn as nn from typing import Callable, Tuple, Type, List class Conv2DBlock(nn.Module): """Conv2DBlock with convolution followed by batch-normalisation, ReLU activation and dropout Args: in_channels (int): Number of input channels for convolution out_channels (int): Number of output channels for convolution kernel_size (int, optional): Kernel size for convolution. Defaults to 3. dropout (float, optional): Dropout. Defaults to 0. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, dropout: float = 0, ) -> None: super().__init__() self.block = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2), ), nn.BatchNorm2d(out_channels), nn.ReLU(True), nn.Dropout(dropout), ) def forward(self, x): return self.block(x) class Deconv2DBlock(nn.Module): """Deconvolution block with ConvTranspose2d followed by Conv2d, batch-normalisation, ReLU activation and dropout Args: in_channels (int): Number of input channels for deconv block out_channels (int): Number of output channels for deconv and convolution. kernel_size (int, optional): Kernel size for convolution. Defaults to 3. dropout (float, optional): Dropout. Defaults to 0. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, dropout: float = 0, ) -> None: super().__init__() self.block = nn.Sequential( nn.ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2, padding=0, output_padding=0, ), nn.Conv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2), ), nn.BatchNorm2d(out_channels), nn.ReLU(True), nn.Dropout(dropout), ) def forward(self, x): return self.block(x) class ViTCellViT(VisionTransformer): def __init__( self, extract_layers: List[int], img_size: List[int] = [224], patch_size: int = 16, in_chans: int = 3, num_classes: int = 0, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4, qkv_bias: bool = False, qk_scale: float = None, drop_rate: float = 0, attn_drop_rate: float = 0, drop_path_rate: float = 0, norm_layer: Callable = nn.LayerNorm, **kwargs ): """Vision Transformer with 1D positional embedding Args: extract_layers: (List[int]): List of Transformer Blocks whose outputs should be returned in addition to the tokens. First blocks starts with 1, and maximum is N=depth. img_size (int, optional): Input image size. Defaults to 224. patch_size (int, optional): Patch Token size (one dimension only, cause tokens are squared). Defaults to 16. in_chans (int, optional): Number of input channels. Defaults to 3. num_classes (int, optional): Number of output classes. if num classes = 0, raw tokens are returned (nn.Identity). Default to 0. embed_dim (int, optional): Embedding dimension. Defaults to 768. depth(int, optional): Number of Transformer Blocks. Defaults to 12. num_heads (int, optional): Number of attention heads per Transformer Block. Defaults to 12. mlp_ratio (float, optional): MLP ratio for hidden MLP dimension (Bottleneck = dim*mlp_ratio). Defaults to 4.0. qkv_bias (bool, optional): If bias should be used for query (q), key (k), and value (v). Defaults to False. qk_scale (float, optional): Scaling parameter. Defaults to None. drop_rate (float, optional): Dropout in MLP. Defaults to 0.0. attn_drop_rate (float, optional): Dropout for attention layer. Defaults to 0.0. drop_path_rate (float, optional): Dropout for skip connection. Defaults to 0.0. norm_layer (Callable, optional): Normalization layer. Defaults to nn.LayerNorm. """ super().__init__( img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, norm_layer=norm_layer, ) self.extract_layers = extract_layers def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass with returning intermediate outputs for skip connections Args: x (torch.Tensor): Input batch Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: torch.Tensor: Output of last layers (all tokens, without classification) torch.Tensor: Classification output torch.Tensor: Skip connection outputs from extract_layer selection """ extracted_layers = [] x = self.prepare_tokens(x) for depth, blk in enumerate(self.blocks): x = blk(x) if depth + 1 in self.extract_layers: extracted_layers.append(x) x = self.norm(x) output = self.head(x[:, 0]) return output, x[:, 0], extracted_layers class ViTCellViTDeit(ImageEncoderViT): def __init__( self, extract_layers: List[int], img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), ) -> None: super().__init__( img_size, patch_size, in_chans, embed_dim, depth, num_heads, mlp_ratio, out_chans, qkv_bias, norm_layer, act_layer, use_abs_pos, use_rel_pos, rel_pos_zero_init, window_size, global_attn_indexes, ) self.extract_layers = extract_layers def forward(self, x: torch.Tensor) -> torch.Tensor: extracted_layers = [] x = self.patch_embed(x) if self.pos_embed is not None: token_size = x.shape[1] x = x + self.pos_embed[:, :token_size, :token_size, :] for depth, blk in enumerate(self.blocks): x = blk(x) if depth + 1 in self.extract_layers: extracted_layers.append(x) output = self.neck(x.permute(0, 3, 1, 2)) _output = rearrange(output, "b c h w -> b c (h w)") return torch.mean(_output, axis=-1), output, extracted_layers