sadimanna's picture
Upload 20 files
d6def08
raw
history blame contribute delete
889 Bytes
from typing import List
from torch.optim import Optimizer
from torch.optim.lr_scheduler import MultiStepLR
class WarmUpMultiStepLR(MultiStepLR):
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma: float = 0.1,
factor: float = 0.3333, num_iters: int = 500, last_epoch: int = -1):
self.factor = factor
self.num_iters = num_iters
super().__init__(optimizer, milestones, gamma, last_epoch)
def get_lr(self) -> List[float]:
if self.last_epoch < self.num_iters:
alpha = self.last_epoch / self.num_iters
factor = (1 - self.factor) * alpha + self.factor
return [lr * factor for lr in super()._get_closed_form_lr()]
else:
factor = 1
return [lr for lr in super().get_lr()]
return [lr * factor for lr in super()._get_closed_form_lr()]