Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch.optim | |
| from . import LegacyFairseqOptimizer, register_optimizer | |
| class Adagrad(LegacyFairseqOptimizer): | |
| def __init__(self, args, params): | |
| super().__init__(args) | |
| self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) | |
| def add_args(parser): | |
| """Add optimizer-specific arguments to the parser.""" | |
| # fmt: off | |
| parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', | |
| help='weight decay') | |
| # fmt: on | |
| def optimizer_config(self): | |
| """ | |
| Return a kwarg dictionary that will be used to override optimizer | |
| args stored in checkpoints. This allows us to load a checkpoint and | |
| resume training using a different set of optimizer args, e.g., with a | |
| different learning rate. | |
| """ | |
| return { | |
| "lr": self.args.lr[0], | |
| "weight_decay": self.args.weight_decay, | |
| } | |
| def supports_flat_params(self): | |
| return False | |