File size: 4,836 Bytes
ab687e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from segmentation_models_pytorch.base import modules as md
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True,
attention_type=None,
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
in_and_skip_channels = in_channels + skip_channels
self.attention1 = md.Attention(attention_type,
in_channels=in_and_skip_channels)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention2 = md.Attention(attention_type,
in_channels=out_channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.skip_channels = skip_channels
def forward(self, x, skip=None):
if skip is None:
x = F.interpolate(x, scale_factor=2, mode="nearest")
else:
if x.shape[-1] != skip.shape[-1]:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.attention2(x)
return x
class CenterBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
conv1 = md.Conv2dReLU(
in_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
super().__init__(conv1, conv2)
class UnetDecoder(nn.Module):
def __init__(self,
encoder_channels,
decoder_channels,
n_blocks=5,
use_batchnorm=True,
attention_type=None,
center=False):
super().__init__()
if n_blocks != len(decoder_channels):
raise ValueError(
f"Model depth is {n_blocks}, but you provided "
f"decoder_channels for {len(decoder_channels)} blocks."
)
# remove first skip with same spatial resolution
encoder_channels = encoder_channels[1:]
# reverse channels to start from head of encoder
encoder_channels = encoder_channels[::-1]
# computing blocks input and output channels
head_channels = encoder_channels[0]
in_channels = [head_channels] + list(decoder_channels[:-1])
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels
if center:
self.center = CenterBlock(
head_channels, head_channels, use_batchnorm=use_batchnorm)
else:
self.center = nn.Identity()
# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm,
attention_type=attention_type)
blocks = [
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
for in_ch, skip_ch, out_ch in zip(in_channels,
skip_channels,
out_channels)
]
self.blocks = nn.ModuleList(blocks)
def forward(self, *features):
features = features[1:]
# remove first skip with same spatial resolution
features = features[:: -1]
# reverse channels to start from head of encoder
head = features[0]
skips = features[1:]
x = self.center(head)
for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)
return x
class SegmentationHead(nn.Sequential):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
upsampling=1):
conv2d = nn.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2)
upsampling = nn.UpsamplingBilinear2d(
scale_factor=upsampling) if upsampling > 1 else nn.Identity()
super().__init__(conv2d, upsampling)
|