File size: 1,261 Bytes
ab687e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from .decoders.unet_decoder import UnetDecoder
from .decoders.unet_decoder import SegmentationHead
import torch.nn as nn
from typing import Tuple
class unet_swin(nn.Module):
"""
Pytorch encoder-decoder model which pairs
an encoder (swin) with the attention unet
decoder.
"""
FEATURE_CHANNELS: Tuple[int] = (3, 256, 512, 1024, 1024)
DECODE_CHANNELS: Tuple[int] = (512, 256, 128, 64)
IN_CHANNELS: int = 64
N_BLOCKS: int = 4
KERNEL_SIZE: int = 3
UPSAMPLING: int = 4
def __init__(self, encoder, num_classes=9):
super().__init__()
self.encoder = encoder
self.decoder = UnetDecoder(
encoder_channels=self.FEATURE_CHANNELS,
n_blocks=self.N_BLOCKS,
decoder_channels=self.DECODE_CHANNELS,
attention_type=None)
self.segmentation_head = SegmentationHead(
in_channels=self.IN_CHANNELS,
out_channels=num_classes,
kernel_size=self.KERNEL_SIZE,
upsampling=self.UPSAMPLING)
def forward(self, x):
encoder_featrue = self.encoder.get_unet_feature(x)
decoder_output = self.decoder(*encoder_featrue)
masks = self.segmentation_head(decoder_output)
return masks
|