Caleb Spradlin
initial commit
ab687e7
raw
history blame
1.26 kB
from .decoders.unet_decoder import UnetDecoder
from .decoders.unet_decoder import SegmentationHead
import torch.nn as nn
from typing import Tuple
class unet_swin(nn.Module):
"""
Pytorch encoder-decoder model which pairs
an encoder (swin) with the attention unet
decoder.
"""
FEATURE_CHANNELS: Tuple[int] = (3, 256, 512, 1024, 1024)
DECODE_CHANNELS: Tuple[int] = (512, 256, 128, 64)
IN_CHANNELS: int = 64
N_BLOCKS: int = 4
KERNEL_SIZE: int = 3
UPSAMPLING: int = 4
def __init__(self, encoder, num_classes=9):
super().__init__()
self.encoder = encoder
self.decoder = UnetDecoder(
encoder_channels=self.FEATURE_CHANNELS,
n_blocks=self.N_BLOCKS,
decoder_channels=self.DECODE_CHANNELS,
attention_type=None)
self.segmentation_head = SegmentationHead(
in_channels=self.IN_CHANNELS,
out_channels=num_classes,
kernel_size=self.KERNEL_SIZE,
upsampling=self.UPSAMPLING)
def forward(self, x):
encoder_featrue = self.encoder.get_unet_feature(x)
decoder_output = self.decoder(*encoder_featrue)
masks = self.segmentation_head(decoder_output)
return masks