刘虹雨
update
8ed2f16
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# import util.util as util
from lib.models.networks.architecture import get_nonspade_norm_layer
class MultiscaleDiscriminator(nn.Module):
def __init__(self, opt):
super(MultiscaleDiscriminator, self).__init__()
self.opt = opt
for i in range(opt.model.net_discriminator.num_D): # num_D = 2
subnetD = self.create_single_discriminator(opt)
self.add_module('discriminator_%d' % i, subnetD)
def create_single_discriminator(self, opt):
subarch = opt.model.net_discriminator.netD_subarch # netD_subarch = n_layer
if subarch == 'n_layer':
netD = NLayerDiscriminator(opt)
else:
raise ValueError('unrecognized discriminator subarchitecture %s' % subarch)
return netD
def downsample(self, input):
return F.avg_pool2d(input, kernel_size=3,
stride=2, padding=[1, 1],
count_include_pad=False)
# Returns list of lists of discriminator outputs.
# The final result is of size opt.model.net_discriminator.num_D x opt.model.net_discriminator.n_layers_D
def forward(self, input):
result = []
get_intermediate_features = not self.opt.model.net_discriminator.no_ganFeat_loss
for name, D in self.named_children():
out = D(input)
if not get_intermediate_features:
out = [out]
result.append(out)
input = self.downsample(input)
return result
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, opt):
super(NLayerDiscriminator, self).__init__()
self.opt = opt
kw = 4
padw = int(np.ceil((kw - 1.0) / 2))
nf = opt.model.net_discriminator.ndf
input_nc = self.compute_D_input_nc(opt)
norm_layer = get_nonspade_norm_layer(opt, opt.model.net_discriminator.norm_D)
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, False)]]
for n in range(1, opt.model.net_discriminator.n_layers_D): # n_layers_D = 4
nf_prev = nf
nf = min(nf * 2, 512)
stride = 1 if n == opt.model.net_discriminator.n_layers_D - 1 else 2
sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw,
stride=stride, padding=padw)),
nn.LeakyReLU(0.2, False)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
# We divide the layers into groups to extract intermediate layer outputs
for n in range(len(sequence)):
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
def compute_D_input_nc(self, opt):
if opt.model.net_discriminator.D_input == "concat":
input_nc = opt.model.net_discriminator.label_nc + opt.model.net_discriminator.output_nc
if opt.model.net_discriminator.contain_dontcare_label:
input_nc += 1
if not opt.model.net_discriminator.no_instance:
input_nc += 1
else:
input_nc = 3
return input_nc
def forward(self, input):
results = [input]
for submodel in self.children():
# intermediate_output = checkpoint(submodel, results[-1])
intermediate_output = submodel(results[-1])
results.append(intermediate_output)
get_intermediate_features = not self.opt.model.net_discriminator.no_ganFeat_loss
if get_intermediate_features:
return results[0:]
else:
return results[-1]
class AudioSubDiscriminator(nn.Module):
def __init__(self, opt, nc, audio_nc):
super(AudioSubDiscriminator, self).__init__()
norm_layer = get_nonspade_norm_layer(opt, opt.model.net_discriminator.norm_D)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
sequence = []
sequence += [norm_layer(nn.Conv1d(nc, nc, 3, 2, 1)),
nn.ReLU()
]
sequence += [norm_layer(nn.Conv1d(nc, audio_nc, 3, 2, 1)),
nn.ReLU()
]
self.conv = nn.Sequential(*sequence)
self.cosine = nn.CosineSimilarity()
self.mapping = nn.Linear(audio_nc, audio_nc)
def forward(self, result, audio):
region = result[result.shape[3] // 2:result.shape[3] - 2, result.shape[4] // 3: 2 * result.shape[4] // 3]
visual = self.avgpool(region)
cos = self.cosine(visual, self.mapping(audio))
return cos
class ImageDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, opt, n_layers=3, norm_layer=nn.BatchNorm2d):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(ImageDiscriminator, self).__init__()
use_bias = norm_layer == nn.InstanceNorm2d
if opt.model.net_discriminator.D_input == "concat":
input_nc = opt.model.net_discriminator.label_nc + opt.model.net_discriminator.output_nc
else:
input_nc = opt.model.net_discriminator.label_nc
ndf = 64
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.model(input)
class FeatureDiscriminator(nn.Module):
def __init__(self, opt):
super(FeatureDiscriminator, self).__init__()
self.opt = opt
self.fc = nn.Linear(512, opt.model.net_discriminator.num_labels)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x0 = x.view(-1, 512)
net = self.dropout(x0)
net = self.fc(net)
return net