Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import swapae.util as util | |
from swapae.models.networks import BaseNetwork | |
from swapae.models.networks.stylegan2_layers import ResBlock, ConvLayer, ToRGB, EqualLinear, Blur, Upsample, make_kernel | |
from swapae.models.networks.stylegan2_op import upfirdn2d | |
class ToSpatialCode(torch.nn.Module): | |
def __init__(self, inch, outch, scale): | |
super().__init__() | |
hiddench = inch // 2 | |
self.conv1 = ConvLayer(inch, hiddench, 1, activate=True, bias=True) | |
self.conv2 = ConvLayer(hiddench, outch, 1, activate=False, bias=True) | |
self.scale = scale | |
self.upsample = Upsample([1, 3, 3, 1], 2) | |
self.blur = Blur([1, 3, 3, 1], pad=(2, 1)) | |
self.register_buffer('kernel', make_kernel([1, 3, 3, 1])) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.conv2(x) | |
for i in range(int(np.log2(self.scale))): | |
x = self.upsample(x) | |
return x | |
class StyleGAN2ResnetEncoder(BaseNetwork): | |
def modify_commandline_options(parser, is_train): | |
parser.add_argument("--netE_scale_capacity", default=1.0, type=float) | |
parser.add_argument("--netE_num_downsampling_sp", default=4, type=int) | |
parser.add_argument("--netE_num_downsampling_gl", default=2, type=int) | |
parser.add_argument("--netE_nc_steepness", default=2.0, type=float) | |
return parser | |
def __init__(self, opt): | |
super().__init__(opt) | |
# If antialiasing is used, create a very lightweight Gaussian kernel. | |
blur_kernel = [1, 2, 1] if self.opt.use_antialias else [1] | |
self.add_module("FromRGB", ConvLayer(3, self.nc(0), 1)) | |
self.DownToSpatialCode = nn.Sequential() | |
for i in range(self.opt.netE_num_downsampling_sp): | |
self.DownToSpatialCode.add_module( | |
"ResBlockDownBy%d" % (2 ** i), | |
ResBlock(self.nc(i), self.nc(i + 1), blur_kernel, | |
reflection_pad=True) | |
) | |
# Spatial Code refers to the Structure Code, and | |
# Global Code refers to the Texture Code of the paper. | |
nchannels = self.nc(self.opt.netE_num_downsampling_sp) | |
self.add_module( | |
"ToSpatialCode", | |
nn.Sequential( | |
ConvLayer(nchannels, nchannels, 1, activate=True, bias=True), | |
ConvLayer(nchannels, self.opt.spatial_code_ch, kernel_size=1, | |
activate=False, bias=True) | |
) | |
) | |
self.DownToGlobalCode = nn.Sequential() | |
for i in range(self.opt.netE_num_downsampling_gl): | |
idx_from_beginning = self.opt.netE_num_downsampling_sp + i | |
self.DownToGlobalCode.add_module( | |
"ConvLayerDownBy%d" % (2 ** idx_from_beginning), | |
ConvLayer(self.nc(idx_from_beginning), | |
self.nc(idx_from_beginning + 1), kernel_size=3, | |
blur_kernel=[1], downsample=True, pad=0) | |
) | |
nchannels = self.nc(self.opt.netE_num_downsampling_sp + | |
self.opt.netE_num_downsampling_gl) | |
self.add_module( | |
"ToGlobalCode", | |
nn.Sequential( | |
EqualLinear(nchannels, self.opt.global_code_ch) | |
) | |
) | |
def nc(self, idx): | |
nc = self.opt.netE_nc_steepness ** (5 + idx) | |
nc = nc * self.opt.netE_scale_capacity | |
# nc = min(self.opt.global_code_ch, int(round(nc))) | |
return round(nc) | |
def forward(self, x, extract_features=False): | |
x = self.FromRGB(x) | |
midpoint = self.DownToSpatialCode(x) | |
sp = self.ToSpatialCode(midpoint) | |
if extract_features: | |
padded_midpoint = F.pad(midpoint, (1, 0, 1, 0), mode='reflect') | |
feature = self.DownToGlobalCode[0](padded_midpoint) | |
assert feature.size(2) == sp.size(2) // 2 and \ | |
feature.size(3) == sp.size(3) // 2 | |
feature = F.interpolate( | |
feature, size=(7, 7), mode='bilinear', align_corners=False) | |
x = self.DownToGlobalCode(midpoint) | |
x = x.mean(dim=(2, 3)) | |
gl = self.ToGlobalCode(x) | |
sp = util.normalize(sp) | |
gl = util.normalize(gl) | |
if extract_features: | |
return sp, gl, feature | |
else: | |
return sp, gl | |