disentangled-image-editing-final-project
/
ContraCLIP
/models
/genforce
/runners
/controllers
/progress_scheduler.py
# python3.7 | |
"""Contains the running controller to control progressive training. | |
This controller is applicable to the models that need to progressively change | |
the batch size, learning rate, etc. | |
""" | |
import numpy as np | |
from .base_controller import BaseController | |
__all__ = ['ProgressScheduler'] | |
_BATCH_SIZE_SCHEDULE_DICT = { | |
4: 16, 8: 8, 16: 4, 32: 2, 64: 1, 128: 1, 256: 1, 512: 1, 1024: 1, | |
} | |
_MAX_BATCH_SIZE = 64 | |
_LEARNING_RATE_SCHEDULE_DICT = { | |
4: 1, 8: 1, 16: 1, 32: 1, 64: 1, 128: 1.5, 256: 2, 512: 3, 1024: 3, | |
} | |
class ProgressScheduler(BaseController): | |
"""Defines the running controller to control progressive training. | |
NOTE: The controller is set to `HIGH` priority by default. | |
""" | |
def __init__(self, config): | |
assert isinstance(config, dict) | |
config.setdefault('priority', 'HIGH') | |
config.setdefault('every_n_iters', 1) | |
super().__init__(config) | |
self.base_batch_size = 0 | |
self.base_lrs = dict() | |
self.total_img = 0 | |
self.init_res = config.get('init_res', 4) | |
self.final_res = self.init_res | |
self.init_lod = 0 | |
self.batch_size_schedule = config.get('batch_size_schedule', dict()) | |
self.lr_schedule = config.get('lr_schedule', dict()) | |
self.minibatch_repeats = config.get('minibatch_repeats', 4) | |
self.lod_training_img = config.get('lod_training_img', 600_000) | |
self.lod_transition_img = config.get('lod_transition_img', 600_000) | |
self.lod_duration = (self.lod_training_img + self.lod_transition_img) | |
# Whether to reset the optimizer state at the beginning of each phase. | |
self.reset_optimizer = config.get('reset_optimizer', True) | |
def get_batch_size(self, resolution): | |
"""Gets batch size for a particular resolution.""" | |
if self.batch_size_schedule: | |
return self.batch_size_schedule.get( | |
f'res{resolution}', self.base_batch_size) | |
batch_size_scale = _BATCH_SIZE_SCHEDULE_DICT[resolution] | |
return min(_MAX_BATCH_SIZE, self.base_batch_size * batch_size_scale) | |
def get_lr_scale(self, resolution): | |
"""Gets learning rate scale for a particular resolution.""" | |
if self.lr_schedule: | |
return self.lr_schedule.get(f'res{resolution}', 1) | |
return _LEARNING_RATE_SCHEDULE_DICT[resolution] | |
def setup(self, runner): | |
# Set level of detail (lod). | |
self.final_res = runner.resolution | |
self.init_lod = np.log2(self.final_res // self.init_res) | |
runner.lod = -1.0 | |
# Save default batch size and learning rate. | |
self.base_batch_size = runner.batch_size | |
for lr_name, lr_scheduler in runner.lr_schedulers.items(): | |
self.base_lrs[lr_name] = lr_scheduler.base_lrs | |
# Add running stats for logging. | |
runner.running_stats.add( | |
'kimg', log_format='7.1f', log_name='kimg', log_strategy='CURRENT') | |
runner.running_stats.add( | |
'lod', log_format='4.2f', log_name='lod', log_strategy='CURRENT') | |
runner.running_stats.add( | |
'minibatch', log_format='4d', log_name='minibatch', | |
log_strategy='CURRENT') | |
# Log progressive schedule. | |
runner.logger.info(f'Progressive Schedule:') | |
res = self.init_res | |
lod = int(self.init_lod) | |
while res <= self.final_res: | |
batch_size = self.get_batch_size(res) | |
lr_scale = self.get_lr_scale(res) | |
runner.logger.info(f' Resolution {res:4d} (lod {lod}): ' | |
f'batch size ' | |
f'{batch_size:3d} * {runner.world_size:2d}, ' | |
f'learning rate scale {lr_scale:.1f}') | |
res *= 2 | |
lod -= 1 | |
assert lod == -1 and res == self.final_res * 2 | |
# Compute total running iterations. | |
assert hasattr(runner.config, 'total_img') | |
self.total_img = runner.config.total_img | |
current_img = 0 | |
num_iters = 0 | |
while current_img < self.total_img: | |
phase = (current_img + self.lod_transition_img) // self.lod_duration | |
phase = np.clip(phase, 0, self.init_lod) | |
if num_iters % self.minibatch_repeats == 0: | |
resolution = self.init_res * (2 ** int(phase)) | |
current_img += self.get_batch_size(resolution) * runner.world_size | |
num_iters += 1 | |
runner.total_iters = num_iters | |
def execute_before_iteration(self, runner): | |
is_first_iter = (runner.iter - runner.start_iter == 1) | |
# Adjust hyper-parameters only at some particular iteration. | |
if (not is_first_iter) and (runner.iter % self.minibatch_repeats != 1): | |
return | |
# Compute level-of-details. | |
phase, subphase = divmod(runner.seen_img, self.lod_duration) | |
lod = self.init_lod - phase | |
if self.lod_transition_img: | |
transition_img = max(subphase - self.lod_training_img, 0) | |
lod = lod - transition_img / self.lod_transition_img | |
lod = max(lod, 0.0) | |
resolution = self.init_res * (2 ** int(np.ceil(self.init_lod - lod))) | |
batch_size = self.get_batch_size(resolution) | |
lr_scale = self.get_lr_scale(resolution) | |
pre_lod = runner.lod | |
pre_resolution = runner.train_loader.dataset.resolution | |
runner.lod = lod | |
# Reset optimizer state if needed. | |
if self.reset_optimizer: | |
if int(lod) != int(pre_lod) or np.ceil(lod) != np.ceil(pre_lod): | |
runner.logger.info(f'Reset the optimizer state at ' | |
f'iter {runner.iter:06d} (lod {lod:.6f}).') | |
for name in runner.optimizers: | |
runner.optimizers[name].state.clear() | |
# Rebuild the dataset and adjust the learing rate if needed. | |
if is_first_iter or resolution != pre_resolution: | |
runner.logger.info(f'Rebuild the dataset at ' | |
f'iter {runner.iter:06d} (lod {lod:.6f}).') | |
runner.train_loader.overwrite_param( | |
batch_size=batch_size, resolution=resolution) | |
runner.batch_size = batch_size | |
for lr_name, base_lrs in self.base_lrs.items(): | |
runner.lr_schedulers[lr_name].base_lrs = [ | |
lr * lr_scale for lr in base_lrs] | |
def execute_after_iteration(self, runner): | |
minibatch = runner.batch_size * runner.world_size | |
runner.running_stats.update({'kimg': runner.seen_img / 1000}) | |
runner.running_stats.update({'lod': runner.lod}) | |
runner.running_stats.update({'minibatch': minibatch}) | |