|
import torch |
|
import torch.nn as nn |
|
|
|
class ConvBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(ConvBlock, self).__init__() |
|
self.conv = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
|
nn.ReLU(inplace=True) |
|
) |
|
|
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
class UpConv(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(UpConv, self).__init__() |
|
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) |
|
|
|
def forward(self, x): |
|
return self.up(x) |
|
|
|
class UNet(nn.Module): |
|
def __init__(self, in_channels=3, out_channels=1): |
|
super(UNet, self).__init__() |
|
|
|
self.encoder1 = ConvBlock(in_channels, 64) |
|
self.encoder2 = ConvBlock(64, 128) |
|
self.encoder3 = ConvBlock(128, 256) |
|
self.encoder4 = ConvBlock(256, 512) |
|
|
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
self.bottleneck = ConvBlock(512, 1024) |
|
|
|
self.upconv4 = UpConv(1024, 512) |
|
self.decoder4 = ConvBlock(1024, 512) |
|
self.upconv3 = UpConv(512, 256) |
|
self.decoder3 = ConvBlock(512, 256) |
|
self.upconv2 = UpConv(256, 128) |
|
self.decoder2 = ConvBlock(256, 128) |
|
self.upconv1 = UpConv(128, 64) |
|
self.decoder1 = ConvBlock(128, 64) |
|
|
|
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) |
|
|
|
def forward(self, x): |
|
enc1 = self.encoder1(x) |
|
enc2 = self.encoder2(self.pool(enc1)) |
|
enc3 = self.encoder3(self.pool(enc2)) |
|
enc4 = self.encoder4(self.pool(enc3)) |
|
|
|
bottleneck = self.bottleneck(self.pool(enc4)) |
|
|
|
dec4 = self.upconv4(bottleneck) |
|
dec4 = torch.cat((enc4, dec4), dim=1) |
|
dec4 = self.decoder4(dec4) |
|
|
|
dec3 = self.upconv3(dec4) |
|
dec3 = torch.cat((enc3, dec3), dim=1) |
|
dec3 = self.decoder3(dec3) |
|
|
|
dec2 = self.upconv2(dec3) |
|
dec2 = torch.cat((enc2, dec2), dim=1) |
|
dec2 = self.decoder2(dec2) |
|
|
|
dec1 = self.upconv1(dec2) |
|
dec1 = torch.cat((enc1, dec1), dim=1) |
|
dec1 = self.decoder1(dec1) |
|
|
|
return self.final_conv(dec1) |