| from torch.optim.lr_scheduler import LRScheduler | |
| class InterpolatingLogScheduler(LRScheduler): | |
| def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1): | |
| """A scheduler that interpolates learning rates in a logarithmic fashion | |
| Args: | |
| - optimizer: pytorch optimizer | |
| - num_steps: int, the number of steps over which to increase from the min_lr to the max_lr | |
| - min_lr: float, the minimum learning rate | |
| - max_lr: float, the maximum learning rate | |
| Usage: | |
| fc = nn.Linear(1,1) | |
| optimizer = optim.Adam(fc.parameters()) | |
| lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4) | |
| """ | |
| self.num_steps = num_steps | |
| self.min_lr = min_lr | |
| self.max_lr = max_lr | |
| self.q = (max_lr / min_lr) ** (1 / (num_steps - 1)) | |
| super().__init__(optimizer, last_epoch) | |
| def get_lr(self): | |
| if self.last_epoch <= 0: | |
| lrs = [self.min_lr for base_lr in self.base_lrs] | |
| elif self.last_epoch < self.num_steps: | |
| lrs = [ | |
| self.min_lr * (self.q ** (self.last_epoch - 1)) | |
| for base_lr in self.base_lrs | |
| ] | |
| else: | |
| lrs = [self.max_lr for base_lr in self.base_lrs] | |
| return lrs | |