Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,598 Bytes
e6ac593 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from ripe import utils
log = utils.get_pylogger(__name__)
class LinearWithPlateaus:
"""Linear scheduler with plateaus.
Linearly increases from `start_val` to `end_val`.
Stays at `start_val` for `plateau_start_steps` steps and at `end_val` for `plateau_end_steps` steps.
Linearly changes from `start_val` to `end_val` during the remaining steps.
"""
def __init__(
self,
start_val,
end_val,
steps_total,
rel_length_start_plateau=0.0,
rel_length_end_plateu=0.0,
):
self.start_val = start_val
self.end_val = end_val
self.steps_total = steps_total
self.plateau_start_steps = steps_total * rel_length_start_plateau
self.plateau_end_steps = steps_total * rel_length_end_plateu
assert self.plateau_start_steps >= 0
assert self.plateau_end_steps >= 0
assert self.plateau_start_steps + self.plateau_end_steps <= self.steps_total
self.slope = (end_val - start_val) / (steps_total - self.plateau_start_steps - self.plateau_end_steps)
log.info(
f"LinearWithPlateaus: start_val={start_val}, end_val={end_val}, steps_total={steps_total}, "
f"plateau_start_steps={self.plateau_start_steps}, plateau_end_steps={self.plateau_end_steps}"
)
def __call__(self, step):
if step < self.plateau_start_steps:
return self.start_val
if step < self.steps_total - self.plateau_end_steps:
return self.start_val + self.slope * (step - self.plateau_start_steps)
return self.end_val
|