File size: 889 Bytes
d6def08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from typing import List
from torch.optim import Optimizer
from torch.optim.lr_scheduler import MultiStepLR
class WarmUpMultiStepLR(MultiStepLR):
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma: float = 0.1,
factor: float = 0.3333, num_iters: int = 500, last_epoch: int = -1):
self.factor = factor
self.num_iters = num_iters
super().__init__(optimizer, milestones, gamma, last_epoch)
def get_lr(self) -> List[float]:
if self.last_epoch < self.num_iters:
alpha = self.last_epoch / self.num_iters
factor = (1 - self.factor) * alpha + self.factor
return [lr * factor for lr in super()._get_closed_form_lr()]
else:
factor = 1
return [lr for lr in super().get_lr()]
return [lr * factor for lr in super()._get_closed_form_lr()]
|