from .decoders.unet_decoder import UnetDecoder from .decoders.unet_decoder import SegmentationHead import torch.nn as nn from typing import Tuple class unet_swin(nn.Module): """ Pytorch encoder-decoder model which pairs an encoder (swin) with the attention unet decoder. """ FEATURE_CHANNELS: Tuple[int] = (3, 256, 512, 1024, 1024) DECODE_CHANNELS: Tuple[int] = (512, 256, 128, 64) IN_CHANNELS: int = 64 N_BLOCKS: int = 4 KERNEL_SIZE: int = 3 UPSAMPLING: int = 4 def __init__(self, encoder, num_classes=9): super().__init__() self.encoder = encoder self.decoder = UnetDecoder( encoder_channels=self.FEATURE_CHANNELS, n_blocks=self.N_BLOCKS, decoder_channels=self.DECODE_CHANNELS, attention_type=None) self.segmentation_head = SegmentationHead( in_channels=self.IN_CHANNELS, out_channels=num_classes, kernel_size=self.KERNEL_SIZE, upsampling=self.UPSAMPLING) def forward(self, x): encoder_featrue = self.encoder.get_unet_feature(x) decoder_output = self.decoder(*encoder_featrue) masks = self.segmentation_head(decoder_output) return masks