|
import argparse |
|
import os |
|
import sys |
|
import pickle |
|
import math |
|
|
|
import torch |
|
import numpy as np |
|
from torchvision import utils |
|
|
|
from model import Generator, Discriminator |
|
|
|
|
|
def convert_modconv(vars, source_name, target_name, flip=False): |
|
weight = vars[source_name + '/weight'].value().eval() |
|
mod_weight = vars[source_name + '/mod_weight'].value().eval() |
|
mod_bias = vars[source_name + '/mod_bias'].value().eval() |
|
noise = vars[source_name + '/noise_strength'].value().eval() |
|
bias = vars[source_name + '/bias'].value().eval() |
|
|
|
dic = { |
|
'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), |
|
'conv.modulation.weight': mod_weight.transpose((1, 0)), |
|
'conv.modulation.bias': mod_bias + 1, |
|
'noise.weight': np.array([noise]), |
|
'activate.bias': bias, |
|
} |
|
|
|
dic_torch = {} |
|
|
|
for k, v in dic.items(): |
|
dic_torch[target_name + '.' + k] = torch.from_numpy(v) |
|
|
|
if flip: |
|
dic_torch[target_name + '.conv.weight'] = torch.flip( |
|
dic_torch[target_name + '.conv.weight'], [3, 4] |
|
) |
|
|
|
return dic_torch |
|
|
|
|
|
def convert_conv(vars, source_name, target_name, bias=True, start=0): |
|
weight = vars[source_name + '/weight'].value().eval() |
|
|
|
dic = {'weight': weight.transpose((3, 2, 0, 1))} |
|
|
|
if bias: |
|
dic['bias'] = vars[source_name + '/bias'].value().eval() |
|
|
|
dic_torch = {} |
|
|
|
dic_torch[target_name + f'.{start}.weight'] = torch.from_numpy(dic['weight']) |
|
|
|
if bias: |
|
dic_torch[target_name + f'.{start + 1}.bias'] = torch.from_numpy(dic['bias']) |
|
|
|
return dic_torch |
|
|
|
|
|
def convert_torgb(vars, source_name, target_name): |
|
weight = vars[source_name + '/weight'].value().eval() |
|
mod_weight = vars[source_name + '/mod_weight'].value().eval() |
|
mod_bias = vars[source_name + '/mod_bias'].value().eval() |
|
bias = vars[source_name + '/bias'].value().eval() |
|
|
|
dic = { |
|
'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), |
|
'conv.modulation.weight': mod_weight.transpose((1, 0)), |
|
'conv.modulation.bias': mod_bias + 1, |
|
'bias': bias.reshape((1, 3, 1, 1)), |
|
} |
|
|
|
dic_torch = {} |
|
|
|
for k, v in dic.items(): |
|
dic_torch[target_name + '.' + k] = torch.from_numpy(v) |
|
|
|
return dic_torch |
|
|
|
|
|
def convert_dense(vars, source_name, target_name): |
|
weight = vars[source_name + '/weight'].value().eval() |
|
bias = vars[source_name + '/bias'].value().eval() |
|
|
|
dic = {'weight': weight.transpose((1, 0)), 'bias': bias} |
|
|
|
dic_torch = {} |
|
|
|
for k, v in dic.items(): |
|
dic_torch[target_name + '.' + k] = torch.from_numpy(v) |
|
|
|
return dic_torch |
|
|
|
|
|
def update(state_dict, new): |
|
for k, v in new.items(): |
|
if k not in state_dict: |
|
raise KeyError(k + ' is not found') |
|
|
|
if v.shape != state_dict[k].shape: |
|
raise ValueError(f'Shape mismatch: {v.shape} vs {state_dict[k].shape}') |
|
|
|
state_dict[k] = v |
|
|
|
|
|
def discriminator_fill_statedict(statedict, vars, size): |
|
log_size = int(math.log(size, 2)) |
|
|
|
update(statedict, convert_conv(vars, f'{size}x{size}/FromRGB', 'convs.0')) |
|
|
|
conv_i = 1 |
|
|
|
for i in range(log_size - 2, 0, -1): |
|
reso = 4 * 2 ** i |
|
update( |
|
statedict, |
|
convert_conv(vars, f'{reso}x{reso}/Conv0', f'convs.{conv_i}.conv1'), |
|
) |
|
update( |
|
statedict, |
|
convert_conv( |
|
vars, f'{reso}x{reso}/Conv1_down', f'convs.{conv_i}.conv2', start=1 |
|
), |
|
) |
|
update( |
|
statedict, |
|
convert_conv( |
|
vars, f'{reso}x{reso}/Skip', f'convs.{conv_i}.skip', start=1, bias=False |
|
), |
|
) |
|
conv_i += 1 |
|
|
|
update(statedict, convert_conv(vars, f'4x4/Conv', 'final_conv')) |
|
update(statedict, convert_dense(vars, f'4x4/Dense0', 'final_linear.0')) |
|
update(statedict, convert_dense(vars, f'Output', 'final_linear.1')) |
|
|
|
return statedict |
|
|
|
|
|
def fill_statedict(state_dict, vars, size): |
|
log_size = int(math.log(size, 2)) |
|
|
|
for i in range(8): |
|
update(state_dict, convert_dense(vars, f'G_mapping/Dense{i}', f'style.{i + 1}')) |
|
|
|
update( |
|
state_dict, |
|
{ |
|
'input.input': torch.from_numpy( |
|
vars['G_synthesis/4x4/Const/const'].value().eval() |
|
) |
|
}, |
|
) |
|
|
|
update(state_dict, convert_torgb(vars, 'G_synthesis/4x4/ToRGB', 'to_rgb1')) |
|
|
|
for i in range(log_size - 2): |
|
reso = 4 * 2 ** (i + 1) |
|
update( |
|
state_dict, |
|
convert_torgb(vars, f'G_synthesis/{reso}x{reso}/ToRGB', f'to_rgbs.{i}'), |
|
) |
|
|
|
update(state_dict, convert_modconv(vars, 'G_synthesis/4x4/Conv', 'conv1')) |
|
|
|
conv_i = 0 |
|
|
|
for i in range(log_size - 2): |
|
reso = 4 * 2 ** (i + 1) |
|
update( |
|
state_dict, |
|
convert_modconv( |
|
vars, |
|
f'G_synthesis/{reso}x{reso}/Conv0_up', |
|
f'convs.{conv_i}', |
|
flip=True, |
|
), |
|
) |
|
update( |
|
state_dict, |
|
convert_modconv( |
|
vars, f'G_synthesis/{reso}x{reso}/Conv1', f'convs.{conv_i + 1}' |
|
), |
|
) |
|
conv_i += 2 |
|
|
|
for i in range(0, (log_size - 2) * 2 + 1): |
|
update( |
|
state_dict, |
|
{ |
|
f'noises.noise_{i}': torch.from_numpy( |
|
vars[f'G_synthesis/noise{i}'].value().eval() |
|
) |
|
}, |
|
) |
|
|
|
return state_dict |
|
|
|
|
|
if __name__ == '__main__': |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
print('Using PyTorch device', device) |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--repo', type=str, required=True) |
|
parser.add_argument('--gen', action='store_true') |
|
parser.add_argument('--disc', action='store_true') |
|
parser.add_argument('--channel_multiplier', type=int, default=2) |
|
parser.add_argument('path', metavar='PATH') |
|
|
|
args = parser.parse_args() |
|
|
|
sys.path.append(args.repo) |
|
|
|
import dnnlib |
|
from dnnlib import tflib |
|
|
|
tflib.init_tf() |
|
|
|
with open(args.path, 'rb') as f: |
|
generator, discriminator, g_ema = pickle.load(f) |
|
|
|
size = g_ema.output_shape[2] |
|
|
|
g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) |
|
state_dict = g.state_dict() |
|
state_dict = fill_statedict(state_dict, g_ema.vars, size) |
|
|
|
g.load_state_dict(state_dict) |
|
|
|
latent_avg = torch.from_numpy(g_ema.vars['dlatent_avg'].value().eval()) |
|
|
|
ckpt = {'g_ema': state_dict, 'latent_avg': latent_avg} |
|
|
|
if args.gen: |
|
g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) |
|
g_train_state = g_train.state_dict() |
|
g_train_state = fill_statedict(g_train_state, generator.vars, size) |
|
ckpt['g'] = g_train_state |
|
|
|
if args.disc: |
|
disc = Discriminator(size, channel_multiplier=args.channel_multiplier) |
|
d_state = disc.state_dict() |
|
d_state = discriminator_fill_statedict(d_state, discriminator.vars, size) |
|
ckpt['d'] = d_state |
|
|
|
name = os.path.splitext(os.path.basename(args.path))[0] |
|
outpath = os.path.join(os.getcwd(), f'{name}.pt') |
|
print('Saving', outpath) |
|
try: |
|
torch.save(ckpt, outpath, _use_new_zipfile_serialization=False) |
|
except TypeError: |
|
torch.save(ckpt, outpath) |
|
|
|
|
|
print('Generating TF-Torch comparison images') |
|
batch_size = {256: 8, 512: 4, 1024: 2} |
|
n_sample = batch_size.get(size, 4) |
|
|
|
g = g.to(device) |
|
|
|
z = np.random.RandomState(0).randn(n_sample, 512).astype('float32') |
|
|
|
with torch.no_grad(): |
|
img_pt, _ = g( |
|
[torch.from_numpy(z).to(device)], |
|
truncation=0.5, |
|
truncation_latent=latent_avg.to(device), |
|
) |
|
|
|
img_tf = g_ema.run(z, None, randomize_noise=False) |
|
img_tf = torch.from_numpy(img_tf).to(device) |
|
|
|
img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp( |
|
0.0, 1.0 |
|
) |
|
|
|
img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0) |
|
utils.save_image( |
|
img_concat, name + '.png', nrow=n_sample, normalize=True, range=(-1, 1) |
|
) |
|
print('Done') |
|
|
|
|