dattarij's picture
adding ContraCLIP folder
8c212a5
raw
history blame
8.69 kB
# python3.7
"""Converts PGGAN model weights from TensorFlow to PyTorch.
The models can be trained through OR released by the repository:
https://github.com/tkarras/progressive_growing_of_gans
"""
import os
import sys
import pickle
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
# pylint: disable=wrong-import-position
from tqdm import tqdm
import numpy as np
import tensorflow as tf
import torch
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from models import build_model
from utils.visualizer import HtmlPageVisualizer
from utils.visualizer import postprocess_image
# pylint: enable=wrong-import-position
__all__ = ['convert_pggan_weight']
GAN_TPYE = 'pggan'
OFFICIAL_CODE_DIR = 'pggan_official'
BASE_DIR = os.path.dirname(os.path.relpath(__file__))
CODE_PATH = os.path.join(BASE_DIR, OFFICIAL_CODE_DIR)
def convert_pggan_weight(tf_weight_path,
pth_weight_path,
test_num=10,
save_test_image=False,
verbose=False):
"""Converts the pre-trained PGGAN weights.
Args:
tf_weight_path: Path to the TensorFlow model to load weights from.
pth_weight_path: Path to the PyTorch model to save converted weights.
test_num: Number of samples used to test the conversion. (default: 10)
save_test_image: Whether to save the test images. (default: False)
verbose: Whether to print verbose log message. (default: False)
"""
sess = tf.compat.v1.InteractiveSession()
print(f'========================================')
print(f'Loading TensorFlow weights from `{tf_weight_path}` ...')
sys.path.insert(0, CODE_PATH)
with open(tf_weight_path, 'rb') as f:
G, D, Gs = pickle.load(f)
sys.path.pop(0)
print(f'Successfully loaded!')
print(f'--------------------')
z_space_dim = G.input_shapes[0][1]
label_size = G.input_shapes[1][1]
image_channels = G.output_shape[1]
resolution = G.output_shape[2]
print(f'Converting TensorFlow weights (G) to PyTorch version ...')
G_vars = dict(G.__getstate__()['variables'])
G_pth = build_model(gan_type=GAN_TPYE,
module='generator',
resolution=resolution,
z_space_dim=z_space_dim,
label_size=label_size,
image_channels=image_channels)
G_state_dict = G_pth.state_dict()
for pth_var_name, tf_var_name in G_pth.pth_to_tf_var_mapping.items():
assert tf_var_name in G_vars
assert pth_var_name in G_state_dict
if verbose:
print(f' Converting `{tf_var_name}` to `{pth_var_name}`.')
var = torch.from_numpy(np.array(G_vars[tf_var_name]))
if 'weight' in tf_var_name:
if 'Dense' in tf_var_name:
var = var.view(var.shape[0], -1, G_pth.init_res, G_pth.init_res)
var = var.permute(1, 0, 2, 3).flip(2, 3)
else:
var = var.permute(3, 2, 0, 1)
G_state_dict[pth_var_name] = var
print(f'Successfully converted!')
print(f'--------------------')
print(f'Converting TensorFlow weights (Gs) to PyTorch version ...')
Gs_vars = dict(Gs.__getstate__()['variables'])
Gs_pth = build_model(gan_type=GAN_TPYE,
module='generator',
resolution=resolution,
z_space_dim=z_space_dim,
label_size=label_size,
image_channels=image_channels)
Gs_state_dict = Gs_pth.state_dict()
for pth_var_name, tf_var_name in Gs_pth.pth_to_tf_var_mapping.items():
assert tf_var_name in Gs_vars
assert pth_var_name in Gs_state_dict
if verbose:
print(f' Converting `{tf_var_name}` to `{pth_var_name}`.')
var = torch.from_numpy(np.array(Gs_vars[tf_var_name]))
if 'weight' in tf_var_name:
if 'Dense' in tf_var_name:
var = var.view(
var.shape[0], -1, Gs_pth.init_res, Gs_pth.init_res)
var = var.permute(1, 0, 2, 3).flip(2, 3)
else:
var = var.permute(3, 2, 0, 1)
Gs_state_dict[pth_var_name] = var
print(f'Successfully converted!')
print(f'--------------------')
print(f'Converting TensorFlow weights (D) to PyTorch version ...')
D_vars = dict(D.__getstate__()['variables'])
D_pth = build_model(gan_type=GAN_TPYE,
module='discriminator',
resolution=resolution,
label_size=label_size,
image_channels=image_channels)
D_state_dict = D_pth.state_dict()
for pth_var_name, tf_var_name in D_pth.pth_to_tf_var_mapping.items():
assert tf_var_name in D_vars
assert pth_var_name in D_state_dict
if verbose:
print(f' Converting `{tf_var_name}` to `{pth_var_name}`.')
var = torch.from_numpy(np.array(D_vars[tf_var_name]))
if 'weight' in tf_var_name:
if 'Dense' in tf_var_name:
var = var.permute(1, 0)
else:
var = var.permute(3, 2, 0, 1)
D_state_dict[pth_var_name] = var
print(f'Successfully converted!')
print(f'--------------------')
print(f'Saving PyTorch weights to `{pth_weight_path}` ...')
state_dict = {
'generator': G_state_dict,
'discriminator': D_state_dict,
'generator_smooth': Gs_state_dict,
}
torch.save(state_dict, pth_weight_path)
print(f'Successfully saved!')
print(f'--------------------')
# Start testing if needed.
if test_num <= 0 or not tf.test.is_built_with_cuda():
warnings.warn(f'Skip testing the converted weights!')
sess.close()
return
if save_test_image:
html = HtmlPageVisualizer(num_rows=test_num, num_cols=3)
html.set_headers(['Index', 'Before Conversion', 'After Conversion'])
for i in range(test_num):
html.set_cell(i, 0, text=f'{i}')
print(f'Testing conversion results ...')
G_pth.load_state_dict(G_state_dict)
D_pth.load_state_dict(D_state_dict)
Gs_pth.load_state_dict(Gs_state_dict)
G_pth.eval().cuda()
D_pth.eval().cuda()
Gs_pth.eval().cuda()
gs_distance = 0.0
dg_distance = 0.0
for i in tqdm(range(test_num)):
# Test Gs(z).
code = np.random.randn(1, z_space_dim)
pth_code = torch.from_numpy(code).type(torch.FloatTensor).cuda()
label = np.zeros((1, label_size), np.float32)
if label_size:
label_id = np.random.randint(label_size)
label[0, label_id] = 1.0
pth_label = torch.from_numpy(label).type(torch.FloatTensor).cuda()
else:
label_id = 0
pth_label = None
tf_output = Gs.run(code, label)
pth_output = Gs_pth(pth_code, label=pth_label)['image']
pth_output = pth_output.detach().cpu().numpy()
distance = np.average(np.abs(tf_output - pth_output))
if verbose:
print(f' Test {i:03d}: Gs distance {distance:.6e}.')
gs_distance += distance
if save_test_image:
html.set_cell(i, 1, image=postprocess_image(tf_output)[0])
html.set_cell(i, 2, image=postprocess_image(pth_output)[0])
# Test D(G(z)).
code = np.random.randn(1, z_space_dim)
pth_code = torch.from_numpy(code).type(torch.FloatTensor).cuda()
label = np.zeros((1, label_size), np.float32)
if label_size:
label_id = np.random.randint(label_size)
label[0, label_id] = 1.0
pth_label = torch.from_numpy(label).type(torch.FloatTensor).cuda()
else:
label_id = 0
pth_label = None
tf_image = G.run(code, label)
tf_output = D.run(tf_image)
pth_image = G_pth(pth_code, label=pth_label)['image']
pth_output = D_pth(pth_image)
pth_output = pth_output.detach().cpu().numpy()
distance = np.average(np.abs(tf_output[0] - pth_output[:, :1]))
if label_size:
distance += np.average(np.abs(tf_output[1] - pth_output[:, 1:]))
if verbose:
print(f' Test {i:03d}: D(G) distance {distance:.6e}.')
dg_distance += distance
print(f'Average Gs distance is {gs_distance / test_num:.6e}.')
print(f'Average D(G) distance is {dg_distance / test_num:.6e}.')
print(f'========================================')
if save_test_image:
html.save(f'{pth_weight_path}.conversion_test.html')
sess.close()