File size: 795 Bytes
8c212a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# python3.7
"""Contains the running controller to clean cache."""
import torch
from .base_controller import BaseController
__all__ = ['CacheCleaner']
class CacheCleaner(BaseController):
"""Defines the running controller to clean cache.
This controller is used to empty the GPU cache after each iteration.
NOTE: The controller is set to `LAST` priority by default.
"""
def __init__(self, config=None):
config = config or dict()
config.setdefault('priority', 'LAST')
config.setdefault('every_n_iters', 1)
super().__init__(config)
def setup(self, runner):
torch.cuda.empty_cache()
def close(self, runner):
torch.cuda.empty_cache()
def execute_after_iteration(self, runner):
torch.cuda.empty_cache()
|