dattarij's picture
adding ContraCLIP folder
8c212a5
raw
history blame
2.68 kB
# python3.7
"""Contains the runner for StyleGAN."""
from copy import deepcopy
from .base_gan_runner import BaseGANRunner
__all__ = ['StyleGANRunner']
class StyleGANRunner(BaseGANRunner):
"""Defines the runner for StyleGAN."""
def __init__(self, config, logger):
super().__init__(config, logger)
self.lod = getattr(self, 'lod', None)
def build_models(self):
super().build_models()
self.g_smooth_img = self.config.modules['generator'].get(
'g_smooth_img', 10000)
self.models['generator_smooth'] = deepcopy(self.models['generator'])
def build_loss(self):
super().build_loss()
self.running_stats.add(
f'Gs_beta', log_format='.4f', log_strategy='CURRENT')
def train_step(self, data, **train_kwargs):
# Set level-of-details.
G = self.get_module(self.models['generator'])
D = self.get_module(self.models['discriminator'])
Gs = self.get_module(self.models['generator_smooth'])
G.synthesis.lod.data.fill_(self.lod)
D.lod.data.fill_(self.lod)
Gs.synthesis.lod.data.fill_(self.lod)
# Update discriminator.
self.set_model_requires_grad('discriminator', True)
self.set_model_requires_grad('generator', False)
d_loss = self.loss.d_loss(self, data)
self.optimizers['discriminator'].zero_grad()
d_loss.backward()
self.optimizers['discriminator'].step()
# Life-long update for generator.
beta = 0.5 ** (self.batch_size * self.world_size / self.g_smooth_img)
self.running_stats.update({'Gs_beta': beta})
self.moving_average_model(model=self.models['generator'],
avg_model=self.models['generator_smooth'],
beta=beta)
# Update generator.
if self._iter % self.config.get('D_repeats', 1) == 0:
self.set_model_requires_grad('discriminator', False)
self.set_model_requires_grad('generator', True)
g_loss = self.loss.g_loss(self, data)
self.optimizers['generator'].zero_grad()
g_loss.backward()
self.optimizers['generator'].step()
def load(self, **kwargs):
super().load(**kwargs)
G = self.get_module(self.models['generator'])
D = self.get_module(self.models['discriminator'])
Gs = self.get_module(self.models['generator_smooth'])
if kwargs['running_metadata']:
lod = G.synthesis.lod.cpu().tolist()
assert lod == D.lod.cpu().tolist()
assert lod == Gs.synthesis.lod.cpu().tolist()
self.lod = lod