File size: 795 Bytes
8c212a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# python3.7
"""Contains the running controller to clean cache."""

import torch

from .base_controller import BaseController

__all__ = ['CacheCleaner']


class CacheCleaner(BaseController):
    """Defines the running controller to clean cache.

    This controller is used to empty the GPU cache after each iteration.

    NOTE: The controller is set to `LAST` priority by default.
    """

    def __init__(self, config=None):
        config = config or dict()
        config.setdefault('priority', 'LAST')
        config.setdefault('every_n_iters', 1)
        super().__init__(config)

    def setup(self, runner):
        torch.cuda.empty_cache()

    def close(self, runner):
        torch.cuda.empty_cache()

    def execute_after_iteration(self, runner):
        torch.cuda.empty_cache()