|
import argparse |
|
|
|
import torch |
|
from torchvision import utils |
|
from model import Generator |
|
from tqdm import tqdm |
|
def generate(args, g_ema, device, mean_latent): |
|
|
|
with torch.no_grad(): |
|
g_ema.eval() |
|
for i in tqdm(range(args.pics)): |
|
sample_z = torch.randn(args.sample, args.latent, device=device) |
|
|
|
sample, _ = g_ema([sample_z], truncation=args.truncation, truncation_latent=mean_latent) |
|
|
|
utils.save_image( |
|
sample, |
|
f'sample/{str(i).zfill(6)}.png', |
|
nrow=1, |
|
normalize=True, |
|
range=(-1, 1), |
|
) |
|
|
|
if __name__ == '__main__': |
|
device = 'cuda' |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--size', type=int, default=1024) |
|
parser.add_argument('--sample', type=int, default=1) |
|
parser.add_argument('--pics', type=int, default=20) |
|
parser.add_argument('--truncation', type=float, default=1) |
|
parser.add_argument('--truncation_mean', type=int, default=4096) |
|
parser.add_argument('--ckpt', type=str, default="stylegan2-ffhq-config-f.pt") |
|
parser.add_argument('--channel_multiplier', type=int, default=2) |
|
|
|
args = parser.parse_args() |
|
|
|
args.latent = 512 |
|
args.n_mlp = 8 |
|
|
|
g_ema = Generator( |
|
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier |
|
).to(device) |
|
checkpoint = torch.load(args.ckpt) |
|
|
|
g_ema.load_state_dict(checkpoint['g_ema']) |
|
|
|
if args.truncation < 1: |
|
with torch.no_grad(): |
|
mean_latent = g_ema.mean_latent(args.truncation_mean) |
|
else: |
|
mean_latent = None |
|
|
|
generate(args, g_ema, device, mean_latent) |
|
|