eksemyashkina's picture
Upload 9 files
b108d0c verified
import torch
from torch import nn
class UNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
# Downsampler
self.enc_conv0 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(64)
)
self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc_conv1 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(128),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(128),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(128)
)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc_conv2 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(256),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(256),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(256)
)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc_conv3 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(512),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(512),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(512)
)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
# bottleneck
self.bottleneck_conv = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(1024),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(1024),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(1024)
)
# Upsampler
self.upsample0 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1),
)
self.dec_conv0 = nn.Sequential(
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(512),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(512),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(512)
)
self.upsample1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
)
self.dec_conv1 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(256),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(256),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(256)
)
self.upsample2 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
)
self.dec_conv2 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(128),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(128),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(128)
)
self.upsample3 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
)
self.dec_conv3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(in_channels=64, out_channels=9, kernel_size=1, stride=1, padding=0)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# encoder
e0 = self.enc_conv0(x)
e1 = self.pool0(e0)
e1 = self.enc_conv1(e1)
e2 = self.pool1(e1)
e2 = self.enc_conv2(e2)
e3 = self.pool2(e2)
e3 = self.enc_conv3(e3)
# bottleneck
b = self.pool3(e3)
b = self.bottleneck_conv(b)
# decoder
d0 = self.upsample0(b)
d0 = torch.cat([d0, e3], dim=1)
d0 = self.dec_conv0(d0)
d1 = self.upsample1(d0)
d1 = torch.cat([d1, e2], dim=1)
d1 = self.dec_conv1(d1)
d2 = self.upsample2(d1)
d2 = torch.cat([d2, e1], dim=1)
d2 = self.dec_conv2(d2)
d3 = self.upsample3(d2)
d3 = torch.cat([d3, e0], dim=1)
d3 = self.dec_conv3(d3)
return d3