lolcats / src /trainer /optim.py
ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
raw
history blame
1.73 kB
"""
Optimizer and schedulers
"""
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
def get_optimizer(optim: str, model: nn.Module, **kwargs: any) -> Optimizer:
"""
Return training optimizer
"""
if optim == 'sgd':
return torch.optim.SGD(model.parameters(), **kwargs)
elif optim == 'adam':
return torch.optim.Adam(model.parameters(), **kwargs)
elif optim in ['adamw', 'adamw_torch']:
return torch.optim.AdamW(model.parameters(), **kwargs)
elif optim == 'adamw_torch_fused':
return torch.optim.AdamW(model.parameters(), **kwargs, fused=True)
elif optim == 'adafactor':
from transformers import Adafactor
kwargs['relative_step'] = False # for now
return Adafactor(model.parameters(), **kwargs)
else:
raise NotImplementedError(f"{optim} optimizer not implemented sorry.")
def get_scheduler(lr_scheduler_type: str, optimizer: Optimizer,
**kwargs: any) -> LRScheduler:
"""
Return learning rate scheduler
"""
if lr_scheduler_type in ['plateau', 'reduce_lr_on_plateau']:
from torch.optim.lr_scheduler import ReduceLROnPlateau
return ReduceLROnPlateau(optimizer=optimizer, **kwargs)
elif lr_scheduler_type == 'cosine_warmup':
from transformers import get_cosine_schedule_with_warmup
return get_cosine_schedule_with_warmup(optimizer=optimizer, **kwargs)
elif lr_scheduler_type in ['linear_warmup', 'linear']:
from transformers import get_linear_schedule_with_warmup
return get_linear_schedule_with_warmup(optimizer=optimizer, **kwargs)
else:
return None