|
import math |
|
from torch.optim.lr_scheduler import LambdaLR |
|
from functools import partial |
|
|
|
|
|
def get_scheduler( |
|
optimizer, |
|
start_lr, |
|
max_lr, |
|
min_lr, |
|
warmup_epochs, |
|
sustain_epochs, |
|
total_epochs, |
|
decay, |
|
mode="cosine", |
|
): |
|
def lr_lambda(epoch): |
|
if epoch < warmup_epochs: |
|
return (max_lr - start_lr) / warmup_epochs * epoch + start_lr |
|
|
|
elif epoch < warmup_epochs + sustain_epochs: |
|
return max_lr |
|
|
|
elif mode == "exponential": |
|
return (max_lr - min_lr) * decay ** ( |
|
epoch - warmup_epochs - sustain_epochs |
|
) + min_lr |
|
|
|
elif mode == "step": |
|
return max_lr * decay ** ((epoch - warmup_epochs - sustain_epochs) // 2) |
|
|
|
elif mode == "cosine": |
|
decay_total_epochs = total_epochs - warmup_epochs - sustain_epochs + 3 |
|
decay_epoch_index = epoch - warmup_epochs - sustain_epochs |
|
phase = math.pi * decay_epoch_index / decay_total_epochs |
|
cosine_decay = 0.5 * (1 + math.cos(phase)) |
|
return (max_lr - min_lr) * cosine_decay + min_lr |
|
|
|
else: |
|
raise ValueError( |
|
f"Unsupported mode '{mode}'. Supported modes are 'exp', 'step', 'cosine'." |
|
) |
|
|
|
return LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
def _get_cosine_schedule_with_warmup_lr_lambda( |
|
current_step: int, |
|
*, |
|
num_warmup_steps: int, |
|
num_training_steps: int, |
|
num_cycles: float, |
|
): |
|
if current_step < num_warmup_steps: |
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
progress = float(current_step - num_warmup_steps) / float( |
|
max(1, num_training_steps - num_warmup_steps) |
|
) |
|
return max( |
|
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) |
|
) |
|
|
|
|
|
def get_cosine_schedule_with_warmup( |
|
optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1 |
|
): |
|
""" |
|
Create a schedule with a learning rate that decreases following the values of the cosine function between the |
|
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the |
|
initial lr set in the optimizer. |
|
|
|
Args: |
|
optimizer ([`~torch.optim.Optimizer`]): |
|
The optimizer for which to schedule the learning rate. |
|
num_warmup_steps (`int`): |
|
The number of steps for the warmup phase. |
|
num_training_steps (`int`): |
|
The total number of training steps. |
|
num_cycles (`float`, *optional*, defaults to 0.5): |
|
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 |
|
following a half-cosine). |
|
last_epoch (`int`, *optional*, defaults to -1): |
|
The index of the last epoch when resuming training. |
|
|
|
Return: |
|
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
|
""" |
|
|
|
lr_lambda = partial( |
|
_get_cosine_schedule_with_warmup_lr_lambda, |
|
num_warmup_steps=num_warmup_steps, |
|
num_training_steps=num_training_steps, |
|
num_cycles=num_cycles, |
|
) |
|
return LambdaLR(optimizer, lr_lambda, last_epoch) |
|
|