from torch.optim.lr_scheduler import _LRScheduler class PolynomialLR(_LRScheduler): def __init__( self, optimizer, step_size, iter_warmup, iter_max, power, min_lr=0, last_epoch=-1, ): self.step_size = step_size self.iter_warmup = int(iter_warmup) self.iter_max = int(iter_max) self.power = power self.min_lr = min_lr super(PolynomialLR, self).__init__(optimizer, last_epoch) def polynomial_decay(self, lr): iter_cur = float(self.last_epoch) if iter_cur < self.iter_warmup: coef = iter_cur / self.iter_warmup coef *= (1 - self.iter_warmup / self.iter_max) ** self.power else: coef = (1 - iter_cur / self.iter_max) ** self.power return (lr - self.min_lr) * coef + self.min_lr def get_lr(self): if ( (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0) or (self.last_epoch > self.iter_max) ): return [group["lr"] for group in self.optimizer.param_groups] return [self.polynomial_decay(lr) for lr in self.base_lrs] def step_update(self, num_updates): self.step()