Spaces:
Running
Running
import math | |
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR | |
class ConstantLRScheduler(_LRScheduler): | |
def __init__(self, | |
optimizer, | |
last_epoch: int = -1, | |
verbose: bool = False, | |
init_lr: float = 0., | |
): | |
""" | |
This is an implementation of constant learning rate scheduler. | |
Args: | |
optimizer: Optimizer | |
last_epoch: The index of last epoch. Default: -1 | |
verbose: If ``True``, prints a message to stdout for each update. Default: ``False`` | |
init_lr: Initial learning rate | |
""" | |
self.init_lr = init_lr | |
super().__init__(optimizer, last_epoch, verbose) | |
def state_dict(self): | |
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]} | |
return state_dict | |
def load_state_dict(self, state_dict): | |
self.__dict__.update(state_dict) | |
def get_lr(self): | |
if not self._get_lr_called_within_step: | |
raise RuntimeError( | |
"To get the last learning rate computed by the scheduler, use " | |
"get_last_lr()" | |
) | |
return [self.init_lr for group in self.optimizer.param_groups] | |
class CosineAnnealingLRScheduler(_LRScheduler): | |
def __init__(self, | |
optimizer, | |
last_epoch: int = -1, | |
verbose: bool = False, | |
init_lr: float = 0., | |
max_lr: float = 4e-4, | |
final_lr: float = 4e-5, | |
warmup_steps: int = 2000, | |
cosine_steps: int = 10000, | |
): | |
""" | |
This is an implementation of cosine annealing learning rate scheduler. | |
Args: | |
optimizer: Optimizer | |
last_epoch: The index of last epoch. Default: -1 | |
verbose: If ``True``, prints a message to stdout for each update. Default: ``False`` | |
init_lr: Initial learning rate | |
max_lr: Maximum learning rate after warmup | |
final_lr: Final learning rate after decay | |
warmup_steps: Number of steps for warmup | |
cosine_steps: Number of steps for cosine annealing | |
""" | |
self.init_lr = init_lr | |
self.max_lr = max_lr | |
self.final_lr = final_lr | |
self.warmup_steps = warmup_steps | |
self.cosine_steps = cosine_steps | |
super(CosineAnnealingLRScheduler, self).__init__(optimizer, last_epoch, verbose) | |
def state_dict(self): | |
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]} | |
return state_dict | |
def load_state_dict(self, state_dict): | |
self.__dict__.update(state_dict) | |
def get_lr(self): | |
if not self._get_lr_called_within_step: | |
raise RuntimeError( | |
"To get the last learning rate computed by the scheduler, use " | |
"get_last_lr()" | |
) | |
step_no = self.last_epoch | |
if step_no <= self.warmup_steps: | |
lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr) | |
else: | |
lr = self.final_lr + 0.5 * (self.max_lr - self.final_lr) \ | |
* (1 + math.cos(math.pi * (step_no - self.warmup_steps) / self.cosine_steps)) | |
return [lr for group in self.optimizer.param_groups] | |
class Esm2LRScheduler(_LRScheduler): | |
def __init__(self, | |
optimizer, | |
last_epoch: int = -1, | |
verbose: bool = False, | |
init_lr: float = 0., | |
max_lr: float = 4e-4, | |
final_lr: float = 4e-5, | |
warmup_steps: int = 2000, | |
start_decay_after_n_steps: int = 500000, | |
end_decay_after_n_steps: int = 5000000, | |
on_use: bool = True, | |
): | |
""" | |
This is an implementation of ESM2's learning rate scheduler. | |
Args: | |
optimizer: Optimizer | |
last_epoch: The index of last epoch. Default: -1 | |
verbose: If ``True``, prints a message to stdout for each update. Default: ``False`` | |
init_lr: Initial learning rate | |
max_lr: Maximum learning rate after warmup | |
final_lr: Final learning rate after decay | |
warmup_steps: Number of steps for warmup | |
start_decay_after_n_steps: Start decay after this number of steps | |
end_decay_after_n_steps: End decay after this number of steps | |
on_use: Whether to use this scheduler. If ``False``, the scheduler will not change the learning rate | |
and will only use the ``init_lr``. Default: ``True`` | |
""" | |
self.init_lr = init_lr | |
self.max_lr = max_lr | |
self.final_lr = final_lr | |
self.warmup_steps = warmup_steps | |
self.start_decay_after_n_steps = start_decay_after_n_steps | |
self.end_decay_after_n_steps = end_decay_after_n_steps | |
self.on_use = on_use | |
super(Esm2LRScheduler, self).__init__(optimizer, last_epoch, verbose) | |
def state_dict(self): | |
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]} | |
return state_dict | |
def load_state_dict(self, state_dict): | |
self.__dict__.update(state_dict) | |
def get_lr(self): | |
if not self._get_lr_called_within_step: | |
raise RuntimeError( | |
"To get the last learning rate computed by the scheduler, use " | |
"get_last_lr()" | |
) | |
step_no = self.last_epoch | |
if not self.on_use: | |
return [base_lr for base_lr in self.base_lrs] | |
if step_no <= self.warmup_steps: | |
lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr) | |
elif step_no <= self.start_decay_after_n_steps: | |
lr = self.max_lr | |
elif step_no <= self.end_decay_after_n_steps: | |
portion = (step_no - self.start_decay_after_n_steps) / (self.end_decay_after_n_steps - self.start_decay_after_n_steps) | |
lr = self.max_lr - portion * (self.max_lr - self.final_lr) | |
else: | |
lr = self.final_lr | |
return [lr for group in self.optimizer.param_groups] |