|
""" |
|
Optimizer and schedulers |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
from torch.optim import Optimizer |
|
from torch.optim.lr_scheduler import LRScheduler |
|
|
|
|
|
def get_optimizer(optim: str, model: nn.Module, **kwargs: any) -> Optimizer: |
|
""" |
|
Return training optimizer |
|
""" |
|
if optim == 'sgd': |
|
return torch.optim.SGD(model.parameters(), **kwargs) |
|
elif optim == 'adam': |
|
return torch.optim.Adam(model.parameters(), **kwargs) |
|
elif optim in ['adamw', 'adamw_torch']: |
|
return torch.optim.AdamW(model.parameters(), **kwargs) |
|
elif optim == 'adamw_torch_fused': |
|
return torch.optim.AdamW(model.parameters(), **kwargs, fused=True) |
|
elif optim == 'adafactor': |
|
from transformers import Adafactor |
|
kwargs['relative_step'] = False |
|
return Adafactor(model.parameters(), **kwargs) |
|
else: |
|
raise NotImplementedError(f"{optim} optimizer not implemented sorry.") |
|
|
|
|
|
def get_scheduler(lr_scheduler_type: str, optimizer: Optimizer, |
|
**kwargs: any) -> LRScheduler: |
|
""" |
|
Return learning rate scheduler |
|
""" |
|
if lr_scheduler_type in ['plateau', 'reduce_lr_on_plateau']: |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
return ReduceLROnPlateau(optimizer=optimizer, **kwargs) |
|
|
|
elif lr_scheduler_type == 'cosine_warmup': |
|
from transformers import get_cosine_schedule_with_warmup |
|
return get_cosine_schedule_with_warmup(optimizer=optimizer, **kwargs) |
|
|
|
elif lr_scheduler_type in ['linear_warmup', 'linear']: |
|
from transformers import get_linear_schedule_with_warmup |
|
return get_linear_schedule_with_warmup(optimizer=optimizer, **kwargs) |
|
|
|
else: |
|
return None |