File size: 3,875 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
'''
A simple tool to generate a sample of output of a GAN.
'''

import torch, numpy, os, argparse, numbers, sys, shutil
from PIL import Image
from torch.utils.data import TensorDataset
from .zdataset import standard_z_sample
from . import pbar
from .autoeval import autoimport_eval
from .workerpool import WorkerBase, WorkerPool

def main():
    parser = argparse.ArgumentParser(description='GAN sample making utility')
    parser.add_argument('--model', type=str, default=None,
            help='constructor for the model to test')
    parser.add_argument('--pthfile', type=str, default=None,
            help='filename of .pth file for the model')
    parser.add_argument('--outdir', type=str, default='images',
            help='directory for image output')
    parser.add_argument('--size', type=int, default=100,
            help='number of images to output')
    parser.add_argument('--seed', type=int, default=1,
            help='seed')
    parser.add_argument('--quiet', action='store_true', default=False,
            help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()
    pbar.quiet(args.quiet)

    # Instantiate the model
    model = autoimport_eval(args.model)
    if args.pthfile is not None:
        data = torch.load(args.pthfile)
        if 'state_dict' in data:
            data = data['state_dict']
        model.load_state_dict(data)
    # Unwrap any DataParallel-wrapped model
    if isinstance(model, torch.nn.DataParallel):
        model = next(model.children())
    # Examine first conv in model to determine input feature size.
    first_layer = [c for c in model.modules()
            if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
                torch.nn.Linear))][0]
    # 4d input if convolutional, 2d input if first layer is linear.
    if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
        z_channels = first_layer.in_channels
        spatialdims = (1, 1)
    else:
        z_channels = first_layer.in_features
        spatialdims = ()
    model.cuda()

    # Get the sample of z vectors
    indexes = torch.arange(args.size)
    z_sample = standard_z_sample(args.size, z_channels, seed=args.seed)
    z_sample = z_sample.view(tuple(z_sample.shape) + spatialdims)

    save_znum_images(args.outdir, model, z_sample, indexes)
    copy_lightbox_to(args.outdir)



def save_znum_images(dirname, model, z_sample, indexes,
        name_template="image_{}.png", lightbox=False, batch_size=100):
    os.makedirs(dirname, exist_ok=True)
    with torch.no_grad():
        # Now generate images
        z_loader = torch.utils.data.DataLoader(TensorDataset(z_sample),
                    batch_size=batch_size, num_workers=2,
                    pin_memory=True)
        saver = WorkerPool(SaveImageWorker)
        for batch_num, [z] in enumerate(pbar(z_loader,
                desc='Saving images')):
            z = z.cuda()
            start_index = batch_num * batch_size
            im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute(
                    0, 2, 3, 1).cpu()
            for i in range(len(im)):
                index = i + start_index
                if indexes is not None:
                    index = indexes[index].item()
                filename = os.path.join(dirname, name_template.format(index))
                saver.add(im[i].numpy(), filename)
    saver.join()

def copy_lightbox_to(dirname):
   srcdir = os.path.realpath(
       os.path.join(os.getcwd(), os.path.dirname(__file__)))
   shutil.copy(os.path.join(srcdir, 'lightbox.html'),
           os.path.join(dirname, '+lightbox.html'))

class SaveImageWorker(WorkerBase):
    def work(self, data, filename):
        Image.fromarray(data).save(filename, optimize=True, quality=100)

if __name__ == '__main__':
    main()