Spaces:
Running
Running
| import math | |
| from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR | |
| class ConstantLRScheduler(_LRScheduler): | |
| def __init__(self, | |
| optimizer, | |
| last_epoch: int = -1, | |
| verbose: bool = False, | |
| init_lr: float = 0., | |
| ): | |
| """ | |
| This is an implementation of constant learning rate scheduler. | |
| Args: | |
| optimizer: Optimizer | |
| last_epoch: The index of last epoch. Default: -1 | |
| verbose: If ``True``, prints a message to stdout for each update. Default: ``False`` | |
| init_lr: Initial learning rate | |
| """ | |
| self.init_lr = init_lr | |
| super().__init__(optimizer, last_epoch, verbose) | |
| def state_dict(self): | |
| state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]} | |
| return state_dict | |
| def load_state_dict(self, state_dict): | |
| self.__dict__.update(state_dict) | |
| def get_lr(self): | |
| if not self._get_lr_called_within_step: | |
| raise RuntimeError( | |
| "To get the last learning rate computed by the scheduler, use " | |
| "get_last_lr()" | |
| ) | |
| return [self.init_lr for group in self.optimizer.param_groups] | |
| class CosineAnnealingLRScheduler(_LRScheduler): | |
| def __init__(self, | |
| optimizer, | |
| last_epoch: int = -1, | |
| verbose: bool = False, | |
| init_lr: float = 0., | |
| max_lr: float = 4e-4, | |
| final_lr: float = 4e-5, | |
| warmup_steps: int = 2000, | |
| cosine_steps: int = 10000, | |
| ): | |
| """ | |
| This is an implementation of cosine annealing learning rate scheduler. | |
| Args: | |
| optimizer: Optimizer | |
| last_epoch: The index of last epoch. Default: -1 | |
| verbose: If ``True``, prints a message to stdout for each update. Default: ``False`` | |
| init_lr: Initial learning rate | |
| max_lr: Maximum learning rate after warmup | |
| final_lr: Final learning rate after decay | |
| warmup_steps: Number of steps for warmup | |
| cosine_steps: Number of steps for cosine annealing | |
| """ | |
| self.init_lr = init_lr | |
| self.max_lr = max_lr | |
| self.final_lr = final_lr | |
| self.warmup_steps = warmup_steps | |
| self.cosine_steps = cosine_steps | |
| super(CosineAnnealingLRScheduler, self).__init__(optimizer, last_epoch, verbose) | |
| def state_dict(self): | |
| state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]} | |
| return state_dict | |
| def load_state_dict(self, state_dict): | |
| self.__dict__.update(state_dict) | |
| def get_lr(self): | |
| if not self._get_lr_called_within_step: | |
| raise RuntimeError( | |
| "To get the last learning rate computed by the scheduler, use " | |
| "get_last_lr()" | |
| ) | |
| step_no = self.last_epoch | |
| if step_no <= self.warmup_steps: | |
| lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr) | |
| else: | |
| lr = self.final_lr + 0.5 * (self.max_lr - self.final_lr) \ | |
| * (1 + math.cos(math.pi * (step_no - self.warmup_steps) / self.cosine_steps)) | |
| return [lr for group in self.optimizer.param_groups] | |
| class Esm2LRScheduler(_LRScheduler): | |
| def __init__(self, | |
| optimizer, | |
| last_epoch: int = -1, | |
| verbose: bool = False, | |
| init_lr: float = 0., | |
| max_lr: float = 4e-4, | |
| final_lr: float = 4e-5, | |
| warmup_steps: int = 2000, | |
| start_decay_after_n_steps: int = 500000, | |
| end_decay_after_n_steps: int = 5000000, | |
| on_use: bool = True, | |
| ): | |
| """ | |
| This is an implementation of ESM2's learning rate scheduler. | |
| Args: | |
| optimizer: Optimizer | |
| last_epoch: The index of last epoch. Default: -1 | |
| verbose: If ``True``, prints a message to stdout for each update. Default: ``False`` | |
| init_lr: Initial learning rate | |
| max_lr: Maximum learning rate after warmup | |
| final_lr: Final learning rate after decay | |
| warmup_steps: Number of steps for warmup | |
| start_decay_after_n_steps: Start decay after this number of steps | |
| end_decay_after_n_steps: End decay after this number of steps | |
| on_use: Whether to use this scheduler. If ``False``, the scheduler will not change the learning rate | |
| and will only use the ``init_lr``. Default: ``True`` | |
| """ | |
| self.init_lr = init_lr | |
| self.max_lr = max_lr | |
| self.final_lr = final_lr | |
| self.warmup_steps = warmup_steps | |
| self.start_decay_after_n_steps = start_decay_after_n_steps | |
| self.end_decay_after_n_steps = end_decay_after_n_steps | |
| self.on_use = on_use | |
| super(Esm2LRScheduler, self).__init__(optimizer, last_epoch, verbose) | |
| def state_dict(self): | |
| state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]} | |
| return state_dict | |
| def load_state_dict(self, state_dict): | |
| self.__dict__.update(state_dict) | |
| def get_lr(self): | |
| if not self._get_lr_called_within_step: | |
| raise RuntimeError( | |
| "To get the last learning rate computed by the scheduler, use " | |
| "get_last_lr()" | |
| ) | |
| step_no = self.last_epoch | |
| if not self.on_use: | |
| return [base_lr for base_lr in self.base_lrs] | |
| if step_no <= self.warmup_steps: | |
| lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr) | |
| elif step_no <= self.start_decay_after_n_steps: | |
| lr = self.max_lr | |
| elif step_no <= self.end_decay_after_n_steps: | |
| portion = (step_no - self.start_decay_after_n_steps) / (self.end_decay_after_n_steps - self.start_decay_after_n_steps) | |
| lr = self.max_lr - portion * (self.max_lr - self.final_lr) | |
| else: | |
| lr = self.final_lr | |
| return [lr for group in self.optimizer.param_groups] |