Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""Learning rate policy.""" | |
import math | |
def get_lr_at_epoch(cfg, cur_epoch): | |
""" | |
Retrieve the learning rate of the current epoch with the option to perform | |
warm up in the beginning of the training stage. | |
Args: | |
cfg (CfgNode): configs. Details can be found in | |
slowfast/config/defaults.py | |
cur_epoch (float): the number of epoch of the current training stage. | |
""" | |
lr = get_lr_func(cfg.SOLVER.LR_POLICY)(cfg, cur_epoch) | |
# Perform warm up. | |
if cur_epoch < cfg.SOLVER.WARMUP_EPOCHS: | |
lr_start = cfg.SOLVER.WARMUP_START_LR | |
lr_end = get_lr_func(cfg.SOLVER.LR_POLICY)( | |
cfg, cfg.SOLVER.WARMUP_EPOCHS | |
) | |
alpha = (lr_end - lr_start) / cfg.SOLVER.WARMUP_EPOCHS | |
lr = cur_epoch * alpha + lr_start | |
return lr | |
def lr_func_cosine(cfg, cur_epoch): | |
""" | |
Retrieve the learning rate to specified values at specified epoch with the | |
cosine learning rate schedule. Details can be found in: | |
Ilya Loshchilov, and Frank Hutter | |
SGDR: Stochastic Gradient Descent With Warm Restarts. | |
Args: | |
cfg (CfgNode): configs. Details can be found in | |
slowfast/config/defaults.py | |
cur_epoch (float): the number of epoch of the current training stage. | |
""" | |
assert cfg.SOLVER.COSINE_END_LR < cfg.SOLVER.BASE_LR | |
return ( | |
cfg.SOLVER.COSINE_END_LR | |
+ (cfg.SOLVER.BASE_LR - cfg.SOLVER.COSINE_END_LR) | |
* (math.cos(math.pi * cur_epoch / cfg.SOLVER.MAX_EPOCH) + 1.0) | |
* 0.5 | |
) | |
def lr_func_steps_with_relative_lrs(cfg, cur_epoch): | |
""" | |
Retrieve the learning rate to specified values at specified epoch with the | |
steps with relative learning rate schedule. | |
Args: | |
cfg (CfgNode): configs. Details can be found in | |
slowfast/config/defaults.py | |
cur_epoch (float): the number of epoch of the current training stage. | |
""" | |
ind = get_step_index(cfg, cur_epoch) | |
return cfg.SOLVER.LRS[ind] * cfg.SOLVER.BASE_LR | |
def get_step_index(cfg, cur_epoch): | |
""" | |
Retrieves the lr step index for the given epoch. | |
Args: | |
cfg (CfgNode): configs. Details can be found in | |
slowfast/config/defaults.py | |
cur_epoch (float): the number of epoch of the current training stage. | |
""" | |
steps = cfg.SOLVER.STEPS + [cfg.SOLVER.MAX_EPOCH] | |
for ind, step in enumerate(steps): # NoQA | |
if cur_epoch < step: | |
break | |
return ind - 1 | |
def get_lr_func(lr_policy): | |
""" | |
Given the configs, retrieve the specified lr policy function. | |
Args: | |
lr_policy (string): the learning rate policy to use for the job. | |
""" | |
policy = "lr_func_" + lr_policy | |
if policy not in globals(): | |
raise NotImplementedError("Unknown LR policy: {}".format(lr_policy)) | |
else: | |
return globals()[policy] | |