Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
class UNet(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.mean = [0.485, 0.456, 0.406] | |
self.std = [0.229, 0.224, 0.225] | |
# Downsampler | |
self.enc_conv0 = nn.Sequential( | |
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(64), | |
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(64), | |
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(64) | |
) | |
self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.enc_conv1 = nn.Sequential( | |
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(128), | |
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(128), | |
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(128) | |
) | |
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.enc_conv2 = nn.Sequential( | |
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(256), | |
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(256), | |
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(256) | |
) | |
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.enc_conv3 = nn.Sequential( | |
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(512), | |
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(512), | |
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(512) | |
) | |
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) | |
# bottleneck | |
self.bottleneck_conv = nn.Sequential( | |
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(1024), | |
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(1024), | |
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(1024) | |
) | |
# Upsampler | |
self.upsample0 = nn.Sequential( | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1), | |
) | |
self.dec_conv0 = nn.Sequential( | |
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(512), | |
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(512), | |
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(512) | |
) | |
self.upsample1 = nn.Sequential( | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1), | |
) | |
self.dec_conv1 = nn.Sequential( | |
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(256), | |
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(256), | |
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(256) | |
) | |
self.upsample2 = nn.Sequential( | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1), | |
) | |
self.dec_conv2 = nn.Sequential( | |
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(128), | |
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(128), | |
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(128) | |
) | |
self.upsample3 = nn.Sequential( | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1), | |
) | |
self.dec_conv3 = nn.Sequential( | |
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(64), | |
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(inplace=True), | |
nn.BatchNorm2d(64), | |
nn.Conv2d(in_channels=64, out_channels=9, kernel_size=1, stride=1, padding=0) | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# encoder | |
e0 = self.enc_conv0(x) | |
e1 = self.pool0(e0) | |
e1 = self.enc_conv1(e1) | |
e2 = self.pool1(e1) | |
e2 = self.enc_conv2(e2) | |
e3 = self.pool2(e2) | |
e3 = self.enc_conv3(e3) | |
# bottleneck | |
b = self.pool3(e3) | |
b = self.bottleneck_conv(b) | |
# decoder | |
d0 = self.upsample0(b) | |
d0 = torch.cat([d0, e3], dim=1) | |
d0 = self.dec_conv0(d0) | |
d1 = self.upsample1(d0) | |
d1 = torch.cat([d1, e2], dim=1) | |
d1 = self.dec_conv1(d1) | |
d2 = self.upsample2(d1) | |
d2 = torch.cat([d2, e1], dim=1) | |
d2 = self.dec_conv2(d2) | |
d3 = self.upsample3(d2) | |
d3 = torch.cat([d3, e0], dim=1) | |
d3 = self.dec_conv3(d3) | |
return d3 |