File size: 3,441 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import torch
from swapae.evaluation import BaseEvaluator
import swapae.util as util
import numpy as np
from PIL import Image


class StructureStyleGridGenerationEvaluator(BaseEvaluator):
    """ generate swapping images and save to disk """
    @staticmethod
    def modify_commandline_options(parser, is_train):
        return parser

    def create_webpage(self, nsteps):
        nsteps = self.opt.resume_iter if nsteps is None else nsteps
        savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps))
        os.makedirs(savedir, exist_ok=True)
        webpage_title = "%s. iter=%s. phase=%s" % \
                        (self.opt.name, str(nsteps), self.target_phase)
        self.webpage = util.HTML(savedir, webpage_title)

    def add_to_webpage(self, images, filenames, tile=1):
        converted_images = []
        for image in images:
            if isinstance(image, list):
                image = torch.stack(image, dim=0).flatten(0, 1)
            image = Image.fromarray(util.tensor2im(image, tile=min(image.size(0), tile)))
            converted_images.append(image)

        self.webpage.add_images(converted_images,
                                filenames)
        print("saved %s" % str(filenames))
        #self.webpage.save()

    def evaluate(self, model, dataset, nsteps=None):
        self.create_webpage(nsteps)

        structure_images, style_images = {}, {}
        for i, data_i in enumerate(dataset):
            bs = data_i["real_A"].size(0)
            #sp, gl = model(data_i["real_A"].cuda(), command="encode")

            for j in range(bs):
                image = data_i["real_A"][j:j+1]
                path = data_i["path_A"][j]
                imagename = os.path.splitext(os.path.basename(path))[0]
                if "/structure/" in path:
                    structure_images[imagename] = image
                else:
                    style_images[imagename] = image

        gls = []
        style_paths = list(style_images.keys())
        for style_path in style_paths:
            style_image = style_images[style_path].cuda()
            gls.append(model(style_image, command="encode")[1])

        sps = []
        structure_paths = list(structure_images.keys())
        for structure_path in structure_paths:
            structure_image = structure_images[structure_path].cuda()
            sps.append(model(structure_image, command="encode")[0])

        # top row to show the input images
        blank_image = style_images[style_paths[0]] * 0.0 + 1.0
        self.add_to_webpage([blank_image] + [style_images[style_path] for style_path in style_paths],
                            ["blank.png"] + [style_path + ".png" for style_path in style_paths],
                            tile=1)

        # swapping
        for i, structure_path in enumerate(structure_paths):
            structure_image = structure_images[structure_path]
            swaps = []
            filenames = []
            for j, style_path in enumerate(style_paths):
                swaps.append(model(sps[i], gls[j], command="decode"))
                filenames.append(structure_path + "_" + style_path + ".png")
            self.add_to_webpage([structure_image] + swaps,
                                [structure_path + ".png"] + filenames,
                                tile=1)

            self.webpage.save()
        return {}