|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from collections import OrderedDict |
|
from pathlib import Path |
|
import requests |
|
import pickle |
|
import sys |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
class MyLinear(nn.Module): |
|
"""Linear layer with equalized learning rate and custom learning rate multiplier.""" |
|
def __init__(self, input_size, output_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True): |
|
super().__init__() |
|
he_std = gain * input_size**(-0.5) |
|
|
|
if use_wscale: |
|
init_std = 1.0 / lrmul |
|
self.w_mul = he_std * lrmul |
|
else: |
|
init_std = he_std / lrmul |
|
self.w_mul = lrmul |
|
self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std) |
|
if bias: |
|
self.bias = torch.nn.Parameter(torch.zeros(output_size)) |
|
self.b_mul = lrmul |
|
else: |
|
self.bias = None |
|
|
|
def forward(self, x): |
|
bias = self.bias |
|
if bias is not None: |
|
bias = bias * self.b_mul |
|
return F.linear(x, self.weight * self.w_mul, bias) |
|
|
|
class MyConv2d(nn.Module): |
|
"""Conv layer with equalized learning rate and custom learning rate multiplier.""" |
|
def __init__(self, input_channels, output_channels, kernel_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True, |
|
intermediate=None, upscale=False): |
|
super().__init__() |
|
if upscale: |
|
self.upscale = Upscale2d() |
|
else: |
|
self.upscale = None |
|
he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) |
|
self.kernel_size = kernel_size |
|
if use_wscale: |
|
init_std = 1.0 / lrmul |
|
self.w_mul = he_std * lrmul |
|
else: |
|
init_std = he_std / lrmul |
|
self.w_mul = lrmul |
|
self.weight = torch.nn.Parameter(torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std) |
|
if bias: |
|
self.bias = torch.nn.Parameter(torch.zeros(output_channels)) |
|
self.b_mul = lrmul |
|
else: |
|
self.bias = None |
|
self.intermediate = intermediate |
|
|
|
def forward(self, x): |
|
bias = self.bias |
|
if bias is not None: |
|
bias = bias * self.b_mul |
|
|
|
have_convolution = False |
|
if self.upscale is not None and min(x.shape[2:]) * 2 >= 128: |
|
|
|
|
|
w = self.weight * self.w_mul |
|
w = w.permute(1, 0, 2, 3) |
|
|
|
w = F.pad(w, (1,1,1,1)) |
|
w = w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] |
|
x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1)-1)//2) |
|
have_convolution = True |
|
elif self.upscale is not None: |
|
x = self.upscale(x) |
|
|
|
if not have_convolution and self.intermediate is None: |
|
return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size//2) |
|
elif not have_convolution: |
|
x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size//2) |
|
|
|
if self.intermediate is not None: |
|
x = self.intermediate(x) |
|
if bias is not None: |
|
x = x + bias.view(1, -1, 1, 1) |
|
return x |
|
|
|
class NoiseLayer(nn.Module): |
|
"""adds noise. noise is per pixel (constant over channels) with per-channel weight""" |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.zeros(channels)) |
|
self.noise = None |
|
|
|
def forward(self, x, noise=None): |
|
if noise is None and self.noise is None: |
|
noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype) |
|
elif noise is None: |
|
|
|
|
|
|
|
noise = self.noise |
|
x = x + self.weight.view(1, -1, 1, 1) * noise |
|
return x |
|
|
|
class StyleMod(nn.Module): |
|
def __init__(self, latent_size, channels, use_wscale): |
|
super(StyleMod, self).__init__() |
|
self.lin = MyLinear(latent_size, |
|
channels * 2, |
|
gain=1.0, use_wscale=use_wscale) |
|
|
|
def forward(self, x, latent): |
|
style = self.lin(latent) |
|
shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1] |
|
style = style.view(shape) |
|
x = x * (style[:, 0] + 1.) + style[:, 1] |
|
return x |
|
|
|
class PixelNormLayer(nn.Module): |
|
def __init__(self, epsilon=1e-8): |
|
super().__init__() |
|
self.epsilon = epsilon |
|
def forward(self, x): |
|
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon) |
|
|
|
class BlurLayer(nn.Module): |
|
def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1): |
|
super(BlurLayer, self).__init__() |
|
kernel=[1, 2, 1] |
|
kernel = torch.tensor(kernel, dtype=torch.float32) |
|
kernel = kernel[:, None] * kernel[None, :] |
|
kernel = kernel[None, None] |
|
if normalize: |
|
kernel = kernel / kernel.sum() |
|
if flip: |
|
kernel = kernel[:, :, ::-1, ::-1] |
|
self.register_buffer('kernel', kernel) |
|
self.stride = stride |
|
|
|
def forward(self, x): |
|
|
|
kernel = self.kernel.expand(x.size(1), -1, -1, -1) |
|
x = F.conv2d( |
|
x, |
|
kernel, |
|
stride=self.stride, |
|
padding=int((self.kernel.size(2)-1)/2), |
|
groups=x.size(1) |
|
) |
|
return x |
|
|
|
def upscale2d(x, factor=2, gain=1): |
|
assert x.dim() == 4 |
|
if gain != 1: |
|
x = x * gain |
|
if factor != 1: |
|
shape = x.shape |
|
x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor) |
|
x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3]) |
|
return x |
|
|
|
class Upscale2d(nn.Module): |
|
def __init__(self, factor=2, gain=1): |
|
super().__init__() |
|
assert isinstance(factor, int) and factor >= 1 |
|
self.gain = gain |
|
self.factor = factor |
|
def forward(self, x): |
|
return upscale2d(x, factor=self.factor, gain=self.gain) |
|
|
|
class G_mapping(nn.Sequential): |
|
def __init__(self, nonlinearity='lrelu', use_wscale=True): |
|
act, gain = {'relu': (torch.relu, np.sqrt(2)), |
|
'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity] |
|
layers = [ |
|
('pixel_norm', PixelNormLayer()), |
|
('dense0', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), |
|
('dense0_act', act), |
|
('dense1', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), |
|
('dense1_act', act), |
|
('dense2', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), |
|
('dense2_act', act), |
|
('dense3', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), |
|
('dense3_act', act), |
|
('dense4', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), |
|
('dense4_act', act), |
|
('dense5', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), |
|
('dense5_act', act), |
|
('dense6', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), |
|
('dense6_act', act), |
|
('dense7', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), |
|
('dense7_act', act) |
|
] |
|
super().__init__(OrderedDict(layers)) |
|
|
|
def forward(self, x): |
|
return super().forward(x) |
|
|
|
class Truncation(nn.Module): |
|
def __init__(self, avg_latent, max_layer=8, threshold=0.7): |
|
super().__init__() |
|
self.max_layer = max_layer |
|
self.threshold = threshold |
|
self.register_buffer('avg_latent', avg_latent) |
|
def forward(self, x): |
|
assert x.dim() == 3 |
|
interp = torch.lerp(self.avg_latent, x, self.threshold) |
|
do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1) |
|
return torch.where(do_trunc, interp, x) |
|
|
|
class LayerEpilogue(nn.Module): |
|
"""Things to do at the end of each layer.""" |
|
def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer): |
|
super().__init__() |
|
layers = [] |
|
if use_noise: |
|
layers.append(('noise', NoiseLayer(channels))) |
|
layers.append(('activation', activation_layer)) |
|
if use_pixel_norm: |
|
layers.append(('pixel_norm', PixelNorm())) |
|
if use_instance_norm: |
|
layers.append(('instance_norm', nn.InstanceNorm2d(channels))) |
|
self.top_epi = nn.Sequential(OrderedDict(layers)) |
|
if use_styles: |
|
self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale) |
|
else: |
|
self.style_mod = None |
|
def forward(self, x, dlatents_in_slice=None): |
|
x = self.top_epi(x) |
|
if self.style_mod is not None: |
|
x = self.style_mod(x, dlatents_in_slice) |
|
else: |
|
assert dlatents_in_slice is None |
|
return x |
|
|
|
|
|
class InputBlock(nn.Module): |
|
def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer): |
|
super().__init__() |
|
self.const_input_layer = const_input_layer |
|
self.nf = nf |
|
if self.const_input_layer: |
|
|
|
self.const = nn.Parameter(torch.ones(1, nf, 4, 4)) |
|
self.bias = nn.Parameter(torch.ones(nf)) |
|
else: |
|
self.dense = MyLinear(dlatent_size, nf*16, gain=gain/4, use_wscale=use_wscale) |
|
self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) |
|
self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale) |
|
self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) |
|
|
|
def forward(self, dlatents_in_range): |
|
batch_size = dlatents_in_range.size(0) |
|
if self.const_input_layer: |
|
x = self.const.expand(batch_size, -1, -1, -1) |
|
x = x + self.bias.view(1, -1, 1, 1) |
|
else: |
|
x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4) |
|
x = self.epi1(x, dlatents_in_range[:, 0]) |
|
x = self.conv(x) |
|
x = self.epi2(x, dlatents_in_range[:, 1]) |
|
return x |
|
|
|
|
|
class GSynthesisBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer): |
|
|
|
super().__init__() |
|
if blur_filter: |
|
blur = BlurLayer(blur_filter) |
|
else: |
|
blur = None |
|
self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale, |
|
intermediate=blur, upscale=True) |
|
self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) |
|
self.conv1 = MyConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale) |
|
self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) |
|
|
|
def forward(self, x, dlatents_in_range): |
|
x = self.conv0_up(x) |
|
x = self.epi1(x, dlatents_in_range[:, 0]) |
|
x = self.conv1(x) |
|
x = self.epi2(x, dlatents_in_range[:, 1]) |
|
return x |
|
|
|
class G_synthesis(nn.Module): |
|
def __init__(self, |
|
dlatent_size = 512, |
|
num_channels = 3, |
|
resolution = 1024, |
|
fmap_base = 8192, |
|
fmap_decay = 1.0, |
|
fmap_max = 512, |
|
use_styles = True, |
|
const_input_layer = True, |
|
use_noise = True, |
|
randomize_noise = True, |
|
nonlinearity = 'lrelu', |
|
use_wscale = True, |
|
use_pixel_norm = False, |
|
use_instance_norm = True, |
|
dtype = torch.float32, |
|
blur_filter = [1,2,1], |
|
): |
|
|
|
super().__init__() |
|
def nf(stage): |
|
return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) |
|
self.dlatent_size = dlatent_size |
|
resolution_log2 = int(np.log2(resolution)) |
|
assert resolution == 2**resolution_log2 and resolution >= 4 |
|
|
|
act, gain = {'relu': (torch.relu, np.sqrt(2)), |
|
'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity] |
|
num_layers = resolution_log2 * 2 - 2 |
|
num_styles = num_layers if use_styles else 1 |
|
torgbs = [] |
|
blocks = [] |
|
for res in range(2, resolution_log2 + 1): |
|
channels = nf(res-1) |
|
name = '{s}x{s}'.format(s=2**res) |
|
if res == 2: |
|
blocks.append((name, |
|
InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale, |
|
use_noise, use_pixel_norm, use_instance_norm, use_styles, act))) |
|
|
|
else: |
|
blocks.append((name, |
|
GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, act))) |
|
last_channels = channels |
|
self.torgb = MyConv2d(channels, num_channels, 1, gain=1, use_wscale=use_wscale) |
|
self.blocks = nn.ModuleDict(OrderedDict(blocks)) |
|
|
|
def forward(self, dlatents_in): |
|
|
|
|
|
batch_size = dlatents_in.size(0) |
|
for i, m in enumerate(self.blocks.values()): |
|
if i == 0: |
|
x = m(dlatents_in[:, 2*i:2*i+2]) |
|
else: |
|
x = m(x, dlatents_in[:, 2*i:2*i+2]) |
|
rgb = self.torgb(x) |
|
return rgb |
|
|
|
|
|
class StyleGAN_G(nn.Sequential): |
|
def __init__(self, resolution, truncation=1.0): |
|
self.resolution = resolution |
|
self.layers = OrderedDict([ |
|
('g_mapping', G_mapping()), |
|
|
|
('g_synthesis', G_synthesis(resolution=resolution)), |
|
]) |
|
super().__init__(self.layers) |
|
|
|
def forward(self, x, latent_is_w=False): |
|
if isinstance(x, list): |
|
assert len(x) == 18, 'Must provide 1 or 18 latents' |
|
if not latent_is_w: |
|
x = [self.layers['g_mapping'].forward(l) for l in x] |
|
x = torch.stack(x, dim=1) |
|
else: |
|
if not latent_is_w: |
|
x = self.layers['g_mapping'].forward(x) |
|
x = x.unsqueeze(1).expand(-1, 18, -1) |
|
|
|
x = self.layers['g_synthesis'].forward(x) |
|
|
|
return x |
|
|
|
|
|
def load_weights(self, checkpoint): |
|
self.load_state_dict(torch.load(checkpoint)) |
|
|
|
def export_from_tf(self, pickle_path): |
|
module_path = Path(__file__).parent / 'stylegan_tf' |
|
sys.path.append(str(module_path.resolve())) |
|
|
|
import dnnlib, dnnlib.tflib, pickle, torch, collections |
|
dnnlib.tflib.init_tf() |
|
|
|
weights = pickle.load(open(pickle_path,'rb')) |
|
weights_pt = [collections.OrderedDict([(k, torch.from_numpy(v.value().eval())) for k,v in w.trainables.items()]) for w in weights] |
|
|
|
|
|
|
|
state_G, state_D, state_Gs = weights_pt |
|
def key_translate(k): |
|
k = k.lower().split('/') |
|
if k[0] == 'g_synthesis': |
|
if not k[1].startswith('torgb'): |
|
k.insert(1, 'blocks') |
|
k = '.'.join(k) |
|
k = (k.replace('const.const','const').replace('const.bias','bias').replace('const.stylemod','epi1.style_mod.lin') |
|
.replace('const.noise.weight','epi1.top_epi.noise.weight') |
|
.replace('conv.noise.weight','epi2.top_epi.noise.weight') |
|
.replace('conv.stylemod','epi2.style_mod.lin') |
|
.replace('conv0_up.noise.weight', 'epi1.top_epi.noise.weight') |
|
.replace('conv0_up.stylemod','epi1.style_mod.lin') |
|
.replace('conv1.noise.weight', 'epi2.top_epi.noise.weight') |
|
.replace('conv1.stylemod','epi2.style_mod.lin') |
|
.replace('torgb_lod0','torgb')) |
|
else: |
|
k = '.'.join(k) |
|
return k |
|
|
|
def weight_translate(k, w): |
|
k = key_translate(k) |
|
if k.endswith('.weight'): |
|
if w.dim() == 2: |
|
w = w.t() |
|
elif w.dim() == 1: |
|
pass |
|
else: |
|
assert w.dim() == 4 |
|
w = w.permute(3, 2, 0, 1) |
|
return w |
|
|
|
|
|
param_dict = {key_translate(k) : weight_translate(k, v) for k,v in state_Gs.items() if 'torgb_lod' not in key_translate(k)} |
|
if 1: |
|
sd_shapes = {k : v.shape for k,v in self.state_dict().items()} |
|
param_shapes = {k : v.shape for k,v in param_dict.items() } |
|
|
|
for k in list(sd_shapes)+list(param_shapes): |
|
pds = param_shapes.get(k) |
|
sds = sd_shapes.get(k) |
|
if pds is None: |
|
print ("sd only", k, sds) |
|
elif sds is None: |
|
print ("pd only", k, pds) |
|
elif sds != pds: |
|
print ("mismatch!", k, pds, sds) |
|
|
|
self.load_state_dict(param_dict, strict=False) |
|
torch.save(self.state_dict(), Path(pickle_path).with_suffix('.pt')) |