my-cool-model / lib /optim /scheduler.py
crapthings's picture
Upload folder using huggingface_hub
f7f604d
raw
history blame
1.12 kB
from torch.optim.lr_scheduler import _LRScheduler
class PolyLr(_LRScheduler):
def __init__(self, optimizer, gamma, max_iteration, minimum_lr=0, warmup_iteration=0, last_epoch=-1):
self.gamma = gamma
self.max_iteration = max_iteration
self.minimum_lr = minimum_lr
self.warmup_iteration = warmup_iteration
self.last_epoch = None
self.base_lrs = []
super(PolyLr, self).__init__(optimizer, last_epoch)
def poly_lr(self, base_lr, step):
return (base_lr - self.minimum_lr) * ((1 - (step / self.max_iteration)) ** self.gamma) + self.minimum_lr
def warmup_lr(self, base_lr, alpha):
return base_lr * (1 / 10.0 * (1 - alpha) + alpha)
def get_lr(self):
if self.last_epoch < self.warmup_iteration:
alpha = self.last_epoch / self.warmup_iteration
lrs = [min(self.warmup_lr(base_lr, alpha), self.poly_lr(base_lr, self.last_epoch)) for base_lr in
self.base_lrs]
else:
lrs = [self.poly_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
return lrs