|
|
|
|
|
|
|
|
|
|
|
from collections.abc import Collection |
|
from dataclasses import dataclass, field |
|
from typing import List |
|
|
|
from omegaconf import II |
|
|
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler |
|
|
|
|
|
@dataclass |
|
class StepLRScheduleConfig(FairseqDataclass): |
|
warmup_updates: int = field( |
|
default=0, |
|
metadata={"help": "warmup the learning rate linearly for the first N updates"}, |
|
) |
|
warmup_init_lr: float = field( |
|
default=-1, |
|
metadata={ |
|
"help": "initial learning rate during warmup phase; default is cfg.lr" |
|
}, |
|
) |
|
lr: List[float] = field( |
|
default=II("optimization.lr"), |
|
metadata={"help": "max learning rate, must be more than cfg.min_lr"}, |
|
) |
|
min_lr: float = field(default=0.0, metadata={"help": "min learning rate"}) |
|
lr_deacy_period: int = field(default=25000, metadata={"help": "decay period"}) |
|
lr_decay: float = field(default=0.5, metadata={"help": "decay factor"}) |
|
|
|
|
|
@register_lr_scheduler("step", dataclass=StepLRScheduleConfig) |
|
class StepLRSchedule(FairseqLRScheduler): |
|
"""Decay learning rate every k updates by a fixed factor |
|
""" |
|
|
|
def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer): |
|
super().__init__(cfg, fairseq_optimizer) |
|
self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr |
|
self.min_lr = cfg.min_lr |
|
self.lr_deacy_period = cfg.lr_deacy_period |
|
self.lr_decay = cfg.lr_decay |
|
self.warmup_updates = cfg.warmup_updates |
|
self.warmup_init_lr = ( |
|
cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr |
|
) |
|
|
|
assert(self.lr_deacy_period > 0) |
|
assert(self.lr_decay <= 1) |
|
assert(self.min_lr >= 0) |
|
assert(self.max_lr > self.min_lr) |
|
|
|
if cfg.warmup_updates > 0: |
|
|
|
self.warmup_lr_step = ( |
|
(self.max_lr - self.warmup_init_lr) / self.warmup_updates |
|
) |
|
else: |
|
self.warmup_lr_step = 1 |
|
|
|
|
|
self.lr = self.warmup_init_lr |
|
self.optimizer.set_lr(self.lr) |
|
|
|
def step(self, epoch, val_loss=None): |
|
"""Update the learning rate at the end of the given epoch.""" |
|
super().step(epoch, val_loss) |
|
|
|
return self.optimizer.get_lr() |
|
|
|
def step_update(self, num_updates): |
|
"""Update the learning rate after each update.""" |
|
if num_updates < self.cfg.warmup_updates: |
|
self.lr = self.warmup_init_lr + num_updates * self.warmup_lr_step |
|
else: |
|
curr_updates = num_updates - self.cfg.warmup_updates |
|
lr_mult = self.lr_decay ** (curr_updates // self.lr_deacy_period) |
|
self.lr = max(self.max_lr * lr_mult, self.min_lr) |
|
|
|
self.optimizer.set_lr(self.lr) |
|
return self.lr |
|
|