disentangled-image-editing-final-project
/
ContraCLIP
/models
/genforce
/runners
/controllers
/checkpointer.py
# python3.7 | |
"""Contains the running controller to handle checkpoints.""" | |
import os.path | |
from .base_controller import BaseController | |
__all__ = ['Checkpointer'] | |
class Checkpointer(BaseController): | |
"""Defines the running controller to handle checkpoints. | |
This controller is used to save and load checkpoints. | |
NOTE: This controller is set to `LAST` priority by default and will only be | |
executed on the master worker. | |
""" | |
def __init__(self, config): | |
assert isinstance(config, dict) | |
config.setdefault('priority', 'LAST') | |
config.setdefault('master_only', True) | |
super().__init__(config) | |
self._save_dir = config.get('checkpoint_dir', None) | |
self._save_running_metadata = config.get('save_running_metadata', True) | |
self._save_learning_rate = config.get('save_learning_rate', True) | |
self._save_optimizer = config.get('save_optimizer', True) | |
self._save_running_stats = config.get('save_running_stats', False) | |
def execute_after_iteration(self, runner): | |
save_dir = self._save_dir or runner.work_dir | |
save_filename = f'checkpoint_iter{runner.iter:06d}.pth' | |
runner.save(filepath=os.path.join(save_dir, save_filename), | |
running_metadata=self._save_running_metadata, | |
learning_rate=self._save_learning_rate, | |
optimizer=self._save_optimizer, | |
running_stats=self._save_running_stats) | |