Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
# import util.util as util | |
from lib.models.networks.architecture import get_nonspade_norm_layer | |
class MultiscaleDiscriminator(nn.Module): | |
def __init__(self, opt): | |
super(MultiscaleDiscriminator, self).__init__() | |
self.opt = opt | |
for i in range(opt.model.net_discriminator.num_D): # num_D = 2 | |
subnetD = self.create_single_discriminator(opt) | |
self.add_module('discriminator_%d' % i, subnetD) | |
def create_single_discriminator(self, opt): | |
subarch = opt.model.net_discriminator.netD_subarch # netD_subarch = n_layer | |
if subarch == 'n_layer': | |
netD = NLayerDiscriminator(opt) | |
else: | |
raise ValueError('unrecognized discriminator subarchitecture %s' % subarch) | |
return netD | |
def downsample(self, input): | |
return F.avg_pool2d(input, kernel_size=3, | |
stride=2, padding=[1, 1], | |
count_include_pad=False) | |
# Returns list of lists of discriminator outputs. | |
# The final result is of size opt.model.net_discriminator.num_D x opt.model.net_discriminator.n_layers_D | |
def forward(self, input): | |
result = [] | |
get_intermediate_features = not self.opt.model.net_discriminator.no_ganFeat_loss | |
for name, D in self.named_children(): | |
out = D(input) | |
if not get_intermediate_features: | |
out = [out] | |
result.append(out) | |
input = self.downsample(input) | |
return result | |
# Defines the PatchGAN discriminator with the specified arguments. | |
class NLayerDiscriminator(nn.Module): | |
def __init__(self, opt): | |
super(NLayerDiscriminator, self).__init__() | |
self.opt = opt | |
kw = 4 | |
padw = int(np.ceil((kw - 1.0) / 2)) | |
nf = opt.model.net_discriminator.ndf | |
input_nc = self.compute_D_input_nc(opt) | |
norm_layer = get_nonspade_norm_layer(opt, opt.model.net_discriminator.norm_D) | |
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), | |
nn.LeakyReLU(0.2, False)]] | |
for n in range(1, opt.model.net_discriminator.n_layers_D): # n_layers_D = 4 | |
nf_prev = nf | |
nf = min(nf * 2, 512) | |
stride = 1 if n == opt.model.net_discriminator.n_layers_D - 1 else 2 | |
sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, | |
stride=stride, padding=padw)), | |
nn.LeakyReLU(0.2, False) | |
]] | |
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] | |
# We divide the layers into groups to extract intermediate layer outputs | |
for n in range(len(sequence)): | |
self.add_module('model' + str(n), nn.Sequential(*sequence[n])) | |
def compute_D_input_nc(self, opt): | |
if opt.model.net_discriminator.D_input == "concat": | |
input_nc = opt.model.net_discriminator.label_nc + opt.model.net_discriminator.output_nc | |
if opt.model.net_discriminator.contain_dontcare_label: | |
input_nc += 1 | |
if not opt.model.net_discriminator.no_instance: | |
input_nc += 1 | |
else: | |
input_nc = 3 | |
return input_nc | |
def forward(self, input): | |
results = [input] | |
for submodel in self.children(): | |
# intermediate_output = checkpoint(submodel, results[-1]) | |
intermediate_output = submodel(results[-1]) | |
results.append(intermediate_output) | |
get_intermediate_features = not self.opt.model.net_discriminator.no_ganFeat_loss | |
if get_intermediate_features: | |
return results[0:] | |
else: | |
return results[-1] | |
class AudioSubDiscriminator(nn.Module): | |
def __init__(self, opt, nc, audio_nc): | |
super(AudioSubDiscriminator, self).__init__() | |
norm_layer = get_nonspade_norm_layer(opt, opt.model.net_discriminator.norm_D) | |
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
sequence = [] | |
sequence += [norm_layer(nn.Conv1d(nc, nc, 3, 2, 1)), | |
nn.ReLU() | |
] | |
sequence += [norm_layer(nn.Conv1d(nc, audio_nc, 3, 2, 1)), | |
nn.ReLU() | |
] | |
self.conv = nn.Sequential(*sequence) | |
self.cosine = nn.CosineSimilarity() | |
self.mapping = nn.Linear(audio_nc, audio_nc) | |
def forward(self, result, audio): | |
region = result[result.shape[3] // 2:result.shape[3] - 2, result.shape[4] // 3: 2 * result.shape[4] // 3] | |
visual = self.avgpool(region) | |
cos = self.cosine(visual, self.mapping(audio)) | |
return cos | |
class ImageDiscriminator(nn.Module): | |
"""Defines a PatchGAN discriminator""" | |
def __init__(self, opt, n_layers=3, norm_layer=nn.BatchNorm2d): | |
"""Construct a PatchGAN discriminator | |
Parameters: | |
input_nc (int) -- the number of channels in input images | |
ndf (int) -- the number of filters in the last conv layer | |
n_layers (int) -- the number of conv layers in the discriminator | |
norm_layer -- normalization layer | |
""" | |
super(ImageDiscriminator, self).__init__() | |
use_bias = norm_layer == nn.InstanceNorm2d | |
if opt.model.net_discriminator.D_input == "concat": | |
input_nc = opt.model.net_discriminator.label_nc + opt.model.net_discriminator.output_nc | |
else: | |
input_nc = opt.model.net_discriminator.label_nc | |
ndf = 64 | |
kw = 4 | |
padw = 1 | |
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] | |
nf_mult = 1 | |
nf_mult_prev = 1 | |
for n in range(1, n_layers): # gradually increase the number of filters | |
nf_mult_prev = nf_mult | |
nf_mult = min(2 ** n, 8) | |
sequence += [ | |
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), | |
norm_layer(ndf * nf_mult), | |
nn.LeakyReLU(0.2, True) | |
] | |
nf_mult_prev = nf_mult | |
nf_mult = min(2 ** n_layers, 8) | |
sequence += [ | |
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), | |
norm_layer(ndf * nf_mult), | |
nn.LeakyReLU(0.2, True) | |
] | |
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map | |
self.model = nn.Sequential(*sequence) | |
def forward(self, input): | |
"""Standard forward.""" | |
return self.model(input) | |
class FeatureDiscriminator(nn.Module): | |
def __init__(self, opt): | |
super(FeatureDiscriminator, self).__init__() | |
self.opt = opt | |
self.fc = nn.Linear(512, opt.model.net_discriminator.num_labels) | |
self.dropout = nn.Dropout(0.5) | |
def forward(self, x): | |
x0 = x.view(-1, 512) | |
net = self.dropout(x0) | |
net = self.fc(net) | |
return net | |