File size: 6,137 Bytes
8c212a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# python3.7
"""A simple tool to synthesize images with pre-trained models."""
import os
import argparse
import subprocess
from tqdm import tqdm
import numpy as np
import torch
from models import MODEL_ZOO
from models import build_generator
from utils.misc import bool_parser
from utils.visualizer import HtmlPageVisualizer
from utils.visualizer import postprocess_image
from utils.visualizer import save_image
def parse_args():
"""Parses arguments."""
parser = argparse.ArgumentParser(
description='Synthesize images with pre-trained models.')
parser.add_argument('model_name', type=str,
help='Name to the pre-trained model.')
parser.add_argument('--save_dir', type=str, default=None,
help='Directory to save the results. If not specified, '
'the results will be saved to '
'`work_dirs/synthesis/` by default. '
'(default: %(default)s)')
parser.add_argument('--num', type=int, default=100,
help='Number of samples to synthesize. '
'(default: %(default)s)')
parser.add_argument('--batch_size', type=int, default=1,
help='Batch size. (default: %(default)s)')
parser.add_argument('--generate_html', type=bool_parser, default=True,
help='Whether to use HTML page to visualize the '
'synthesized results. (default: %(default)s)')
parser.add_argument('--save_raw_synthesis', type=bool_parser, default=False,
help='Whether to save raw synthesis. '
'(default: %(default)s)')
parser.add_argument('--seed', type=int, default=0,
help='Seed for sampling. (default: %(default)s)')
parser.add_argument('--trunc_psi', type=float, default=0.7,
help='Psi factor used for truncation. This is '
'particularly applicable to StyleGAN (v1/v2). '
'(default: %(default)s)')
parser.add_argument('--trunc_layers', type=int, default=8,
help='Number of layers to perform truncation. This is '
'particularly applicable to StyleGAN (v1/v2). '
'(default: %(default)s)')
parser.add_argument('--randomize_noise', type=bool_parser, default=False,
help='Whether to randomize the layer-wise noise. This '
'is particularly applicable to StyleGAN (v1/v2). '
'(default: %(default)s)')
return parser.parse_args()
def main():
"""Main function."""
args = parse_args()
if args.num <= 0:
return
if not args.save_raw_synthesis and not args.generate_html:
return
# Parse model configuration.
if args.model_name not in MODEL_ZOO:
raise SystemExit(f'Model `{args.model_name}` is not registered in '
f'`models/model_zoo.py`!')
model_config = MODEL_ZOO[args.model_name].copy()
url = model_config.pop('url') # URL to download model if needed.
# Get work directory and job name.
if args.save_dir:
work_dir = args.save_dir
else:
work_dir = os.path.join('work_dirs', 'synthesis')
os.makedirs(work_dir, exist_ok=True)
job_name = f'{args.model_name}_{args.num}'
if args.save_raw_synthesis:
os.makedirs(os.path.join(work_dir, job_name), exist_ok=True)
# Build generation and get synthesis kwargs.
print(f'Building generator for model `{args.model_name}` ...')
generator = build_generator(**model_config)
synthesis_kwargs = dict(trunc_psi=args.trunc_psi,
trunc_layers=args.trunc_layers,
randomize_noise=args.randomize_noise)
print(f'Finish building generator.')
# Load pre-trained weights.
os.makedirs('/import/nobackup_mmv_ioannisp/jo001/genforce_models', exist_ok=True)
checkpoint_path = os.path.join('/import/nobackup_mmv_ioannisp/jo001/genforce_models', args.model_name + '.pth')
print(f'Loading checkpoint from `{checkpoint_path}` ...')
if not os.path.exists(checkpoint_path):
print(f' Downloading checkpoint from `{url}` ...')
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
print(f' Finish downloading checkpoint.')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'generator_smooth' in checkpoint:
generator.load_state_dict(checkpoint['generator_smooth'])
else:
generator.load_state_dict(checkpoint['generator'])
generator = generator.cuda()
generator.eval()
print(f'Finish loading checkpoint.')
# Set random seed.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Sample and synthesize.
print(f'Synthesizing {args.num} samples ...')
indices = list(range(args.num))
if args.generate_html:
html = HtmlPageVisualizer(grid_size=args.num)
for batch_idx in tqdm(range(0, args.num, args.batch_size)):
sub_indices = indices[batch_idx:batch_idx + args.batch_size]
code = torch.randn(len(sub_indices), generator.z_space_dim).cuda()
with torch.no_grad():
images = generator(code, **synthesis_kwargs)['image']
images = postprocess_image(images.detach().cpu().numpy())
for sub_idx, image in zip(sub_indices, images):
if args.save_raw_synthesis:
save_path = os.path.join(
work_dir, job_name, f'{sub_idx:06d}.jpg')
save_image(save_path, image)
if args.generate_html:
row_idx, col_idx = divmod(sub_idx, html.num_cols)
html.set_cell(row_idx, col_idx, image=image,
text=f'Sample {sub_idx:06d}')
if args.generate_html:
html.save(os.path.join(work_dir, f'{job_name}.html'))
print(f'Finish synthesizing {args.num} samples.')
if __name__ == '__main__':
main()
|