Spaces:
Runtime error
Runtime error
File size: 7,298 Bytes
b108d0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import torch
from torch import nn
class UNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
# Downsampler
self.enc_conv0 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(64)
)
self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc_conv1 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(128),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(128),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(128)
)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc_conv2 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(256),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(256),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(256)
)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc_conv3 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(512),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(512),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(512)
)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
# bottleneck
self.bottleneck_conv = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(1024),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(1024),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(1024)
)
# Upsampler
self.upsample0 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1),
)
self.dec_conv0 = nn.Sequential(
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(512),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(512),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(512)
)
self.upsample1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
)
self.dec_conv1 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(256),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(256),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(256)
)
self.upsample2 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
)
self.dec_conv2 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(128),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(128),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(128)
)
self.upsample3 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
)
self.dec_conv3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(in_channels=64, out_channels=9, kernel_size=1, stride=1, padding=0)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# encoder
e0 = self.enc_conv0(x)
e1 = self.pool0(e0)
e1 = self.enc_conv1(e1)
e2 = self.pool1(e1)
e2 = self.enc_conv2(e2)
e3 = self.pool2(e2)
e3 = self.enc_conv3(e3)
# bottleneck
b = self.pool3(e3)
b = self.bottleneck_conv(b)
# decoder
d0 = self.upsample0(b)
d0 = torch.cat([d0, e3], dim=1)
d0 = self.dec_conv0(d0)
d1 = self.upsample1(d0)
d1 = torch.cat([d1, e2], dim=1)
d1 = self.dec_conv1(d1)
d2 = self.upsample2(d1)
d2 = torch.cat([d2, e1], dim=1)
d2 = self.dec_conv2(d2)
d3 = self.upsample3(d2)
d3 = torch.cat([d3, e0], dim=1)
d3 = self.dec_conv3(d3)
return d3 |