File size: 4,898 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import glob
import torchvision.transforms as transforms
import os
import torch
from swapae.evaluation import BaseEvaluator
import swapae.util as util
from PIL import Image


class InputDataset(torch.utils.data.Dataset):
    def __init__(self, dataroot):
        structure_images = sorted(glob.glob(os.path.join(dataroot, "input_structure", "*.png")))
        style_images = sorted(glob.glob(os.path.join(dataroot, "input_style", "*.png")))

        for structure_path, style_path in zip(structure_images, style_images):
            assert structure_path.replace("structure", "style") == style_path, \
                "%s and %s do not match" % (structure_path, style_path)

        assert len(structure_images) == len(style_images)
        print("found %d images at %s" % (len(structure_images), dataroot))

        self.structure_images = structure_images
        self.style_images = style_images
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]
        )

    def __len__(self):
        return len(self.structure_images)

    def __getitem__(self, idx):
        structure_image = self.transform(Image.open(self.structure_images[idx]).convert('RGB'))
        style_image = self.transform(Image.open(self.style_images[idx]).convert('RGB'))
        return {'structure': structure_image,
                'style': style_image,
                'path': self.structure_images[idx]}


class SwapGenerationFromArrangedResultEvaluator(BaseEvaluator):
    """ Given two directories containing input structure and style (texture)
    images, respectively, generate reconstructed and swapped images.
    The input directories should contain the same set of image filenames. 
    It differs from StructureStyleGridGenerationEvaluator, which creates
    N^2 outputs (i.e. swapping of all possible pairs between the structure and
    style images).
    """
    @staticmethod
    def modify_commandline_options(parser, is_train):
        return parser

    def image_save_dir(self, nsteps):
        return os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps), "images")

    def create_webpage(self, nsteps):
        if nsteps is None:
            nsteps = self.opt.resume_iter
        elif isinstance(nsteps, int):
            nsteps = str(round(nsteps / 1000)) + "k"
        savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps))
        os.makedirs(savedir, exist_ok=True)
        webpage_title = "%s. iter=%s. phase=%s" % \
                        (self.opt.name, str(nsteps), self.target_phase)
        self.webpage = util.HTML(savedir, webpage_title)

    def add_to_webpage(self, images, filenames, tile=1):
        converted_images = []
        for image in images:
            if isinstance(image, list):
                image = torch.stack(image, dim=0).flatten(0, 1)
            image = Image.fromarray(util.tensor2im(image, tile=min(image.size(0), tile)))
            converted_images.append(image)

        self.webpage.add_images(converted_images,
                                filenames)
        print("saved %s" % str(filenames))
        #self.webpage.save()

    def set_num_test_images(self, num_images):
        self.num_test_images = num_images

    def evaluate(self, model, dataset, nsteps=None):
        input_dataset = torch.utils.data.DataLoader(
            InputDataset(self.opt.dataroot),
            batch_size=1,
            shuffle=False, drop_last=False, num_workers=0
        )

        self.num_test_images = None
        self.create_webpage(nsteps)
        image_num = 0
        for i, data_i in enumerate(input_dataset):
            structure = data_i["structure"].cuda()
            style = data_i["style"].cuda()
            path = data_i["path"][0]
            path = os.path.basename(path)
            #if "real_B" in data_i:
            #    image = torch.cat([image, data_i["real_B"].cuda()], dim=0)
            #    paths = paths + data_i["path_B"]
            sp, gl = model(structure, command="encode")
            rec = model(sp, gl, command="decode")

            _, gl = model(style, command="encode")
            swapped = model(sp, gl, command="decode")

            self.add_to_webpage([structure, style, rec, swapped],
                                ["%s_structure.png" % (path),
                                 "%s_style.png" % (path),
                                 "%s_rec.png" % (path),
                                 "%s_swap.png" % (path)],
                                tile=1)
            image_num += 1
            if self.num_test_images is not None and self.num_test_images <= image_num:
                self.webpage.save()
                return {}
                    
            self.webpage.save()
        return {}