Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch.nn.utils import spectral_norm | |
import models.basicblock as B | |
import functools | |
import numpy as np | |
""" | |
# -------------------------------------------- | |
# Discriminator_PatchGAN | |
# Discriminator_UNet | |
# -------------------------------------------- | |
""" | |
# -------------------------------------------- | |
# PatchGAN discriminator | |
# If n_layers = 3, then the receptive field is 70x70 | |
# -------------------------------------------- | |
class Discriminator_PatchGAN(nn.Module): | |
def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_type='spectral'): | |
'''PatchGAN discriminator, receptive field = 70x70 if n_layers = 3 | |
Args: | |
input_nc: number of input channels | |
ndf: base channel number | |
n_layers: number of conv layer with stride 2 | |
norm_type: 'batch', 'instance', 'spectral', 'batchspectral', instancespectral' | |
Returns: | |
tensor: score | |
''' | |
super(Discriminator_PatchGAN, self).__init__() | |
self.n_layers = n_layers | |
norm_layer = self.get_norm_layer(norm_type=norm_type) | |
kw = 4 | |
padw = int(np.ceil((kw - 1.0) / 2)) | |
sequence = [[self.use_spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), norm_type), nn.LeakyReLU(0.2, True)]] | |
nf = ndf | |
for n in range(1, n_layers): | |
nf_prev = nf | |
nf = min(nf * 2, 512) | |
sequence += [[self.use_spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), norm_type), | |
norm_layer(nf), | |
nn.LeakyReLU(0.2, True)]] | |
nf_prev = nf | |
nf = min(nf * 2, 512) | |
sequence += [[self.use_spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), norm_type), | |
norm_layer(nf), | |
nn.LeakyReLU(0.2, True)]] | |
sequence += [[self.use_spectral_norm(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw), norm_type)]] | |
self.model = nn.Sequential() | |
for n in range(len(sequence)): | |
self.model.add_module('child' + str(n), nn.Sequential(*sequence[n])) | |
self.model.apply(self.weights_init) | |
def use_spectral_norm(self, module, norm_type='spectral'): | |
if 'spectral' in norm_type: | |
return spectral_norm(module) | |
return module | |
def get_norm_layer(self, norm_type='instance'): | |
if 'batch' in norm_type: | |
norm_layer = functools.partial(nn.BatchNorm2d, affine=True) | |
elif 'instance' in norm_type: | |
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) | |
else: | |
norm_layer = functools.partial(nn.Identity) | |
return norm_layer | |
def weights_init(self, m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
m.weight.data.normal_(0.0, 0.02) | |
elif classname.find('BatchNorm2d') != -1: | |
m.weight.data.normal_(1.0, 0.02) | |
m.bias.data.fill_(0) | |
def forward(self, x): | |
return self.model(x) | |
class Discriminator_UNet(nn.Module): | |
"""Defines a U-Net discriminator with spectral normalization (SN)""" | |
def __init__(self, input_nc=3, ndf=64): | |
super(Discriminator_UNet, self).__init__() | |
norm = spectral_norm | |
self.conv0 = nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1) | |
self.conv1 = norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)) | |
self.conv2 = norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)) | |
self.conv3 = norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)) | |
# upsample | |
self.conv4 = norm(nn.Conv2d(ndf * 8, ndf * 4, 3, 1, 1, bias=False)) | |
self.conv5 = norm(nn.Conv2d(ndf * 4, ndf * 2, 3, 1, 1, bias=False)) | |
self.conv6 = norm(nn.Conv2d(ndf * 2, ndf, 3, 1, 1, bias=False)) | |
# extra | |
self.conv7 = norm(nn.Conv2d(ndf, ndf, 3, 1, 1, bias=False)) | |
self.conv8 = norm(nn.Conv2d(ndf, ndf, 3, 1, 1, bias=False)) | |
self.conv9 = nn.Conv2d(ndf, 1, 3, 1, 1) | |
print('using the UNet discriminator') | |
def forward(self, x): | |
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) | |
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) | |
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) | |
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) | |
# upsample | |
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) | |
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) | |
x4 = x4 + x2 | |
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) | |
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) | |
x5 = x5 + x1 | |
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) | |
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) | |
x6 = x6 + x0 | |
# extra | |
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) | |
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) | |
out = self.conv9(out) | |
return out | |
# -------------------------------------------- | |
# VGG style Discriminator with 96x96 input | |
# -------------------------------------------- | |
class Discriminator_VGG_96(nn.Module): | |
def __init__(self, in_nc=3, base_nc=64, ac_type='BL'): | |
super(Discriminator_VGG_96, self).__init__() | |
# features | |
# hxw, c | |
# 96, 64 | |
conv0 = B.conv(in_nc, base_nc, kernel_size=3, mode='C') | |
conv1 = B.conv(base_nc, base_nc, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 48, 64 | |
conv2 = B.conv(base_nc, base_nc*2, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv3 = B.conv(base_nc*2, base_nc*2, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 24, 128 | |
conv4 = B.conv(base_nc*2, base_nc*4, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv5 = B.conv(base_nc*4, base_nc*4, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 12, 256 | |
conv6 = B.conv(base_nc*4, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv7 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 6, 512 | |
conv8 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv9 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 3, 512 | |
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, | |
conv5, conv6, conv7, conv8, conv9) | |
# classifier | |
self.classifier = nn.Sequential( | |
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) | |
def forward(self, x): | |
x = self.features(x) | |
x = x.view(x.size(0), -1) | |
x = self.classifier(x) | |
return x | |
# -------------------------------------------- | |
# VGG style Discriminator with 128x128 input | |
# -------------------------------------------- | |
class Discriminator_VGG_128(nn.Module): | |
def __init__(self, in_nc=3, base_nc=64, ac_type='BL'): | |
super(Discriminator_VGG_128, self).__init__() | |
# features | |
# hxw, c | |
# 128, 64 | |
conv0 = B.conv(in_nc, base_nc, kernel_size=3, mode='C') | |
conv1 = B.conv(base_nc, base_nc, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 64, 64 | |
conv2 = B.conv(base_nc, base_nc*2, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv3 = B.conv(base_nc*2, base_nc*2, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 32, 128 | |
conv4 = B.conv(base_nc*2, base_nc*4, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv5 = B.conv(base_nc*4, base_nc*4, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 16, 256 | |
conv6 = B.conv(base_nc*4, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv7 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 8, 512 | |
conv8 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv9 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 4, 512 | |
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, | |
conv5, conv6, conv7, conv8, conv9) | |
# classifier | |
self.classifier = nn.Sequential(nn.Linear(512 * 4 * 4, 100), | |
nn.LeakyReLU(0.2, True), | |
nn.Linear(100, 1)) | |
def forward(self, x): | |
x = self.features(x) | |
x = x.view(x.size(0), -1) | |
x = self.classifier(x) | |
return x | |
# -------------------------------------------- | |
# VGG style Discriminator with 192x192 input | |
# -------------------------------------------- | |
class Discriminator_VGG_192(nn.Module): | |
def __init__(self, in_nc=3, base_nc=64, ac_type='BL'): | |
super(Discriminator_VGG_192, self).__init__() | |
# features | |
# hxw, c | |
# 192, 64 | |
conv0 = B.conv(in_nc, base_nc, kernel_size=3, mode='C') | |
conv1 = B.conv(base_nc, base_nc, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 96, 64 | |
conv2 = B.conv(base_nc, base_nc*2, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv3 = B.conv(base_nc*2, base_nc*2, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 48, 128 | |
conv4 = B.conv(base_nc*2, base_nc*4, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv5 = B.conv(base_nc*4, base_nc*4, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 24, 256 | |
conv6 = B.conv(base_nc*4, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv7 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 12, 512 | |
conv8 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv9 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 6, 512 | |
conv10 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
conv11 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
# 3, 512 | |
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, | |
conv6, conv7, conv8, conv9, conv10, conv11) | |
# classifier | |
self.classifier = nn.Sequential(nn.Linear(512 * 3 * 3, 100), | |
nn.LeakyReLU(0.2, True), | |
nn.Linear(100, 1)) | |
def forward(self, x): | |
x = self.features(x) | |
x = x.view(x.size(0), -1) | |
x = self.classifier(x) | |
return x | |
# -------------------------------------------- | |
# SN-VGG style Discriminator with 128x128 input | |
# -------------------------------------------- | |
class Discriminator_VGG_128_SN(nn.Module): | |
def __init__(self): | |
super(Discriminator_VGG_128_SN, self).__init__() | |
# features | |
# hxw, c | |
# 128, 64 | |
self.lrelu = nn.LeakyReLU(0.2, True) | |
self.conv0 = spectral_norm(nn.Conv2d(3, 64, 3, 1, 1)) | |
self.conv1 = spectral_norm(nn.Conv2d(64, 64, 4, 2, 1)) | |
# 64, 64 | |
self.conv2 = spectral_norm(nn.Conv2d(64, 128, 3, 1, 1)) | |
self.conv3 = spectral_norm(nn.Conv2d(128, 128, 4, 2, 1)) | |
# 32, 128 | |
self.conv4 = spectral_norm(nn.Conv2d(128, 256, 3, 1, 1)) | |
self.conv5 = spectral_norm(nn.Conv2d(256, 256, 4, 2, 1)) | |
# 16, 256 | |
self.conv6 = spectral_norm(nn.Conv2d(256, 512, 3, 1, 1)) | |
self.conv7 = spectral_norm(nn.Conv2d(512, 512, 4, 2, 1)) | |
# 8, 512 | |
self.conv8 = spectral_norm(nn.Conv2d(512, 512, 3, 1, 1)) | |
self.conv9 = spectral_norm(nn.Conv2d(512, 512, 4, 2, 1)) | |
# 4, 512 | |
# classifier | |
self.linear0 = spectral_norm(nn.Linear(512 * 4 * 4, 100)) | |
self.linear1 = spectral_norm(nn.Linear(100, 1)) | |
def forward(self, x): | |
x = self.lrelu(self.conv0(x)) | |
x = self.lrelu(self.conv1(x)) | |
x = self.lrelu(self.conv2(x)) | |
x = self.lrelu(self.conv3(x)) | |
x = self.lrelu(self.conv4(x)) | |
x = self.lrelu(self.conv5(x)) | |
x = self.lrelu(self.conv6(x)) | |
x = self.lrelu(self.conv7(x)) | |
x = self.lrelu(self.conv8(x)) | |
x = self.lrelu(self.conv9(x)) | |
x = x.view(x.size(0), -1) | |
x = self.lrelu(self.linear0(x)) | |
x = self.linear1(x) | |
return x | |
if __name__ == '__main__': | |
x = torch.rand(1, 3, 96, 96) | |
net = Discriminator_VGG_96() | |
net.eval() | |
with torch.no_grad(): | |
y = net(x) | |
print(y.size()) | |
x = torch.rand(1, 3, 128, 128) | |
net = Discriminator_VGG_128() | |
net.eval() | |
with torch.no_grad(): | |
y = net(x) | |
print(y.size()) | |
x = torch.rand(1, 3, 192, 192) | |
net = Discriminator_VGG_192() | |
net.eval() | |
with torch.no_grad(): | |
y = net(x) | |
print(y.size()) | |
x = torch.rand(1, 3, 128, 128) | |
net = Discriminator_VGG_128_SN() | |
net.eval() | |
with torch.no_grad(): | |
y = net(x) | |
print(y.size()) | |
# run models/network_discriminator.py | |