dattarij's picture
adding ContraCLIP folder
8c212a5
raw
history blame
1.47 kB
# python3.7
"""Contains the running controller to handle checkpoints."""
import os.path
from .base_controller import BaseController
__all__ = ['Checkpointer']
class Checkpointer(BaseController):
"""Defines the running controller to handle checkpoints.
This controller is used to save and load checkpoints.
NOTE: This controller is set to `LAST` priority by default and will only be
executed on the master worker.
"""
def __init__(self, config):
assert isinstance(config, dict)
config.setdefault('priority', 'LAST')
config.setdefault('master_only', True)
super().__init__(config)
self._save_dir = config.get('checkpoint_dir', None)
self._save_running_metadata = config.get('save_running_metadata', True)
self._save_learning_rate = config.get('save_learning_rate', True)
self._save_optimizer = config.get('save_optimizer', True)
self._save_running_stats = config.get('save_running_stats', False)
def execute_after_iteration(self, runner):
save_dir = self._save_dir or runner.work_dir
save_filename = f'checkpoint_iter{runner.iter:06d}.pth'
runner.save(filepath=os.path.join(save_dir, save_filename),
running_metadata=self._save_running_metadata,
learning_rate=self._save_learning_rate,
optimizer=self._save_optimizer,
running_stats=self._save_running_stats)