dattarij's picture
adding ContraCLIP folder
8c212a5
raw
history blame
11 kB
# python3.7
"""Contains the base class for GAN runner."""
import os
import shutil
import numpy as np
import torch
import torch.distributed as dist
from metrics.inception import build_inception_model
from metrics.fid import extract_feature
from metrics.fid import compute_fid
from utils.visualizer import HtmlPageVisualizer
from utils.visualizer import postprocess_image
from utils.visualizer import save_image
from utils.visualizer import load_image
from .base_runner import BaseRunner
__all__ = ['BaseGANRunner']
class BaseGANRunner(BaseRunner):
"""Defines the base class for GAN runner."""
def __init__(self, config, logger):
super().__init__(config, logger)
self.inception_model = None
def moving_average_model(self, model, avg_model, beta=0.999):
"""Moving average model weights.
This trick is commonly used in GAN training, where the weight of the
generator is life-long averaged
Args:
model: The latest model used to update the averaged weights.
avg_model: The averaged model weights.
beta: Hyper-parameter used for moving average.
"""
model_params = dict(self.get_module(model).named_parameters())
avg_params = dict(self.get_module(avg_model).named_parameters())
assert len(model_params) == len(avg_params)
for param_name in avg_params:
assert param_name in model_params
avg_params[param_name].data = (
avg_params[param_name].data * beta +
model_params[param_name].data * (1 - beta))
def build_models(self):
super().build_models()
assert 'generator' in self.models
assert 'discriminator' in self.models
self.z_space_dim = self.models['generator'].z_space_dim
self.resolution = self.models['generator'].resolution
self.G_kwargs_train = self.config.modules['generator'].get(
'kwargs_train', dict())
self.G_kwargs_val = self.config.modules['generator'].get(
'kwargs_val', dict())
self.D_kwargs_train = self.config.modules['discriminator'].get(
'kwargs_train', dict())
self.D_kwargs_val = self.config.modules['discriminator'].get(
'kwargs_val', dict())
def train_step(self, data, **train_kwargs):
raise NotImplementedError('Should be implemented in derived class.')
def val(self, **val_kwargs):
self.synthesize(**val_kwargs)
def synthesize(self,
num,
z=None,
html_name=None,
save_raw_synthesis=False):
"""Synthesizes images.
Args:
num: Number of images to synthesize.
z: Latent codes used for generation. If not specified, this function
will sample latent codes randomly. (default: None)
html_name: Name of the output html page for visualization. If not
specified, no visualization page will be saved. (default: None)
save_raw_synthesis: Whether to save raw synthesis on the disk.
(default: False)
"""
if not html_name and not save_raw_synthesis:
return
self.set_mode('val')
temp_dir = os.path.join(self.work_dir, 'synthesize_results')
os.makedirs(temp_dir, exist_ok=True)
if z is not None:
assert isinstance(z, np.ndarray)
assert z.ndim == 2 and z.shape[1] == self.z_space_dim
num = min(num, z.shape[0])
z = torch.from_numpy(z).type(torch.FloatTensor)
if not num:
return
# TODO: Use same z during the entire training process.
self.logger.init_pbar()
task1 = self.logger.add_pbar_task('Synthesize', total=num)
indices = list(range(self.rank, num, self.world_size))
for batch_idx in range(0, len(indices), self.val_batch_size):
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
batch_size = len(sub_indices)
if z is None:
code = torch.randn(batch_size, self.z_space_dim).cuda()
else:
code = z[sub_indices].cuda()
with torch.no_grad():
if 'generator_smooth' in self.models:
G = self.models['generator_smooth']
else:
G = self.models['generator']
images = G(code, **self.G_kwargs_val)['image']
images = postprocess_image(images.detach().cpu().numpy())
for sub_idx, image in zip(sub_indices, images):
save_image(os.path.join(temp_dir, f'{sub_idx:06d}.jpg'), image)
self.logger.update_pbar(task1, batch_size * self.world_size)
dist.barrier()
if self.rank != 0:
return
if html_name:
task2 = self.logger.add_pbar_task('Visualize', total=num)
html = HtmlPageVisualizer(grid_size=num)
for image_idx in range(num):
image = load_image(
os.path.join(temp_dir, f'{image_idx:06d}.jpg'))
row_idx, col_idx = divmod(image_idx, html.num_cols)
html.set_cell(row_idx, col_idx, image=image,
text=f'Sample {image_idx:06d}')
self.logger.update_pbar(task2, 1)
html.save(os.path.join(self.work_dir, html_name))
if not save_raw_synthesis:
shutil.rmtree(temp_dir)
self.logger.close_pbar()
def fid(self,
fid_num,
z=None,
ignore_cache=False,
align_tf=True):
"""Computes the FID metric."""
self.set_mode('val')
if self.val_loader is None:
self.build_dataset('val')
fid_num = min(fid_num, len(self.val_loader.dataset))
if self.inception_model is None:
if align_tf:
self.logger.info(f'Building inception model '
f'(aligned with TensorFlow) ...')
else:
self.logger.info(f'Building inception model '
f'(using torchvision) ...')
self.inception_model = build_inception_model(align_tf).cuda()
self.logger.info(f'Finish building inception model.')
if z is not None:
assert isinstance(z, np.ndarray)
assert z.ndim == 2 and z.shape[1] == self.z_space_dim
fid_num = min(fid_num, z.shape[0])
z = torch.from_numpy(z).type(torch.FloatTensor)
if not fid_num:
return -1
indices = list(range(self.rank, fid_num, self.world_size))
self.logger.init_pbar()
# Extract features from fake images.
fake_feature_list = []
task1 = self.logger.add_pbar_task('Fake', total=fid_num)
for batch_idx in range(0, len(indices), self.val_batch_size):
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
batch_size = len(sub_indices)
if z is None:
code = torch.randn(batch_size, self.z_space_dim).cuda()
else:
code = z[sub_indices].cuda()
with torch.no_grad():
if 'generator_smooth' in self.models:
G = self.models['generator_smooth']
else:
G = self.models['generator']
fake_images = G(code)['image']
fake_feature_list.append(
extract_feature(self.inception_model, fake_images))
self.logger.update_pbar(task1, batch_size * self.world_size)
np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy',
np.concatenate(fake_feature_list, axis=0))
# Extract features from real images if needed.
cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy'
do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache)
if do_real_test:
real_feature_list = []
task2 = self.logger.add_pbar_task("Real", total=fid_num)
for batch_idx in range(0, len(indices), self.val_batch_size):
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
batch_size = len(sub_indices)
data = next(self.val_loader)
for key in data:
data[key] = data[key][:batch_size].cuda(
torch.cuda.current_device(), non_blocking=True)
with torch.no_grad():
real_images = data['image']
real_feature_list.append(
extract_feature(self.inception_model, real_images))
self.logger.update_pbar(task2, batch_size * self.world_size)
np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy',
np.concatenate(real_feature_list, axis=0))
dist.barrier()
if self.rank != 0:
return -1
self.logger.close_pbar()
# Collect fake features.
fake_feature_list.clear()
for rank in range(self.world_size):
fake_feature_list.append(
np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy'))
os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy')
fake_features = np.concatenate(fake_feature_list, axis=0)
assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num
feature_dim = fake_features.shape[1]
pad = fid_num % self.world_size
if pad:
pad = self.world_size - pad
fake_features = np.pad(fake_features, ((0, pad), (0, 0)))
fake_features = fake_features.reshape(self.world_size, -1, feature_dim)
fake_features = fake_features.transpose(1, 0, 2)
fake_features = fake_features.reshape(-1, feature_dim)[:fid_num]
# Collect (or load) real features.
if do_real_test:
real_feature_list.clear()
for rank in range(self.world_size):
real_feature_list.append(
np.load(f'{self.work_dir}/real_fid_features_{rank}.npy'))
os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy')
real_features = np.concatenate(real_feature_list, axis=0)
assert real_features.shape == (fid_num, feature_dim)
real_features = np.pad(real_features, ((0, pad), (0, 0)))
real_features = real_features.reshape(
self.world_size, -1, feature_dim)
real_features = real_features.transpose(1, 0, 2)
real_features = real_features.reshape(-1, feature_dim)[:fid_num]
np.save(cached_fid_file, real_features)
else:
real_features = np.load(cached_fid_file)
assert real_features.shape == (fid_num, feature_dim)
fid_value = compute_fid(fake_features, real_features)
return fid_value