File size: 6,112 Bytes
8c212a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# python3.7
"""Contains the base class for Encoder (GAN Inversion) runner."""

import os
import shutil

import torch
import torch.distributed as dist

from utils.visualizer import HtmlPageVisualizer
from utils.visualizer import get_grid_shape
from utils.visualizer import postprocess_image
from utils.visualizer import save_image
from utils.visualizer import load_image
from .base_runner import BaseRunner

__all__ = ['BaseEncoderRunner']


class BaseEncoderRunner(BaseRunner):
    """Defines the base class for Encoder runner."""

    def __init__(self, config, logger):
        super().__init__(config, logger)
        self.inception_model = None

    def build_models(self):
        super().build_models()
        assert 'encoder' in self.models
        assert 'generator' in self.models
        assert 'discriminator' in self.models

        self.resolution = self.models['generator'].resolution
        self.G_kwargs_train = self.config.modules['generator'].get(
            'kwargs_train', dict())
        self.G_kwargs_val = self.config.modules['generator'].get(
            'kwargs_val', dict())
        self.D_kwargs_train = self.config.modules['discriminator'].get(
            'kwargs_train', dict())
        self.D_kwargs_val = self.config.modules['discriminator'].get(
            'kwargs_val', dict())

    def train_step(self, data, **train_kwargs):
        raise NotImplementedError('Should be implemented in derived class.')

    def val(self, **val_kwargs):
        self.synthesize(**val_kwargs)

    def synthesize(self,
                   num,
                   html_name=None,
                   save_raw_synthesis=False):
        """Synthesizes images.

        Args:
            num: Number of images to synthesize.
            z: Latent codes used for generation. If not specified, this function
                will sample latent codes randomly. (default: None)
            html_name: Name of the output html page for visualization. If not
                specified, no visualization page will be saved. (default: None)
            save_raw_synthesis: Whether to save raw synthesis on the disk.
                (default: False)
        """
        if not html_name and not save_raw_synthesis:
            return

        self.set_mode('val')

        if self.val_loader is None:
            self.build_dataset('val')

        temp_dir = os.path.join(self.work_dir, 'synthesize_results')
        os.makedirs(temp_dir, exist_ok=True)

        if not num:
            return
        if num % self.val_batch_size != 0:
            num =  (num //self.val_batch_size +1)*self.val_batch_size
        # TODO: Use same z during the entire training process.

        self.logger.init_pbar()
        task1 = self.logger.add_pbar_task('Synthesize', total=num)

        indices = list(range(self.rank, num, self.world_size))
        for batch_idx in range(0, len(indices), self.val_batch_size):
            sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
            batch_size = len(sub_indices)
            data = next(self.val_loader)
            for key in data:
                data[key] = data[key][:batch_size].cuda(
                    torch.cuda.current_device(), non_blocking=True)

            with torch.no_grad():
                real_images = data['image']
                E = self.models['encoder']
                if 'generator_smooth' in self.models:
                    G = self.get_module(self.models['generator_smooth'])
                else:
                    G = self.get_module(self.models['generator'])
                latents = E(real_images)
                if self.config.space_of_latent == 'z':
                    rec_images = G(
                        latents, **self.G_kwargs_val)['image']
                elif self.config.space_of_latent == 'wp':
                    rec_images = G.synthesis(
                        latents, **self.G_kwargs_val)['image']
                elif self.config.space_of_latent == 'y':
                    G.set_space_of_latent('y')
                    rec_images = G.synthesis(
                        latents, **self.G_kwargs_val)['image']
                else:
                    raise NotImplementedError(
                        f'Space of latent `{self.config.space_of_latent}` '
                        f'is not supported!')
                rec_images = postprocess_image(
                    rec_images.detach().cpu().numpy())
                real_images = postprocess_image(
                    real_images.detach().cpu().numpy())
            for sub_idx, rec_image, real_image in zip(
                    sub_indices, rec_images, real_images):
                save_image(os.path.join(temp_dir, f'{sub_idx:06d}_rec.jpg'),
                           rec_image)
                save_image(os.path.join(temp_dir, f'{sub_idx:06d}_ori.jpg'),
                           real_image)
            self.logger.update_pbar(task1, batch_size * self.world_size)

        dist.barrier()
        if self.rank != 0:
            return

        if html_name:
            task2 = self.logger.add_pbar_task('Visualize', total=num)
            row, col = get_grid_shape(num * 2)
            if row % 2 != 0:
                row, col = col, row
            html = HtmlPageVisualizer(num_rows=row, num_cols=col)
            for image_idx in range(num):
                rec_image = load_image(
                    os.path.join(temp_dir, f'{image_idx:06d}_rec.jpg'))
                real_image = load_image(
                    os.path.join(temp_dir, f'{image_idx:06d}_ori.jpg'))
                row_idx, col_idx = divmod(image_idx, html.num_cols)
                html.set_cell(2*row_idx, col_idx, image=real_image,
                              text=f'Sample {image_idx:06d}_ori')
                html.set_cell(2*row_idx+1, col_idx, image=rec_image,
                              text=f'Sample {image_idx:06d}_rec')
                self.logger.update_pbar(task2, 1)
            html.save(os.path.join(self.work_dir, html_name))
        if not save_raw_synthesis:
            shutil.rmtree(temp_dir)

        self.logger.close_pbar()