Spaces:
Sleeping
Sleeping
from typing import Type | |
from torch import optim | |
from .functional_adadelta import _FunctionalAdadelta | |
from .functional_adagrad import _FunctionalAdagrad | |
from .functional_adam import _FunctionalAdam | |
from .functional_adamax import _FunctionalAdamax | |
from .functional_adamw import _FunctionalAdamW | |
from .functional_rmsprop import _FunctionalRMSprop | |
from .functional_rprop import _FunctionalRprop | |
from .functional_sgd import _FunctionalSGD | |
# dict to map a user passed in optimizer_class to a functional | |
# optimizer class if we have already defined inside the | |
# distributed.optim package, this is so that we hide the | |
# functional optimizer to user and still provide the same API. | |
functional_optim_map = { | |
optim.Adagrad: _FunctionalAdagrad, | |
optim.Adam: _FunctionalAdam, | |
optim.AdamW: _FunctionalAdamW, | |
optim.SGD: _FunctionalSGD, | |
optim.Adadelta: _FunctionalAdadelta, | |
optim.RMSprop: _FunctionalRMSprop, | |
optim.Rprop: _FunctionalRprop, | |
optim.Adamax: _FunctionalAdamax, | |
} | |
def register_functional_optim(key, optim): | |
""" | |
Interface to insert a new functional optimizer to functional_optim_map | |
``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key | |
need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers) | |
Example:: | |
>>> # import the new functional optimizer | |
>>> # xdoctest: +SKIP | |
>>> from xyz import fn_optimizer | |
>>> from torch.distributed.optim.utils import register_functional_optim | |
>>> fn_optim_key = "XYZ_optim" | |
>>> register_functional_optim(fn_optim_key, fn_optimizer) | |
""" | |
if key not in functional_optim_map: | |
functional_optim_map[key] = optim | |
def as_functional_optim(optim_cls: Type, *args, **kwargs): | |
try: | |
functional_cls = functional_optim_map[optim_cls] | |
except KeyError as e: | |
raise ValueError( | |
f"Optimizer {optim_cls} does not have a functional " f"counterpart!" | |
) from e | |
return _create_functional_optim(functional_cls, *args, **kwargs) | |
def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs): | |
return functional_optim_cls( | |
[], | |
*args, | |
**kwargs, | |
_allow_empty_param_list=True, | |
) | |