disentangled-image-editing-final-project
/
ContraCLIP
/models
/genforce
/runners
/stylegan_runner.py
# python3.7 | |
"""Contains the runner for StyleGAN.""" | |
from copy import deepcopy | |
from .base_gan_runner import BaseGANRunner | |
__all__ = ['StyleGANRunner'] | |
class StyleGANRunner(BaseGANRunner): | |
"""Defines the runner for StyleGAN.""" | |
def __init__(self, config, logger): | |
super().__init__(config, logger) | |
self.lod = getattr(self, 'lod', None) | |
def build_models(self): | |
super().build_models() | |
self.g_smooth_img = self.config.modules['generator'].get( | |
'g_smooth_img', 10000) | |
self.models['generator_smooth'] = deepcopy(self.models['generator']) | |
def build_loss(self): | |
super().build_loss() | |
self.running_stats.add( | |
f'Gs_beta', log_format='.4f', log_strategy='CURRENT') | |
def train_step(self, data, **train_kwargs): | |
# Set level-of-details. | |
G = self.get_module(self.models['generator']) | |
D = self.get_module(self.models['discriminator']) | |
Gs = self.get_module(self.models['generator_smooth']) | |
G.synthesis.lod.data.fill_(self.lod) | |
D.lod.data.fill_(self.lod) | |
Gs.synthesis.lod.data.fill_(self.lod) | |
# Update discriminator. | |
self.set_model_requires_grad('discriminator', True) | |
self.set_model_requires_grad('generator', False) | |
d_loss = self.loss.d_loss(self, data) | |
self.optimizers['discriminator'].zero_grad() | |
d_loss.backward() | |
self.optimizers['discriminator'].step() | |
# Life-long update for generator. | |
beta = 0.5 ** (self.batch_size * self.world_size / self.g_smooth_img) | |
self.running_stats.update({'Gs_beta': beta}) | |
self.moving_average_model(model=self.models['generator'], | |
avg_model=self.models['generator_smooth'], | |
beta=beta) | |
# Update generator. | |
if self._iter % self.config.get('D_repeats', 1) == 0: | |
self.set_model_requires_grad('discriminator', False) | |
self.set_model_requires_grad('generator', True) | |
g_loss = self.loss.g_loss(self, data) | |
self.optimizers['generator'].zero_grad() | |
g_loss.backward() | |
self.optimizers['generator'].step() | |
def load(self, **kwargs): | |
super().load(**kwargs) | |
G = self.get_module(self.models['generator']) | |
D = self.get_module(self.models['discriminator']) | |
Gs = self.get_module(self.models['generator_smooth']) | |
if kwargs['running_metadata']: | |
lod = G.synthesis.lod.cpu().tolist() | |
assert lod == D.lod.cpu().tolist() | |
assert lod == Gs.synthesis.lod.cpu().tolist() | |
self.lod = lod | |