|
|
|
"""A simple tool to synthesize images with pre-trained models.""" |
|
|
|
import os |
|
import argparse |
|
import subprocess |
|
from tqdm import tqdm |
|
import numpy as np |
|
|
|
import torch |
|
|
|
from models import MODEL_ZOO |
|
from models import build_generator |
|
from utils.misc import bool_parser |
|
from utils.visualizer import HtmlPageVisualizer |
|
from utils.visualizer import postprocess_image |
|
from utils.visualizer import save_image |
|
|
|
|
|
def parse_args(): |
|
"""Parses arguments.""" |
|
parser = argparse.ArgumentParser( |
|
description='Synthesize images with pre-trained models.') |
|
parser.add_argument('model_name', type=str, |
|
help='Name to the pre-trained model.') |
|
parser.add_argument('--save_dir', type=str, default=None, |
|
help='Directory to save the results. If not specified, ' |
|
'the results will be saved to ' |
|
'`work_dirs/synthesis/` by default. ' |
|
'(default: %(default)s)') |
|
parser.add_argument('--num', type=int, default=100, |
|
help='Number of samples to synthesize. ' |
|
'(default: %(default)s)') |
|
parser.add_argument('--batch_size', type=int, default=1, |
|
help='Batch size. (default: %(default)s)') |
|
parser.add_argument('--generate_html', type=bool_parser, default=True, |
|
help='Whether to use HTML page to visualize the ' |
|
'synthesized results. (default: %(default)s)') |
|
parser.add_argument('--save_raw_synthesis', type=bool_parser, default=False, |
|
help='Whether to save raw synthesis. ' |
|
'(default: %(default)s)') |
|
parser.add_argument('--seed', type=int, default=0, |
|
help='Seed for sampling. (default: %(default)s)') |
|
parser.add_argument('--trunc_psi', type=float, default=0.7, |
|
help='Psi factor used for truncation. This is ' |
|
'particularly applicable to StyleGAN (v1/v2). ' |
|
'(default: %(default)s)') |
|
parser.add_argument('--trunc_layers', type=int, default=8, |
|
help='Number of layers to perform truncation. This is ' |
|
'particularly applicable to StyleGAN (v1/v2). ' |
|
'(default: %(default)s)') |
|
parser.add_argument('--randomize_noise', type=bool_parser, default=False, |
|
help='Whether to randomize the layer-wise noise. This ' |
|
'is particularly applicable to StyleGAN (v1/v2). ' |
|
'(default: %(default)s)') |
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
"""Main function.""" |
|
args = parse_args() |
|
if args.num <= 0: |
|
return |
|
if not args.save_raw_synthesis and not args.generate_html: |
|
return |
|
|
|
|
|
if args.model_name not in MODEL_ZOO: |
|
raise SystemExit(f'Model `{args.model_name}` is not registered in ' |
|
f'`models/model_zoo.py`!') |
|
model_config = MODEL_ZOO[args.model_name].copy() |
|
url = model_config.pop('url') |
|
|
|
|
|
if args.save_dir: |
|
work_dir = args.save_dir |
|
else: |
|
work_dir = os.path.join('work_dirs', 'synthesis') |
|
os.makedirs(work_dir, exist_ok=True) |
|
job_name = f'{args.model_name}_{args.num}' |
|
if args.save_raw_synthesis: |
|
os.makedirs(os.path.join(work_dir, job_name), exist_ok=True) |
|
|
|
|
|
print(f'Building generator for model `{args.model_name}` ...') |
|
generator = build_generator(**model_config) |
|
synthesis_kwargs = dict(trunc_psi=args.trunc_psi, |
|
trunc_layers=args.trunc_layers, |
|
randomize_noise=args.randomize_noise) |
|
print(f'Finish building generator.') |
|
|
|
|
|
os.makedirs('/import/nobackup_mmv_ioannisp/jo001/genforce_models', exist_ok=True) |
|
checkpoint_path = os.path.join('/import/nobackup_mmv_ioannisp/jo001/genforce_models', args.model_name + '.pth') |
|
print(f'Loading checkpoint from `{checkpoint_path}` ...') |
|
if not os.path.exists(checkpoint_path): |
|
print(f' Downloading checkpoint from `{url}` ...') |
|
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url]) |
|
print(f' Finish downloading checkpoint.') |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
if 'generator_smooth' in checkpoint: |
|
generator.load_state_dict(checkpoint['generator_smooth']) |
|
else: |
|
generator.load_state_dict(checkpoint['generator']) |
|
generator = generator.cuda() |
|
generator.eval() |
|
print(f'Finish loading checkpoint.') |
|
|
|
|
|
np.random.seed(args.seed) |
|
torch.manual_seed(args.seed) |
|
|
|
|
|
print(f'Synthesizing {args.num} samples ...') |
|
indices = list(range(args.num)) |
|
if args.generate_html: |
|
html = HtmlPageVisualizer(grid_size=args.num) |
|
for batch_idx in tqdm(range(0, args.num, args.batch_size)): |
|
sub_indices = indices[batch_idx:batch_idx + args.batch_size] |
|
code = torch.randn(len(sub_indices), generator.z_space_dim).cuda() |
|
with torch.no_grad(): |
|
images = generator(code, **synthesis_kwargs)['image'] |
|
images = postprocess_image(images.detach().cpu().numpy()) |
|
for sub_idx, image in zip(sub_indices, images): |
|
if args.save_raw_synthesis: |
|
save_path = os.path.join( |
|
work_dir, job_name, f'{sub_idx:06d}.jpg') |
|
save_image(save_path, image) |
|
if args.generate_html: |
|
row_idx, col_idx = divmod(sub_idx, html.num_cols) |
|
html.set_cell(row_idx, col_idx, image=image, |
|
text=f'Sample {sub_idx:06d}') |
|
if args.generate_html: |
|
html.save(os.path.join(work_dir, f'{job_name}.html')) |
|
print(f'Finish synthesizing {args.num} samples.') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|