Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""Optimizer.""" | |
import torch | |
import timesformer.utils.lr_policy as lr_policy | |
def construct_optimizer(model, cfg): | |
""" | |
Construct a stochastic gradient descent or ADAM optimizer with momentum. | |
Details can be found in: | |
Herbert Robbins, and Sutton Monro. "A stochastic approximation method." | |
and | |
Diederik P.Kingma, and Jimmy Ba. | |
"Adam: A Method for Stochastic Optimization." | |
Args: | |
model (model): model to perform stochastic gradient descent | |
optimization or ADAM optimization. | |
cfg (config): configs of hyper-parameters of SGD or ADAM, includes base | |
learning rate, momentum, weight_decay, dampening, and etc. | |
""" | |
# Batchnorm parameters. | |
bn_params = [] | |
# Non-batchnorm parameters. | |
non_bn_parameters = [] | |
for name, p in model.named_parameters(): | |
if "bn" in name: | |
bn_params.append(p) | |
else: | |
non_bn_parameters.append(p) | |
# Apply different weight decay to Batchnorm and non-batchnorm parameters. | |
# In Caffe2 classification codebase the weight decay for batchnorm is 0.0. | |
# Having a different weight decay on batchnorm might cause a performance | |
# drop. | |
optim_params = [ | |
{"params": bn_params, "weight_decay": cfg.BN.WEIGHT_DECAY}, | |
{"params": non_bn_parameters, "weight_decay": cfg.SOLVER.WEIGHT_DECAY}, | |
] | |
# Check all parameters will be passed into optimizer. | |
assert len(list(model.parameters())) == len(non_bn_parameters) + len( | |
bn_params | |
), "parameter size does not match: {} + {} != {}".format( | |
len(non_bn_parameters), len(bn_params), len(list(model.parameters())) | |
) | |
if cfg.SOLVER.OPTIMIZING_METHOD == "sgd": | |
return torch.optim.SGD( | |
optim_params, | |
lr=cfg.SOLVER.BASE_LR, | |
momentum=cfg.SOLVER.MOMENTUM, | |
weight_decay=cfg.SOLVER.WEIGHT_DECAY, | |
dampening=cfg.SOLVER.DAMPENING, | |
nesterov=cfg.SOLVER.NESTEROV, | |
) | |
elif cfg.SOLVER.OPTIMIZING_METHOD == "adam": | |
return torch.optim.Adam( | |
optim_params, | |
lr=cfg.SOLVER.BASE_LR, | |
betas=(0.9, 0.999), | |
eps=1e-08, | |
weight_decay=cfg.SOLVER.WEIGHT_DECAY, | |
) | |
elif cfg.SOLVER.OPTIMIZING_METHOD == "adamw": | |
return torch.optim.AdamW( | |
optim_params, | |
lr=cfg.SOLVER.BASE_LR, | |
betas=(0.9, 0.999), | |
eps=1e-08, | |
weight_decay=cfg.SOLVER.WEIGHT_DECAY, | |
) | |
else: | |
raise NotImplementedError( | |
"Does not support {} optimizer".format(cfg.SOLVER.OPTIMIZING_METHOD) | |
) | |
def get_epoch_lr(cur_epoch, cfg): | |
""" | |
Retrieves the lr for the given epoch (as specified by the lr policy). | |
Args: | |
cfg (config): configs of hyper-parameters of ADAM, includes base | |
learning rate, betas, and weight decays. | |
cur_epoch (float): the number of epoch of the current training stage. | |
""" | |
return lr_policy.get_lr_at_epoch(cfg, cur_epoch) | |
def set_lr(optimizer, new_lr): | |
""" | |
Sets the optimizer lr to the specified value. | |
Args: | |
optimizer (optim): the optimizer using to optimize the current network. | |
new_lr (float): the new learning rate to set. | |
""" | |
for param_group in optimizer.param_groups: | |
param_group["lr"] = new_lr | |