File size: 4,121 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from PIL import Image
import numpy as np
import torch
from swapae.evaluation import BaseEvaluator
import swapae.util as util


class SwapVisualizationEvaluator(BaseEvaluator):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.add_argument("--swap_num_columns", type=int, default=4,
                            help="number of images to be shown in the swap visualization grid. Setting this value will result in 4x4 swapping grid, with additional row and col for showing original images.")
        parser.add_argument("--swap_num_images", type=int, default=16,
                            help="total number of images to perform swapping. In the end, (swap_num_images / swap_num_columns) grid will be saved to disk")
        return parser

    def gather_images(self, dataset):
        all_images = []
        num_images_to_gather = max(self.opt.swap_num_columns, self.opt.num_gpus)
        exhausted = False
        while len(all_images) < num_images_to_gather:
            try:
                data = next(dataset)
            except StopIteration:
                print("Exhausted the dataset at %s" % (self.opt.dataroot))
                exhausted = True
                break
            for i in range(data["real_A"].size(0)):
                all_images.append(data["real_A"][i:i+1])
                if "real_B" in data:
                    all_images.append(data["real_B"][i:i+1])
                if len(all_images) >= num_images_to_gather:
                    break
        if len(all_images) == 0:
            return None, None, True
        return all_images, exhausted

    def generate_mix_grid(self, model, images):
        sps, gls = [], []
        for image in images:
            assert image.size(0) == 1
            sp, gl = model(image.expand(self.opt.num_gpus, -1, -1, -1), command="encode")
            sp = sp[:1]
            gl = gl[:1]
            sps.append(sp)
            gls.append(gl)
        gl = torch.cat(gls, dim=0)
        
        def put_img(img, canvas, row, col):
            h, w = img.shape[0], img.shape[1]
            start_x = int(self.opt.load_size * col + (self.opt.load_size - w) * 0.5)
            start_y = int(self.opt.load_size * row + (self.opt.load_size - h) * 0.5)
            canvas[start_y:start_y + h, start_x: start_x + w] = img
        grid_w = self.opt.load_size * (gl.size(0) + 1)
        grid_h = self.opt.load_size * (gl.size(0) + 1)
        grid_img = np.ones((grid_h, grid_w, 3), dtype=np.uint8)
        #images_np = util.tensor2im(images, tile=False)
        for i, image in enumerate(images):
            image_np = util.tensor2im(image, tile=False)[0]
            put_img(image_np, grid_img, 0, i + 1)
            put_img(image_np, grid_img, i + 1, 0)

        for i, sp in enumerate(sps):
            sp_for_current_row = sp.repeat(gl.size(0), 1, 1, 1)
            mix_row = model(sp_for_current_row, gl, command="decode")
            mix_row = util.tensor2im(mix_row, tile=False)
            for j, mix in enumerate(mix_row):
                put_img(mix, grid_img, i + 1, j + 1)

        final_grid = Image.fromarray(grid_img)
        return final_grid

    def evaluate(self, model, dataset, nsteps):
        nsteps = self.opt.resume_iter if nsteps is None else str(round(nsteps / 1000)) + "k"
        savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps))
        os.makedirs(savedir, exist_ok=True)
        webpage_title = "Swap Visualization of %s. iter=%s. phase=%s" % \
                        (self.opt.name, str(nsteps), self.target_phase)
        webpage = util.HTML(savedir, webpage_title)
        num_repeats = int(np.ceil(self.opt.swap_num_images / max(self.opt.swap_num_columns, self.opt.num_gpus)))
        for i in range(num_repeats):
            images, should_break = self.gather_images(dataset)
            if images is None:
                break
            mix_grid = self.generate_mix_grid(model, images)
            webpage.add_images([mix_grid], ["%04d.png" % i])
            if should_break:
                break
        webpage.save()
        return {}