disentangled-image-editing-final-project
/
ContraCLIP
/models
/genforce
/converters
/stylegan2ada_pth_converter.py
| # python3.7 | |
| """Converts StyleGAN2-ADA-PyTorch model to match this repository. | |
| The models can be trained through OR released by the repository: | |
| https://github.com/NVlabs/stylegan2-ada-pytorch | |
| """ | |
| import os | |
| import sys | |
| import re | |
| import pickle | |
| import warnings | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| from models import build_model | |
| from utils.visualizer import HtmlPageVisualizer | |
| from utils.visualizer import postprocess_image | |
| __all__ = ['convert_stylegan2ada_pth_weight'] | |
| GAN_TPYE = 'stylegan2' | |
| OFFICIAL_CODE_DIR = 'stylegan2ada_pth_official' | |
| BASE_DIR = os.path.dirname(os.path.relpath(__file__)) | |
| CODE_PATH = os.path.join(BASE_DIR, OFFICIAL_CODE_DIR) | |
| TRUNC_PSI = 0.5 | |
| TRUNC_LAYERS = 18 | |
| RANDOMIZE_NOISE = False | |
| NOISE_MODE = 'random' if RANDOMIZE_NOISE else 'const' | |
| # The following two dictionary of mapping patterns are modified from | |
| # https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/legacy.py | |
| G_PTH_TO_TF_VAR_MAPPING_PATTERN = { | |
| r'mapping\.w_avg': | |
| lambda: f'dlatent_avg', | |
| r'mapping\.embed\.weight': | |
| lambda: f'LabelEmbed/weight', | |
| r'mapping\.embed\.bias': | |
| lambda: f'LabelEmbed/bias', | |
| r'mapping\.fc(\d+)\.weight': | |
| lambda i: f'Dense{i}/weight', | |
| r'mapping\.fc(\d+)\.bias': | |
| lambda i: f'Dense{i}/bias', | |
| r'synthesis\.b4\.const': | |
| lambda: f'4x4/Const/const', | |
| r'synthesis\.b4\.conv1\.weight': | |
| lambda: f'4x4/Conv/weight', | |
| r'synthesis\.b4\.conv1\.bias': | |
| lambda: f'4x4/Conv/bias', | |
| r'synthesis\.b4\.conv1\.noise_const': | |
| lambda: f'noise0', | |
| r'synthesis\.b4\.conv1\.noise_strength': | |
| lambda: f'4x4/Conv/noise_strength', | |
| r'synthesis\.b4\.conv1\.affine\.weight': | |
| lambda: f'4x4/Conv/mod_weight', | |
| r'synthesis\.b4\.conv1\.affine\.bias': | |
| lambda: f'4x4/Conv/mod_bias', | |
| r'synthesis\.b(\d+)\.conv0\.weight': | |
| lambda r: f'{r}x{r}/Conv0_up/weight', | |
| r'synthesis\.b(\d+)\.conv0\.bias': | |
| lambda r: f'{r}x{r}/Conv0_up/bias', | |
| r'synthesis\.b(\d+)\.conv0\.noise_const': | |
| lambda r: f'noise{int(np.log2(int(r)))*2-5}', | |
| r'synthesis\.b(\d+)\.conv0\.noise_strength': | |
| lambda r: f'{r}x{r}/Conv0_up/noise_strength', | |
| r'synthesis\.b(\d+)\.conv0\.affine\.weight': | |
| lambda r: f'{r}x{r}/Conv0_up/mod_weight', | |
| r'synthesis\.b(\d+)\.conv0\.affine\.bias': | |
| lambda r: f'{r}x{r}/Conv0_up/mod_bias', | |
| r'synthesis\.b(\d+)\.conv1\.weight': | |
| lambda r: f'{r}x{r}/Conv1/weight', | |
| r'synthesis\.b(\d+)\.conv1\.bias': | |
| lambda r: f'{r}x{r}/Conv1/bias', | |
| r'synthesis\.b(\d+)\.conv1\.noise_const': | |
| lambda r: f'noise{int(np.log2(int(r)))*2-4}', | |
| r'synthesis\.b(\d+)\.conv1\.noise_strength': | |
| lambda r: f'{r}x{r}/Conv1/noise_strength', | |
| r'synthesis\.b(\d+)\.conv1\.affine\.weight': | |
| lambda r: f'{r}x{r}/Conv1/mod_weight', | |
| r'synthesis\.b(\d+)\.conv1\.affine\.bias': | |
| lambda r: f'{r}x{r}/Conv1/mod_bias', | |
| r'synthesis\.b(\d+)\.torgb\.weight': | |
| lambda r: f'{r}x{r}/ToRGB/weight', | |
| r'synthesis\.b(\d+)\.torgb\.bias': | |
| lambda r: f'{r}x{r}/ToRGB/bias', | |
| r'synthesis\.b(\d+)\.torgb\.affine\.weight': | |
| lambda r: f'{r}x{r}/ToRGB/mod_weight', | |
| r'synthesis\.b(\d+)\.torgb\.affine\.bias': | |
| lambda r: f'{r}x{r}/ToRGB/mod_bias', | |
| r'synthesis\.b(\d+)\.skip\.weight': | |
| lambda r: f'{r}x{r}/Skip/weight', | |
| r'.*\.resample_filter': | |
| None, | |
| } | |
| D_PTH_TO_TF_VAR_MAPPING_PATTERN = { | |
| r'b(\d+)\.fromrgb\.weight': | |
| lambda r: f'{r}x{r}/FromRGB/weight', | |
| r'b(\d+)\.fromrgb\.bias': | |
| lambda r: f'{r}x{r}/FromRGB/bias', | |
| r'b(\d+)\.conv(\d+)\.weight': | |
| lambda r, i: f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight', | |
| r'b(\d+)\.conv(\d+)\.bias': | |
| lambda r, i: f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias', | |
| r'b(\d+)\.skip\.weight': | |
| lambda r: f'{r}x{r}/Skip/weight', | |
| r'mapping\.embed\.weight': | |
| lambda: f'LabelEmbed/weight', | |
| r'mapping\.embed\.bias': | |
| lambda: f'LabelEmbed/bias', | |
| r'mapping\.fc(\d+)\.weight': | |
| lambda i: f'Mapping{i}/weight', | |
| r'mapping\.fc(\d+)\.bias': | |
| lambda i: f'Mapping{i}/bias', | |
| r'b4\.conv\.weight': | |
| lambda: f'4x4/Conv/weight', | |
| r'b4\.conv\.bias': | |
| lambda: f'4x4/Conv/bias', | |
| r'b4\.fc\.weight': | |
| lambda: f'4x4/Dense0/weight', | |
| r'b4\.fc\.bias': | |
| lambda: f'4x4/Dense0/bias', | |
| r'b4\.out\.weight': | |
| lambda: f'Output/weight', | |
| r'b4\.out\.bias': | |
| lambda: f'Output/bias', | |
| r'.*\.resample_filter': | |
| None, | |
| } | |
| def convert_stylegan2ada_pth_weight(src_weight_path, | |
| dst_weight_path, | |
| test_num=10, | |
| save_test_image=False, | |
| verbose=False): | |
| """Converts the pre-trained StyleGAN2-ADA-PyTorch weights. | |
| Args: | |
| src_weight_path: Path to the source model to load weights from. | |
| dst_weight_path: Path to the target model to save converted weights. | |
| test_num: Number of samples used to test the conversion. (default: 10) | |
| save_test_image: Whether to save the test images. (default: False) | |
| verbose: Whether to print verbose log message. (default: False) | |
| """ | |
| print(f'========================================') | |
| print(f'Loading source weights from `{src_weight_path}` ...') | |
| sys.path.insert(0, CODE_PATH) | |
| with open(src_weight_path, 'rb') as f: | |
| model = pickle.load(f) | |
| sys.path.pop(0) | |
| print(f'Successfully loaded!') | |
| print(f'--------------------') | |
| z_space_dim = model['G'].z_dim | |
| label_size = model['G'].c_dim | |
| w_space_dim = model['G'].w_dim | |
| image_channels = model['G'].img_channels | |
| resolution = model['G'].img_resolution | |
| repeat_w = True | |
| print(f'Converting source weights (G) to target ...') | |
| G_vars = dict(model['G'].named_parameters()) | |
| G_vars.update(dict(model['G'].named_buffers())) | |
| G = build_model(gan_type=GAN_TPYE, | |
| module='generator', | |
| resolution=resolution, | |
| z_space_dim=z_space_dim, | |
| w_space_dim=w_space_dim, | |
| label_size=label_size, | |
| repeat_w=repeat_w, | |
| image_channels=image_channels) | |
| G_state_dict = G.state_dict() | |
| official_tf_to_pth_var_mapping = {} | |
| for name in G_vars.keys(): | |
| for pattern, fn in G_PTH_TO_TF_VAR_MAPPING_PATTERN.items(): | |
| match = re.fullmatch(pattern, name) | |
| if match: | |
| if fn is not None: | |
| official_tf_to_pth_var_mapping[fn(*match.groups())] = name | |
| break | |
| for dst_var_name, tf_var_name in G.pth_to_tf_var_mapping.items(): | |
| assert tf_var_name in official_tf_to_pth_var_mapping | |
| assert dst_var_name in G_state_dict | |
| src_var_name = official_tf_to_pth_var_mapping[tf_var_name] | |
| assert src_var_name in G_vars | |
| if verbose: | |
| print(f' Converting `{src_var_name}` to `{dst_var_name}`.') | |
| var = G_vars[src_var_name].data | |
| if 'weight' in tf_var_name: | |
| if 'Conv0_up/weight' in tf_var_name: | |
| var = var.flip(2, 3) | |
| elif 'Skip' in tf_var_name: | |
| var = var.flip(2, 3) | |
| if 'bias' in tf_var_name: | |
| if 'mod_bias' in tf_var_name: | |
| var = var - 1 | |
| if 'Const' in tf_var_name: | |
| var = var.unsqueeze(0) | |
| if 'noise' in tf_var_name and 'noise_' not in tf_var_name: | |
| var = var.unsqueeze(0).unsqueeze(0) | |
| G_state_dict[dst_var_name] = var | |
| print(f'Successfully converted!') | |
| print(f'--------------------') | |
| print(f'Converting source weights (Gs) to target ...') | |
| Gs_vars = dict(model['G_ema'].named_parameters()) | |
| Gs_vars.update(dict(model['G_ema'].named_buffers())) | |
| Gs = build_model(gan_type=GAN_TPYE, | |
| module='generator', | |
| resolution=resolution, | |
| z_space_dim=z_space_dim, | |
| w_space_dim=w_space_dim, | |
| label_size=label_size, | |
| repeat_w=repeat_w, | |
| image_channels=image_channels) | |
| Gs_state_dict = Gs.state_dict() | |
| official_tf_to_pth_var_mapping = {} | |
| for name in Gs_vars.keys(): | |
| for pattern, fn in G_PTH_TO_TF_VAR_MAPPING_PATTERN.items(): | |
| match = re.fullmatch(pattern, name) | |
| if match: | |
| if fn is not None: | |
| official_tf_to_pth_var_mapping[fn(*match.groups())] = name | |
| break | |
| for dst_var_name, tf_var_name in Gs.pth_to_tf_var_mapping.items(): | |
| assert tf_var_name in official_tf_to_pth_var_mapping | |
| assert dst_var_name in Gs_state_dict | |
| src_var_name = official_tf_to_pth_var_mapping[tf_var_name] | |
| assert src_var_name in Gs_vars | |
| if verbose: | |
| print(f' Converting `{src_var_name}` to `{dst_var_name}`.') | |
| var = Gs_vars[src_var_name].data | |
| if 'weight' in tf_var_name: | |
| if 'Conv0_up/weight' in tf_var_name: | |
| var = var.flip(2, 3) | |
| elif 'Skip' in tf_var_name: | |
| var = var.flip(2, 3) | |
| if 'bias' in tf_var_name: | |
| if 'mod_bias' in tf_var_name: | |
| var = var - 1 | |
| if 'Const' in tf_var_name: | |
| var = var.unsqueeze(0) | |
| if 'noise' in tf_var_name and 'noise_' not in tf_var_name: | |
| var = var.unsqueeze(0).unsqueeze(0) | |
| Gs_state_dict[dst_var_name] = var | |
| print(f'Successfully converted!') | |
| print(f'--------------------') | |
| print(f'Converting source weights (D) to target ...') | |
| D_vars = dict(model['D'].named_parameters()) | |
| D_vars.update(dict(model['D'].named_buffers())) | |
| D = build_model(gan_type=GAN_TPYE, | |
| module='discriminator', | |
| resolution=resolution, | |
| label_size=label_size, | |
| image_channels=image_channels) | |
| D_state_dict = D.state_dict() | |
| official_tf_to_pth_var_mapping = {} | |
| for name in D_vars.keys(): | |
| for pattern, fn in D_PTH_TO_TF_VAR_MAPPING_PATTERN.items(): | |
| match = re.fullmatch(pattern, name) | |
| if match: | |
| if fn is not None: | |
| official_tf_to_pth_var_mapping[fn(*match.groups())] = name | |
| break | |
| for dst_var_name, tf_var_name in D.pth_to_tf_var_mapping.items(): | |
| assert tf_var_name in official_tf_to_pth_var_mapping | |
| assert dst_var_name in D_state_dict | |
| src_var_name = official_tf_to_pth_var_mapping[tf_var_name] | |
| assert src_var_name in D_vars | |
| if verbose: | |
| print(f' Converting `{src_var_name}` to `{dst_var_name}`.') | |
| var = D_vars[src_var_name].data | |
| D_state_dict[dst_var_name] = var | |
| print(f'Successfully converted!') | |
| print(f'--------------------') | |
| print(f'Saving target weights to `{dst_weight_path}` ...') | |
| state_dict = { | |
| 'generator': G_state_dict, | |
| 'discriminator': D_state_dict, | |
| 'generator_smooth': Gs_state_dict, | |
| } | |
| torch.save(state_dict, dst_weight_path) | |
| print(f'Successfully saved!') | |
| print(f'--------------------') | |
| # Start testing if needed. | |
| if test_num <= 0: | |
| warnings.warn(f'Skip testing the converted weights!') | |
| return | |
| if save_test_image: | |
| html = HtmlPageVisualizer(num_rows=test_num, num_cols=3) | |
| html.set_headers(['Index', 'Before Conversion', 'After Conversion']) | |
| for i in range(test_num): | |
| html.set_cell(i, 0, text=f'{i}') | |
| print(f'Testing conversion results ...') | |
| G.load_state_dict(G_state_dict) | |
| D.load_state_dict(D_state_dict) | |
| Gs.load_state_dict(Gs_state_dict) | |
| G.eval().cuda() | |
| D.eval().cuda() | |
| Gs.eval().cuda() | |
| model['G'].eval().cuda() | |
| model['D'].eval().cuda() | |
| model['G_ema'].eval().cuda() | |
| gs_distance = 0.0 | |
| dg_distance = 0.0 | |
| for i in tqdm(range(test_num)): | |
| # Test Gs(z). | |
| code = np.random.randn(1, z_space_dim) | |
| code = torch.from_numpy(code).type(torch.FloatTensor).cuda() | |
| if label_size: | |
| label_id = np.random.randint(label_size) | |
| label = np.zeros((1, label_size), np.float32) | |
| label[0, label_id] = 1.0 | |
| label = torch.from_numpy(label).type(torch.FloatTensor).cuda() | |
| else: | |
| label_id = 0 | |
| label = None | |
| src_output = model['G_ema'](code, | |
| label, | |
| truncation_psi=TRUNC_PSI, | |
| truncation_cutoff=TRUNC_LAYERS, | |
| noise_mode=NOISE_MODE) | |
| src_output = src_output.detach().cpu().numpy() | |
| dst_output = Gs(code, | |
| label=label, | |
| trunc_psi=TRUNC_PSI, | |
| trunc_layers=TRUNC_LAYERS, | |
| randomize_noise=RANDOMIZE_NOISE)['image'] | |
| dst_output = dst_output.detach().cpu().numpy() | |
| distance = np.average(np.abs(src_output - dst_output)) | |
| if verbose: | |
| print(f' Test {i:03d}: Gs distance {distance:.6e}.') | |
| gs_distance += distance | |
| if save_test_image: | |
| html.set_cell(i, 1, image=postprocess_image(src_output)[0]) | |
| html.set_cell(i, 2, image=postprocess_image(dst_output)[0]) | |
| # Test D(G(z)). | |
| code = np.random.randn(1, z_space_dim) | |
| code = torch.from_numpy(code).type(torch.FloatTensor).cuda() | |
| if label_size: | |
| label_id = np.random.randint(label_size) | |
| label = np.zeros((1, label_size), np.float32) | |
| label[0, label_id] = 1.0 | |
| label = torch.from_numpy(label).type(torch.FloatTensor).cuda() | |
| else: | |
| label_id = 0 | |
| label = None | |
| src_image = model['G'](code, | |
| label, | |
| truncation_psi=TRUNC_PSI, | |
| truncation_cutoff=TRUNC_LAYERS, | |
| noise_mode=NOISE_MODE) | |
| src_output = model['D'](src_image, label) | |
| src_output = src_output.detach().cpu().numpy() | |
| dst_image = G(code, | |
| label=label, | |
| trunc_psi=TRUNC_PSI, | |
| trunc_layers=TRUNC_LAYERS, | |
| randomize_noise=RANDOMIZE_NOISE)['image'] | |
| dst_output = D(dst_image, label) | |
| dst_output = dst_output.detach().cpu().numpy() | |
| distance = np.average(np.abs(src_output - dst_output)) | |
| if verbose: | |
| print(f' Test {i:03d}: D(G) distance {distance:.6e}.') | |
| dg_distance += distance | |
| print(f'Average Gs distance is {gs_distance / test_num:.6e}.') | |
| print(f'Average D(G) distance is {dg_distance / test_num:.6e}.') | |
| print(f'========================================') | |
| if save_test_image: | |
| html.save(f'{dst_weight_path}.conversion_test.html') | |