Last commit not found
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
"""Generate images using pretrained network pickle.""" | |
import argparse | |
import os | |
import pickle | |
import re | |
import numpy as np | |
import PIL.Image | |
import dnnlib | |
import dnnlib.tflib as tflib | |
#---------------------------------------------------------------------------- | |
def generate_images(network_pkl, seeds, truncation_psi, outdir, class_idx, dlatents_npz): | |
tflib.init_tf() | |
print('Loading networks from "%s"...' % network_pkl) | |
with dnnlib.util.open_url(network_pkl) as fp: | |
_G, _D, Gs = pickle.load(fp) | |
os.makedirs(outdir, exist_ok=True) | |
# Render images for a given dlatent vector. | |
if dlatents_npz is not None: | |
print(f'Generating images from dlatents file "{dlatents_npz}"') | |
dlatents = np.load(dlatents_npz)['dlatents'] | |
assert dlatents.shape[1:] == (18, 512) # [N, 18, 512] | |
imgs = Gs.components.synthesis.run(dlatents, output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)) | |
for i, img in enumerate(imgs): | |
fname = f'{outdir}/dlatent{i:02d}.png' | |
print (f'Saved {fname}') | |
PIL.Image.fromarray(img, 'RGB').save(fname) | |
return | |
# Render images for dlatents initialized from random seeds. | |
Gs_kwargs = { | |
'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), | |
'randomize_noise': False | |
} | |
if truncation_psi is not None: | |
Gs_kwargs['truncation_psi'] = truncation_psi | |
noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')] | |
label = np.zeros([1] + Gs.input_shapes[1][1:]) | |
if class_idx is not None: | |
label[:, class_idx] = 1 | |
for seed_idx, seed in enumerate(seeds): | |
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) | |
rnd = np.random.RandomState(seed) | |
z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component] | |
tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] | |
images = Gs.run(z, label, **Gs_kwargs) # [minibatch, height, width, channel] | |
PIL.Image.fromarray(images[0], 'RGB').save(f'{outdir}/seed{seed:04d}.png') | |
#---------------------------------------------------------------------------- | |
def _parse_num_range(s): | |
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' | |
range_re = re.compile(r'^(\d+)-(\d+)$') | |
m = range_re.match(s) | |
if m: | |
return list(range(int(m.group(1)), int(m.group(2))+1)) | |
vals = s.split(',') | |
return [int(x) for x in vals] | |
#---------------------------------------------------------------------------- | |
_examples = '''examples: | |
# Generate curated MetFaces images without truncation (Fig.10 left) | |
python %(prog)s --outdir=out --trunc=1 --seeds=85,265,297,849 \\ | |
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl | |
# Generate uncurated MetFaces images with truncation (Fig.12 upper left) | |
python %(prog)s --outdir=out --trunc=0.7 --seeds=600-605 \\ | |
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl | |
# Generate class conditional CIFAR-10 images (Fig.17 left, Car) | |
python %(prog)s --outdir=out --trunc=1 --seeds=0-35 --class=1 \\ | |
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/cifar10.pkl | |
# Render image from projected latent vector | |
python %(prog)s --outdir=out --dlatents=out/dlatents.npz \\ | |
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl | |
''' | |
#---------------------------------------------------------------------------- | |
def main(): | |
parser = argparse.ArgumentParser( | |
description='Generate images using pretrained network pickle.', | |
epilog=_examples, | |
formatter_class=argparse.RawDescriptionHelpFormatter | |
) | |
parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True) | |
g = parser.add_mutually_exclusive_group(required=True) | |
g.add_argument('--seeds', type=_parse_num_range, help='List of random seeds') | |
g.add_argument('--dlatents', dest='dlatents_npz', help='Generate images for saved dlatents') | |
parser.add_argument('--trunc', dest='truncation_psi', type=float, help='Truncation psi (default: %(default)s)', default=0.5) | |
parser.add_argument('--class', dest='class_idx', type=int, help='Class label (default: unconditional)') | |
parser.add_argument('--outdir', help='Where to save the output images', required=True, metavar='DIR') | |
args = parser.parse_args() | |
generate_images(**vars(args)) | |
#---------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
main() | |
#---------------------------------------------------------------------------- | |