File size: 2,675 Bytes
8c212a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# python3.7
"""Contains the runner for StyleGAN."""
from copy import deepcopy
from .base_gan_runner import BaseGANRunner
__all__ = ['StyleGANRunner']
class StyleGANRunner(BaseGANRunner):
"""Defines the runner for StyleGAN."""
def __init__(self, config, logger):
super().__init__(config, logger)
self.lod = getattr(self, 'lod', None)
def build_models(self):
super().build_models()
self.g_smooth_img = self.config.modules['generator'].get(
'g_smooth_img', 10000)
self.models['generator_smooth'] = deepcopy(self.models['generator'])
def build_loss(self):
super().build_loss()
self.running_stats.add(
f'Gs_beta', log_format='.4f', log_strategy='CURRENT')
def train_step(self, data, **train_kwargs):
# Set level-of-details.
G = self.get_module(self.models['generator'])
D = self.get_module(self.models['discriminator'])
Gs = self.get_module(self.models['generator_smooth'])
G.synthesis.lod.data.fill_(self.lod)
D.lod.data.fill_(self.lod)
Gs.synthesis.lod.data.fill_(self.lod)
# Update discriminator.
self.set_model_requires_grad('discriminator', True)
self.set_model_requires_grad('generator', False)
d_loss = self.loss.d_loss(self, data)
self.optimizers['discriminator'].zero_grad()
d_loss.backward()
self.optimizers['discriminator'].step()
# Life-long update for generator.
beta = 0.5 ** (self.batch_size * self.world_size / self.g_smooth_img)
self.running_stats.update({'Gs_beta': beta})
self.moving_average_model(model=self.models['generator'],
avg_model=self.models['generator_smooth'],
beta=beta)
# Update generator.
if self._iter % self.config.get('D_repeats', 1) == 0:
self.set_model_requires_grad('discriminator', False)
self.set_model_requires_grad('generator', True)
g_loss = self.loss.g_loss(self, data)
self.optimizers['generator'].zero_grad()
g_loss.backward()
self.optimizers['generator'].step()
def load(self, **kwargs):
super().load(**kwargs)
G = self.get_module(self.models['generator'])
D = self.get_module(self.models['discriminator'])
Gs = self.get_module(self.models['generator_smooth'])
if kwargs['running_metadata']:
lod = G.synthesis.lod.cpu().tolist()
assert lod == D.lod.cpu().tolist()
assert lod == Gs.synthesis.lod.cpu().tolist()
self.lod = lod
|