Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,188 Bytes
8ed2f16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# import util.util as util
from lib.models.networks.architecture import get_nonspade_norm_layer
class MultiscaleDiscriminator(nn.Module):
def __init__(self, opt):
super(MultiscaleDiscriminator, self).__init__()
self.opt = opt
for i in range(opt.model.net_discriminator.num_D): # num_D = 2
subnetD = self.create_single_discriminator(opt)
self.add_module('discriminator_%d' % i, subnetD)
def create_single_discriminator(self, opt):
subarch = opt.model.net_discriminator.netD_subarch # netD_subarch = n_layer
if subarch == 'n_layer':
netD = NLayerDiscriminator(opt)
else:
raise ValueError('unrecognized discriminator subarchitecture %s' % subarch)
return netD
def downsample(self, input):
return F.avg_pool2d(input, kernel_size=3,
stride=2, padding=[1, 1],
count_include_pad=False)
# Returns list of lists of discriminator outputs.
# The final result is of size opt.model.net_discriminator.num_D x opt.model.net_discriminator.n_layers_D
def forward(self, input):
result = []
get_intermediate_features = not self.opt.model.net_discriminator.no_ganFeat_loss
for name, D in self.named_children():
out = D(input)
if not get_intermediate_features:
out = [out]
result.append(out)
input = self.downsample(input)
return result
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, opt):
super(NLayerDiscriminator, self).__init__()
self.opt = opt
kw = 4
padw = int(np.ceil((kw - 1.0) / 2))
nf = opt.model.net_discriminator.ndf
input_nc = self.compute_D_input_nc(opt)
norm_layer = get_nonspade_norm_layer(opt, opt.model.net_discriminator.norm_D)
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, False)]]
for n in range(1, opt.model.net_discriminator.n_layers_D): # n_layers_D = 4
nf_prev = nf
nf = min(nf * 2, 512)
stride = 1 if n == opt.model.net_discriminator.n_layers_D - 1 else 2
sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw,
stride=stride, padding=padw)),
nn.LeakyReLU(0.2, False)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
# We divide the layers into groups to extract intermediate layer outputs
for n in range(len(sequence)):
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
def compute_D_input_nc(self, opt):
if opt.model.net_discriminator.D_input == "concat":
input_nc = opt.model.net_discriminator.label_nc + opt.model.net_discriminator.output_nc
if opt.model.net_discriminator.contain_dontcare_label:
input_nc += 1
if not opt.model.net_discriminator.no_instance:
input_nc += 1
else:
input_nc = 3
return input_nc
def forward(self, input):
results = [input]
for submodel in self.children():
# intermediate_output = checkpoint(submodel, results[-1])
intermediate_output = submodel(results[-1])
results.append(intermediate_output)
get_intermediate_features = not self.opt.model.net_discriminator.no_ganFeat_loss
if get_intermediate_features:
return results[0:]
else:
return results[-1]
class AudioSubDiscriminator(nn.Module):
def __init__(self, opt, nc, audio_nc):
super(AudioSubDiscriminator, self).__init__()
norm_layer = get_nonspade_norm_layer(opt, opt.model.net_discriminator.norm_D)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
sequence = []
sequence += [norm_layer(nn.Conv1d(nc, nc, 3, 2, 1)),
nn.ReLU()
]
sequence += [norm_layer(nn.Conv1d(nc, audio_nc, 3, 2, 1)),
nn.ReLU()
]
self.conv = nn.Sequential(*sequence)
self.cosine = nn.CosineSimilarity()
self.mapping = nn.Linear(audio_nc, audio_nc)
def forward(self, result, audio):
region = result[result.shape[3] // 2:result.shape[3] - 2, result.shape[4] // 3: 2 * result.shape[4] // 3]
visual = self.avgpool(region)
cos = self.cosine(visual, self.mapping(audio))
return cos
class ImageDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, opt, n_layers=3, norm_layer=nn.BatchNorm2d):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(ImageDiscriminator, self).__init__()
use_bias = norm_layer == nn.InstanceNorm2d
if opt.model.net_discriminator.D_input == "concat":
input_nc = opt.model.net_discriminator.label_nc + opt.model.net_discriminator.output_nc
else:
input_nc = opt.model.net_discriminator.label_nc
ndf = 64
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.model(input)
class FeatureDiscriminator(nn.Module):
def __init__(self, opt):
super(FeatureDiscriminator, self).__init__()
self.opt = opt
self.fc = nn.Linear(512, opt.model.net_discriminator.num_labels)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x0 = x.view(-1, 512)
net = self.dropout(x0)
net = self.fc(net)
return net
|