Spaces:
Configuration error
Configuration error
import torch | |
from lib.utils.optimizer.radam import RAdam | |
_optimizer_factory = { | |
'adam': torch.optim.Adam, | |
'radam': RAdam, | |
'sgd': torch.optim.SGD | |
} | |
def make_optimizer(cfg, net, lr=None, weight_decay=None): | |
params = [] | |
lr = cfg.train.lr if lr is None else lr | |
weight_decay = cfg.train.weight_decay if weight_decay is None else weight_decay | |
for key, value in net.named_parameters(): | |
if not value.requires_grad: | |
continue | |
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] | |
if 'adam' in cfg.train.optim: | |
optimizer = _optimizer_factory[cfg.train.optim](params, lr, weight_decay=weight_decay) | |
else: | |
optimizer = _optimizer_factory[cfg.train.optim](params, lr, momentum=0.9) | |
return optimizer | |