disentangled-image-editing-final-project
/
ContraCLIP
/models
/genforce
/converters
/pggan_converter.py
# python3.7 | |
"""Converts PGGAN model weights from TensorFlow to PyTorch. | |
The models can be trained through OR released by the repository: | |
https://github.com/tkarras/progressive_growing_of_gans | |
""" | |
import os | |
import sys | |
import pickle | |
import warnings | |
warnings.filterwarnings('ignore', category=FutureWarning) | |
# pylint: disable=wrong-import-position | |
from tqdm import tqdm | |
import numpy as np | |
import tensorflow as tf | |
import torch | |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
from models import build_model | |
from utils.visualizer import HtmlPageVisualizer | |
from utils.visualizer import postprocess_image | |
# pylint: enable=wrong-import-position | |
__all__ = ['convert_pggan_weight'] | |
GAN_TPYE = 'pggan' | |
OFFICIAL_CODE_DIR = 'pggan_official' | |
BASE_DIR = os.path.dirname(os.path.relpath(__file__)) | |
CODE_PATH = os.path.join(BASE_DIR, OFFICIAL_CODE_DIR) | |
def convert_pggan_weight(tf_weight_path, | |
pth_weight_path, | |
test_num=10, | |
save_test_image=False, | |
verbose=False): | |
"""Converts the pre-trained PGGAN weights. | |
Args: | |
tf_weight_path: Path to the TensorFlow model to load weights from. | |
pth_weight_path: Path to the PyTorch model to save converted weights. | |
test_num: Number of samples used to test the conversion. (default: 10) | |
save_test_image: Whether to save the test images. (default: False) | |
verbose: Whether to print verbose log message. (default: False) | |
""" | |
sess = tf.compat.v1.InteractiveSession() | |
print(f'========================================') | |
print(f'Loading TensorFlow weights from `{tf_weight_path}` ...') | |
sys.path.insert(0, CODE_PATH) | |
with open(tf_weight_path, 'rb') as f: | |
G, D, Gs = pickle.load(f) | |
sys.path.pop(0) | |
print(f'Successfully loaded!') | |
print(f'--------------------') | |
z_space_dim = G.input_shapes[0][1] | |
label_size = G.input_shapes[1][1] | |
image_channels = G.output_shape[1] | |
resolution = G.output_shape[2] | |
print(f'Converting TensorFlow weights (G) to PyTorch version ...') | |
G_vars = dict(G.__getstate__()['variables']) | |
G_pth = build_model(gan_type=GAN_TPYE, | |
module='generator', | |
resolution=resolution, | |
z_space_dim=z_space_dim, | |
label_size=label_size, | |
image_channels=image_channels) | |
G_state_dict = G_pth.state_dict() | |
for pth_var_name, tf_var_name in G_pth.pth_to_tf_var_mapping.items(): | |
assert tf_var_name in G_vars | |
assert pth_var_name in G_state_dict | |
if verbose: | |
print(f' Converting `{tf_var_name}` to `{pth_var_name}`.') | |
var = torch.from_numpy(np.array(G_vars[tf_var_name])) | |
if 'weight' in tf_var_name: | |
if 'Dense' in tf_var_name: | |
var = var.view(var.shape[0], -1, G_pth.init_res, G_pth.init_res) | |
var = var.permute(1, 0, 2, 3).flip(2, 3) | |
else: | |
var = var.permute(3, 2, 0, 1) | |
G_state_dict[pth_var_name] = var | |
print(f'Successfully converted!') | |
print(f'--------------------') | |
print(f'Converting TensorFlow weights (Gs) to PyTorch version ...') | |
Gs_vars = dict(Gs.__getstate__()['variables']) | |
Gs_pth = build_model(gan_type=GAN_TPYE, | |
module='generator', | |
resolution=resolution, | |
z_space_dim=z_space_dim, | |
label_size=label_size, | |
image_channels=image_channels) | |
Gs_state_dict = Gs_pth.state_dict() | |
for pth_var_name, tf_var_name in Gs_pth.pth_to_tf_var_mapping.items(): | |
assert tf_var_name in Gs_vars | |
assert pth_var_name in Gs_state_dict | |
if verbose: | |
print(f' Converting `{tf_var_name}` to `{pth_var_name}`.') | |
var = torch.from_numpy(np.array(Gs_vars[tf_var_name])) | |
if 'weight' in tf_var_name: | |
if 'Dense' in tf_var_name: | |
var = var.view( | |
var.shape[0], -1, Gs_pth.init_res, Gs_pth.init_res) | |
var = var.permute(1, 0, 2, 3).flip(2, 3) | |
else: | |
var = var.permute(3, 2, 0, 1) | |
Gs_state_dict[pth_var_name] = var | |
print(f'Successfully converted!') | |
print(f'--------------------') | |
print(f'Converting TensorFlow weights (D) to PyTorch version ...') | |
D_vars = dict(D.__getstate__()['variables']) | |
D_pth = build_model(gan_type=GAN_TPYE, | |
module='discriminator', | |
resolution=resolution, | |
label_size=label_size, | |
image_channels=image_channels) | |
D_state_dict = D_pth.state_dict() | |
for pth_var_name, tf_var_name in D_pth.pth_to_tf_var_mapping.items(): | |
assert tf_var_name in D_vars | |
assert pth_var_name in D_state_dict | |
if verbose: | |
print(f' Converting `{tf_var_name}` to `{pth_var_name}`.') | |
var = torch.from_numpy(np.array(D_vars[tf_var_name])) | |
if 'weight' in tf_var_name: | |
if 'Dense' in tf_var_name: | |
var = var.permute(1, 0) | |
else: | |
var = var.permute(3, 2, 0, 1) | |
D_state_dict[pth_var_name] = var | |
print(f'Successfully converted!') | |
print(f'--------------------') | |
print(f'Saving PyTorch weights to `{pth_weight_path}` ...') | |
state_dict = { | |
'generator': G_state_dict, | |
'discriminator': D_state_dict, | |
'generator_smooth': Gs_state_dict, | |
} | |
torch.save(state_dict, pth_weight_path) | |
print(f'Successfully saved!') | |
print(f'--------------------') | |
# Start testing if needed. | |
if test_num <= 0 or not tf.test.is_built_with_cuda(): | |
warnings.warn(f'Skip testing the converted weights!') | |
sess.close() | |
return | |
if save_test_image: | |
html = HtmlPageVisualizer(num_rows=test_num, num_cols=3) | |
html.set_headers(['Index', 'Before Conversion', 'After Conversion']) | |
for i in range(test_num): | |
html.set_cell(i, 0, text=f'{i}') | |
print(f'Testing conversion results ...') | |
G_pth.load_state_dict(G_state_dict) | |
D_pth.load_state_dict(D_state_dict) | |
Gs_pth.load_state_dict(Gs_state_dict) | |
G_pth.eval().cuda() | |
D_pth.eval().cuda() | |
Gs_pth.eval().cuda() | |
gs_distance = 0.0 | |
dg_distance = 0.0 | |
for i in tqdm(range(test_num)): | |
# Test Gs(z). | |
code = np.random.randn(1, z_space_dim) | |
pth_code = torch.from_numpy(code).type(torch.FloatTensor).cuda() | |
label = np.zeros((1, label_size), np.float32) | |
if label_size: | |
label_id = np.random.randint(label_size) | |
label[0, label_id] = 1.0 | |
pth_label = torch.from_numpy(label).type(torch.FloatTensor).cuda() | |
else: | |
label_id = 0 | |
pth_label = None | |
tf_output = Gs.run(code, label) | |
pth_output = Gs_pth(pth_code, label=pth_label)['image'] | |
pth_output = pth_output.detach().cpu().numpy() | |
distance = np.average(np.abs(tf_output - pth_output)) | |
if verbose: | |
print(f' Test {i:03d}: Gs distance {distance:.6e}.') | |
gs_distance += distance | |
if save_test_image: | |
html.set_cell(i, 1, image=postprocess_image(tf_output)[0]) | |
html.set_cell(i, 2, image=postprocess_image(pth_output)[0]) | |
# Test D(G(z)). | |
code = np.random.randn(1, z_space_dim) | |
pth_code = torch.from_numpy(code).type(torch.FloatTensor).cuda() | |
label = np.zeros((1, label_size), np.float32) | |
if label_size: | |
label_id = np.random.randint(label_size) | |
label[0, label_id] = 1.0 | |
pth_label = torch.from_numpy(label).type(torch.FloatTensor).cuda() | |
else: | |
label_id = 0 | |
pth_label = None | |
tf_image = G.run(code, label) | |
tf_output = D.run(tf_image) | |
pth_image = G_pth(pth_code, label=pth_label)['image'] | |
pth_output = D_pth(pth_image) | |
pth_output = pth_output.detach().cpu().numpy() | |
distance = np.average(np.abs(tf_output[0] - pth_output[:, :1])) | |
if label_size: | |
distance += np.average(np.abs(tf_output[1] - pth_output[:, 1:])) | |
if verbose: | |
print(f' Test {i:03d}: D(G) distance {distance:.6e}.') | |
dg_distance += distance | |
print(f'Average Gs distance is {gs_distance / test_num:.6e}.') | |
print(f'Average D(G) distance is {dg_distance / test_num:.6e}.') | |
print(f'========================================') | |
if save_test_image: | |
html.save(f'{pth_weight_path}.conversion_test.html') | |
sess.close() | |