NeuralBody / lib /train /optimizer.py
pengsida
initial commit
1ba539f
raw
history blame
793 Bytes
import torch
from lib.utils.optimizer.radam import RAdam
_optimizer_factory = {
'adam': torch.optim.Adam,
'radam': RAdam,
'sgd': torch.optim.SGD
}
def make_optimizer(cfg, net, lr=None, weight_decay=None):
params = []
lr = cfg.train.lr if lr is None else lr
weight_decay = cfg.train.weight_decay if weight_decay is None else weight_decay
for key, value in net.named_parameters():
if not value.requires_grad:
continue
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
if 'adam' in cfg.train.optim:
optimizer = _optimizer_factory[cfg.train.optim](params, lr, weight_decay=weight_decay)
else:
optimizer = _optimizer_factory[cfg.train.optim](params, lr, momentum=0.9)
return optimizer