dattarij's picture
adding ContraCLIP folder
8c212a5
raw
history blame
6.11 kB
# python3.7
"""Contains the base class for Encoder (GAN Inversion) runner."""
import os
import shutil
import torch
import torch.distributed as dist
from utils.visualizer import HtmlPageVisualizer
from utils.visualizer import get_grid_shape
from utils.visualizer import postprocess_image
from utils.visualizer import save_image
from utils.visualizer import load_image
from .base_runner import BaseRunner
__all__ = ['BaseEncoderRunner']
class BaseEncoderRunner(BaseRunner):
"""Defines the base class for Encoder runner."""
def __init__(self, config, logger):
super().__init__(config, logger)
self.inception_model = None
def build_models(self):
super().build_models()
assert 'encoder' in self.models
assert 'generator' in self.models
assert 'discriminator' in self.models
self.resolution = self.models['generator'].resolution
self.G_kwargs_train = self.config.modules['generator'].get(
'kwargs_train', dict())
self.G_kwargs_val = self.config.modules['generator'].get(
'kwargs_val', dict())
self.D_kwargs_train = self.config.modules['discriminator'].get(
'kwargs_train', dict())
self.D_kwargs_val = self.config.modules['discriminator'].get(
'kwargs_val', dict())
def train_step(self, data, **train_kwargs):
raise NotImplementedError('Should be implemented in derived class.')
def val(self, **val_kwargs):
self.synthesize(**val_kwargs)
def synthesize(self,
num,
html_name=None,
save_raw_synthesis=False):
"""Synthesizes images.
Args:
num: Number of images to synthesize.
z: Latent codes used for generation. If not specified, this function
will sample latent codes randomly. (default: None)
html_name: Name of the output html page for visualization. If not
specified, no visualization page will be saved. (default: None)
save_raw_synthesis: Whether to save raw synthesis on the disk.
(default: False)
"""
if not html_name and not save_raw_synthesis:
return
self.set_mode('val')
if self.val_loader is None:
self.build_dataset('val')
temp_dir = os.path.join(self.work_dir, 'synthesize_results')
os.makedirs(temp_dir, exist_ok=True)
if not num:
return
if num % self.val_batch_size != 0:
num = (num //self.val_batch_size +1)*self.val_batch_size
# TODO: Use same z during the entire training process.
self.logger.init_pbar()
task1 = self.logger.add_pbar_task('Synthesize', total=num)
indices = list(range(self.rank, num, self.world_size))
for batch_idx in range(0, len(indices), self.val_batch_size):
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
batch_size = len(sub_indices)
data = next(self.val_loader)
for key in data:
data[key] = data[key][:batch_size].cuda(
torch.cuda.current_device(), non_blocking=True)
with torch.no_grad():
real_images = data['image']
E = self.models['encoder']
if 'generator_smooth' in self.models:
G = self.get_module(self.models['generator_smooth'])
else:
G = self.get_module(self.models['generator'])
latents = E(real_images)
if self.config.space_of_latent == 'z':
rec_images = G(
latents, **self.G_kwargs_val)['image']
elif self.config.space_of_latent == 'wp':
rec_images = G.synthesis(
latents, **self.G_kwargs_val)['image']
elif self.config.space_of_latent == 'y':
G.set_space_of_latent('y')
rec_images = G.synthesis(
latents, **self.G_kwargs_val)['image']
else:
raise NotImplementedError(
f'Space of latent `{self.config.space_of_latent}` '
f'is not supported!')
rec_images = postprocess_image(
rec_images.detach().cpu().numpy())
real_images = postprocess_image(
real_images.detach().cpu().numpy())
for sub_idx, rec_image, real_image in zip(
sub_indices, rec_images, real_images):
save_image(os.path.join(temp_dir, f'{sub_idx:06d}_rec.jpg'),
rec_image)
save_image(os.path.join(temp_dir, f'{sub_idx:06d}_ori.jpg'),
real_image)
self.logger.update_pbar(task1, batch_size * self.world_size)
dist.barrier()
if self.rank != 0:
return
if html_name:
task2 = self.logger.add_pbar_task('Visualize', total=num)
row, col = get_grid_shape(num * 2)
if row % 2 != 0:
row, col = col, row
html = HtmlPageVisualizer(num_rows=row, num_cols=col)
for image_idx in range(num):
rec_image = load_image(
os.path.join(temp_dir, f'{image_idx:06d}_rec.jpg'))
real_image = load_image(
os.path.join(temp_dir, f'{image_idx:06d}_ori.jpg'))
row_idx, col_idx = divmod(image_idx, html.num_cols)
html.set_cell(2*row_idx, col_idx, image=real_image,
text=f'Sample {image_idx:06d}_ori')
html.set_cell(2*row_idx+1, col_idx, image=rec_image,
text=f'Sample {image_idx:06d}_rec')
self.logger.update_pbar(task2, 1)
html.save(os.path.join(self.work_dir, html_name))
if not save_raw_synthesis:
shutil.rmtree(temp_dir)
self.logger.close_pbar()