sapiens-pose / external /det /mmdet /testing /_fast_stop_training_hook.py
rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from mmengine.hooks import Hook
from mmdet.registry import HOOKS
@HOOKS.register_module()
class FastStopTrainingHook(Hook):
"""Set runner's epoch information to the model."""
def __init__(self, by_epoch, save_ckpt=False, stop_iter_or_epoch=5):
self.by_epoch = by_epoch
self.save_ckpt = save_ckpt
self.stop_iter_or_epoch = stop_iter_or_epoch
def after_train_iter(self, runner, batch_idx: int, data_batch: None,
outputs: None) -> None:
if self.save_ckpt and self.by_epoch:
# If it is epoch-based and want to save weights,
# we must run at least 1 epoch.
return
if runner.iter >= self.stop_iter_or_epoch:
raise RuntimeError('quick exit')
def after_train_epoch(self, runner) -> None:
if runner.epoch >= self.stop_iter_or_epoch - 1:
raise RuntimeError('quick exit')