|
from torch.optim.lr_scheduler import _LRScheduler |
|
|
|
|
|
class PolyLr(_LRScheduler): |
|
def __init__(self, optimizer, gamma, max_iteration, minimum_lr=0, warmup_iteration=0, last_epoch=-1): |
|
self.gamma = gamma |
|
self.max_iteration = max_iteration |
|
self.minimum_lr = minimum_lr |
|
self.warmup_iteration = warmup_iteration |
|
|
|
self.last_epoch = None |
|
self.base_lrs = [] |
|
|
|
super(PolyLr, self).__init__(optimizer, last_epoch) |
|
|
|
def poly_lr(self, base_lr, step): |
|
return (base_lr - self.minimum_lr) * ((1 - (step / self.max_iteration)) ** self.gamma) + self.minimum_lr |
|
|
|
def warmup_lr(self, base_lr, alpha): |
|
return base_lr * (1 / 10.0 * (1 - alpha) + alpha) |
|
|
|
def get_lr(self): |
|
if self.last_epoch < self.warmup_iteration: |
|
alpha = self.last_epoch / self.warmup_iteration |
|
lrs = [min(self.warmup_lr(base_lr, alpha), self.poly_lr(base_lr, self.last_epoch)) for base_lr in |
|
self.base_lrs] |
|
else: |
|
lrs = [self.poly_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] |
|
|
|
return lrs |