|
from typing import List |
|
|
|
from torch.optim import Optimizer |
|
from torch.optim.lr_scheduler import MultiStepLR |
|
|
|
|
|
class WarmUpMultiStepLR(MultiStepLR): |
|
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma: float = 0.1, |
|
factor: float = 0.3333, num_iters: int = 500, last_epoch: int = -1): |
|
self.factor = factor |
|
self.num_iters = num_iters |
|
super().__init__(optimizer, milestones, gamma, last_epoch) |
|
|
|
def get_lr(self) -> List[float]: |
|
if self.last_epoch < self.num_iters: |
|
alpha = self.last_epoch / self.num_iters |
|
factor = (1 - self.factor) * alpha + self.factor |
|
return [lr * factor for lr in super()._get_closed_form_lr()] |
|
else: |
|
factor = 1 |
|
return [lr for lr in super().get_lr()] |
|
|
|
return [lr * factor for lr in super()._get_closed_form_lr()] |
|
|