import torch from torch import nn class UNet(nn.Module): def __init__(self) -> None: super().__init__() self.mean = [0.485, 0.456, 0.406] self.std = [0.229, 0.224, 0.225] # Downsampler self.enc_conv0 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(64), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(64), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(64) ) self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) self.enc_conv1 = nn.Sequential( nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(128), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(128), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(128) ) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.enc_conv2 = nn.Sequential( nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(256), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(256), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(256) ) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.enc_conv3 = nn.Sequential( nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(512), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(512), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(512) ) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # bottleneck self.bottleneck_conv = nn.Sequential( nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(1024), nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(1024), nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(1024) ) # Upsampler self.upsample0 = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1), ) self.dec_conv0 = nn.Sequential( nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(512), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(512), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(512) ) self.upsample1 = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1), ) self.dec_conv1 = nn.Sequential( nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(256), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(256), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(256) ) self.upsample2 = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1), ) self.dec_conv2 = nn.Sequential( nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(128), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(128), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(128) ) self.upsample3 = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1), ) self.dec_conv3 = nn.Sequential( nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(64), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=True), nn.BatchNorm2d(64), nn.Conv2d(in_channels=64, out_channels=9, kernel_size=1, stride=1, padding=0) ) def forward(self, x: torch.Tensor) -> torch.Tensor: # encoder e0 = self.enc_conv0(x) e1 = self.pool0(e0) e1 = self.enc_conv1(e1) e2 = self.pool1(e1) e2 = self.enc_conv2(e2) e3 = self.pool2(e2) e3 = self.enc_conv3(e3) # bottleneck b = self.pool3(e3) b = self.bottleneck_conv(b) # decoder d0 = self.upsample0(b) d0 = torch.cat([d0, e3], dim=1) d0 = self.dec_conv0(d0) d1 = self.upsample1(d0) d1 = torch.cat([d1, e2], dim=1) d1 = self.dec_conv1(d1) d2 = self.upsample2(d1) d2 = torch.cat([d2, e1], dim=1) d2 = self.dec_conv2(d2) d3 = self.upsample3(d2) d3 = torch.cat([d3, e0], dim=1) d3 = self.dec_conv3(d3) return d3