|
from .decoders.unet_decoder import UnetDecoder |
|
from .decoders.unet_decoder import SegmentationHead |
|
|
|
import torch.nn as nn |
|
|
|
from typing import Tuple |
|
|
|
|
|
class unet_swin(nn.Module): |
|
""" |
|
Pytorch encoder-decoder model which pairs |
|
an encoder (swin) with the attention unet |
|
decoder. |
|
""" |
|
|
|
FEATURE_CHANNELS: Tuple[int] = (3, 256, 512, 1024, 1024) |
|
DECODE_CHANNELS: Tuple[int] = (512, 256, 128, 64) |
|
IN_CHANNELS: int = 64 |
|
N_BLOCKS: int = 4 |
|
KERNEL_SIZE: int = 3 |
|
UPSAMPLING: int = 4 |
|
|
|
def __init__(self, encoder, num_classes=9): |
|
super().__init__() |
|
|
|
self.encoder = encoder |
|
|
|
self.decoder = UnetDecoder( |
|
encoder_channels=self.FEATURE_CHANNELS, |
|
n_blocks=self.N_BLOCKS, |
|
decoder_channels=self.DECODE_CHANNELS, |
|
attention_type=None) |
|
self.segmentation_head = SegmentationHead( |
|
in_channels=self.IN_CHANNELS, |
|
out_channels=num_classes, |
|
kernel_size=self.KERNEL_SIZE, |
|
upsampling=self.UPSAMPLING) |
|
|
|
def forward(self, x): |
|
encoder_featrue = self.encoder.get_unet_feature(x) |
|
decoder_output = self.decoder(*encoder_featrue) |
|
masks = self.segmentation_head(decoder_output) |
|
|
|
return masks |
|
|