File size: 1,598 Bytes
e6ac593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from ripe import utils

log = utils.get_pylogger(__name__)


class LinearWithPlateaus:
    """Linear scheduler with plateaus.

    Linearly increases from `start_val` to `end_val`.
    Stays at `start_val` for `plateau_start_steps` steps and at `end_val` for `plateau_end_steps` steps.
    Linearly changes from `start_val` to `end_val` during the remaining steps.
    """

    def __init__(
        self,
        start_val,
        end_val,
        steps_total,
        rel_length_start_plateau=0.0,
        rel_length_end_plateu=0.0,
    ):
        self.start_val = start_val
        self.end_val = end_val
        self.steps_total = steps_total
        self.plateau_start_steps = steps_total * rel_length_start_plateau
        self.plateau_end_steps = steps_total * rel_length_end_plateu

        assert self.plateau_start_steps >= 0
        assert self.plateau_end_steps >= 0
        assert self.plateau_start_steps + self.plateau_end_steps <= self.steps_total

        self.slope = (end_val - start_val) / (steps_total - self.plateau_start_steps - self.plateau_end_steps)

        log.info(
            f"LinearWithPlateaus: start_val={start_val}, end_val={end_val}, steps_total={steps_total}, "
            f"plateau_start_steps={self.plateau_start_steps}, plateau_end_steps={self.plateau_end_steps}"
        )

    def __call__(self, step):
        if step < self.plateau_start_steps:
            return self.start_val
        if step < self.steps_total - self.plateau_end_steps:
            return self.start_val + self.slope * (step - self.plateau_start_steps)
        return self.end_val