Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import copy | |
import functools | |
import inspect | |
import os | |
import os.path as osp | |
import time | |
from collections import OrderedDict | |
from typing import Callable, Dict, List, Optional, Sequence, Union | |
import torch.nn as nn | |
from torch.distributed.fsdp import (FullStateDictConfig, | |
FullyShardedDataParallel, | |
LocalStateDictConfig, StateDictType) | |
from torch.distributed.fsdp.fully_sharded_data_parallel import ( | |
FullOptimStateDictConfig, LocalOptimStateDictConfig, OptimStateDictConfig, | |
StateDictConfig) | |
from torch.optim import Optimizer | |
from torch.optim.lr_scheduler import LRScheduler | |
import mmengine | |
from legion.common.worker_pool import WorkerPool | |
from mmengine.config import Config, ConfigDict | |
from mmengine.device import get_device | |
from mmengine.dist import get_rank, is_main_process | |
from mmengine.model import BaseDataPreprocessor, is_model_wrapper | |
from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper, | |
OptimWrapperDict, _ParamScheduler, | |
build_optim_wrapper) | |
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, | |
PARAM_SCHEDULERS, STRATEGIES, Registry) | |
from mmengine.utils import get_git_hash, mkdir_or_exist | |
from .distributed import DDPStrategy | |
from .utils import MetaTensorContext | |
FSDP = FullyShardedDataParallel | |
FSDP_CONFIGS = Registry('fsdp configs') | |
FSDP_CONFIGS.register_module(module=FullOptimStateDictConfig) | |
FSDP_CONFIGS.register_module(module=LocalOptimStateDictConfig) | |
FSDP_CONFIGS.register_module(module=FullStateDictConfig) | |
FSDP_CONFIGS.register_module(module=LocalStateDictConfig) | |
def _save_checkpoint(params): | |
from mmengine.runner.checkpoint import save_checkpoint | |
checkpoint, filename = params | |
save_checkpoint(checkpoint, filename) | |
class FSDPStrategy(DDPStrategy): | |
"""Support training model with FullyShardedDataParallel (FSDP). | |
Keyword Args: | |
model_wrapper (dict, optional): Config dict for model wrapper. The | |
default configuration is: | |
Examples: | |
>>> model_wrapper = dict( | |
>>> type='MMFullyShardedDataParallel', | |
>>> use_orig_params=True, | |
>>> ) | |
See more configurable arguments in | |
:class:`MMFullyShardedDataParallel`. Defaults to None | |
skip_init_weights (bool, optional): Whether to skip initialization of | |
weights. Defaults to False. This is useful when the parameters of | |
the large model are loaded from a checkpoint, since skipping the | |
initialization of weights can save a lot of time. | |
state_dict_cfg (str or dict): Configuration for | |
how to save and load the state dict of the model, optimizer, and | |
scheduler. | |
- "local": save and load the sharded state dict in all ranks. | |
- "full": save and load the full state dict in rank 0. | |
- `dict` object: save and load the state dict more flexibly. For | |
example, you can first offload the state dict to the 'cpu' and | |
then save it to the disk. This can help you to load the | |
checkpoint in a non-gpu environment: | |
Examples: | |
>>> state_dict_cfg=dict( | |
>>> state_dict_type='FULL_STATE_DICT', | |
>>> state_dict_config=dict(type='FullStateDictConfig', offload_to_cpu=True), | |
>>> optim_state_dict_config=dict(type='FullOptimStateDictConfig', offload_to_cpu=True), | |
See more configurable arguments for ``state_dict_cfg``, | |
``state_dict_config``, and ``optim_state_dict_config``in | |
`FSDP official api documents`_ | |
kwargs (dict): Additional arguments passed to :class:`DDPStrategy`: | |
- work_dir (str): The working directory to save checkpoints. | |
The logs will be saved in the subdirectory of `work_dir` named | |
:attr:`timestamp`. Defaults to 'work_dirs'. | |
- experiment_name (str, optional): Name of current experiment. If | |
not specified, timestamp will be used as :attr:`experiment_name`. | |
Defaults to None. | |
- env_kwargs (dict, optional): Environment config passed in | |
:meth:`setup_env`. Defaults to None. | |
- log_kwargs (dict, optional): Logger config passed in | |
:meth:`build_logger`. Defaults to None. | |
.. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type | |
""" # noqa: E501 | |
def __init__(self, | |
*, | |
model_wrapper: Optional[dict] = None, | |
skip_init_weights=False, | |
state_dict_cfg: Union[str, dict] = 'local', | |
train_micro_batch_size_per_gpu: Optional[int] = None, | |
**kwargs): | |
super().__init__(model_wrapper=model_wrapper, **kwargs) | |
self._init_state_dict_cfg(state_dict_cfg) | |
if not isinstance(skip_init_weights, bool): | |
raise TypeError('skip_init_weights must be a boolean, but got ' | |
f'{type(skip_init_weights)}') | |
self.skip_init_weights = skip_init_weights | |
self.train_micro_batch_size_per_gpu = train_micro_batch_size_per_gpu | |
def _wrap_model(self, model: nn.Module) -> None: | |
"""Wrap the model to :obj:``MMFullyShardedDataParallel`` or other | |
custom fully sharded data parallel module wrappers. | |
Args: | |
model (nn.Module): Model to be wrapped. | |
Returns: | |
FullyShardedDataParallel: ``MMFullyShardedDataParallel`` | |
or subclass of ``FullyShardedDataParallel``. | |
""" | |
for module in model.modules(): | |
if isinstance(module, BaseDataPreprocessor): | |
module.to(get_device()) | |
if is_model_wrapper(model): | |
return | |
if self.model_wrapper is None: | |
self.model_wrapper = dict(type='MMFullyShardedDataParallel') | |
default_args = dict( | |
module=model, | |
device_id=int(os.environ['LOCAL_RANK']), | |
type='MMFullyShardedDataParallel') | |
model = MODEL_WRAPPERS.build( | |
self.model_wrapper, default_args=default_args) | |
model.set_state_dict_type(model, self.state_dict_type, | |
self.state_dict_config, | |
self.optim_state_dict_config) | |
return model | |
def _is_full_state_dict(self): | |
"""Whether to save and load the full state_dict in rank 0.""" | |
return self.state_dict_type == StateDictType.FULL_STATE_DICT | |
# This is lazy initialized so each replicas creates its own if it needs one. | |
def worker_pool(self): | |
worker_pool = WorkerPool(1, _save_checkpoint) | |
worker_pool.start() | |
return worker_pool | |
def build_model(self, model: Union[nn.Module, dict]) -> nn.Module: | |
"""Build model. | |
If skip_init_weights is True, the model will be built with an empty | |
weights. It means that :meth:`load_checkpoint` must be called to fill | |
the weights before training. | |
Args: | |
model (nn.Module or dict): A ``nn.Module`` object or a dict to | |
build ``nn.Module`` object. If ``model`` is a ``nn.Module`` | |
object, just returns itself. | |
Returns: | |
nn.Module: Model build from ``model``. | |
""" | |
if self.skip_init_weights: | |
if isinstance(model, dict): | |
# Accelerate initialization by skipping init weights | |
with MetaTensorContext(): | |
model = super().build_model(model) | |
model.to_empty(device='cpu') | |
else: | |
model = super().build_model(model) | |
# `id_to_name` will be used to convert the `optim_state_dict` of the | |
# raw optimizer to the `optim_state_dict` | |
# returned by `FSDP.optim_state_dict` in | |
# `StateDictType.FULL_STATE_DICT` mode. | |
self.id_to_name = dict() | |
for name, param in model.named_parameters(): | |
self.id_to_name[id(param)] = name | |
return model | |
def save_checkpoint(self, | |
filename: str, | |
*, | |
save_optimizer: bool = True, | |
save_param_scheduler: bool = True, | |
extra_ckpt: Optional[dict] = None, | |
callback: Optional[Callable] = None) -> None: | |
"""Save checkpoint to given ``filename``. | |
If ``state_dict_type`` is `full`, the checkpoint will only be saved in | |
rank0. The structure of the saved checkpoint is the same as the one | |
saved by ``DDPStrategy`` | |
If ``state_dict_type`` is `local`, each rank will save the sharded | |
state dict to a directory, which means the saved structure will look | |
like this: | |
.. code-block:: bash | |
── epoch_0.pth | |
├── rank0.pth | |
├── rank1.pth | |
├── ... | |
└── rank8.pth | |
Args: | |
filename (str): Filename to save checkpoint. | |
Keyword Args: | |
save_optimizer (bool): Whether to save the optimizer to | |
the checkpoint. Defaults to True. | |
save_param_scheduler (bool): Whether to save the param_scheduler | |
to the checkpoint. Defaults to True. | |
extra_ckpt (dict, optional): Extra checkpoint to save. | |
Defaults to None. | |
callback (callable, callable): Callback function to modify the | |
checkpoint before saving the checkpoint. | |
Defaults to None. | |
""" | |
from mmengine.runner.checkpoint import save_checkpoint | |
state_dict: dict = dict() | |
state_dict['state_dict'] = self.model_state_dict() | |
# save optimizer state dict | |
if save_optimizer and hasattr(self, 'optim_wrapper'): | |
state_dict['optimizer'] = self.optim_state_dict() | |
# save param scheduler state dict | |
if save_param_scheduler and hasattr(self, 'param_schedulers'): | |
state_dict['param_schedulers'] = self.scheduler_state_dict() | |
# save extra checkpoint passed by users | |
if extra_ckpt is None: | |
extra_ckpt = dict() | |
if 'meta' not in extra_ckpt: | |
extra_ckpt['meta'] = dict() | |
extra_ckpt['meta'].update( | |
seed=self.seed, | |
time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), | |
mmengine=mmengine.__version__ + get_git_hash(), | |
) | |
state_dict.update(extra_ckpt) | |
# users can do some modification before saving checkpoint | |
if callback is not None: | |
callback(state_dict) | |
# In non-FULL_STATE_DICT model, FSDPStrategy will save checkpoint | |
# of different ranks in different files. | |
if not self._is_full_state_dict(): | |
rank = get_rank() | |
mkdir_or_exist(filename) | |
ckpt_name = f'rank{rank}.pth' | |
filename = osp.join(filename, ckpt_name) | |
# Don't use worker_pool due to use of ShardedTensor | |
_save_checkpoint((state_dict, filename)) | |
if is_main_process(): | |
if self._is_full_state_dict(): | |
self.worker_pool.put((state_dict, filename)) | |
else: | |
# Don't use worker_pool due to use of ShardedTensor | |
_save_checkpoint((state_dict, filename)) | |
def model_state_dict(self) -> dict: | |
"""Get model state dict based on the ``state_dict_type``. | |
If ``state_dict_type`` is `full`, the model state dict will be the | |
same as the one of original unsharded model. | |
If ``state_dict_type`` is ``local``, and ``use_orig_params`` is ``True`` | |
in ``model_wrapper``. The key of the state dict will be the same as | |
the one of original unsharded model, but its value will be the sharded | |
one | |
If ``state_dict_type`` is `local`, and ```use_orig_params``` is | |
``False`` in ``model_wrapper``, the flatten and sharded state dict will | |
be returned. | |
See more details in the `official api documents`_ | |
.. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict | |
""" # noqa: E501 | |
# We've set state_dict by `FSDP.set_state_dict_type`, therefore we | |
# should get model state dict by `FSDP.state_dict` | |
return self.model.state_dict() | |
def optim_state_dict(self) -> dict: | |
"""Get model state dict based on the ``state_dict_type``. | |
If ``state_dict_type`` is ``full``, the optimizer state dict can be | |
loaded by the original unsharded optimizer. | |
Otherwise, the optimizer state dict could only be loaded by the | |
optimizer with sharded parameters. | |
Note: | |
The optimizer state dict is not the same as the one of original | |
optimizer even if in ``full`` mode, although they can be loaded | |
correctly. | |
See more details in the `official api documents`_ | |
.. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict | |
""" # noqa: E501 | |
return FSDP.optim_state_dict(self.model, self.optim_wrapper) | |
def load_checkpoint(self, filename: str, **kwargs) -> dict: | |
"""Load checkpoint from given ``filename``. | |
Note: | |
If ``state_dict_type`` is `local`, the filename should be a | |
directory contains ``rank{i}.pth``. | |
Args: | |
filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
``open-mmlab://xxx``. | |
Keyword Args: | |
map_location (str or callable): A string or a callable function to | |
specifying how to remap storage locations. | |
Defaults to 'cpu'. | |
strict (bool): strict (bool): Whether to allow different params for | |
the model and checkpoint. | |
revise_keys (list): A list of customized keywords to modify the | |
state_dict in checkpoint. Each item is a (pattern, replacement) | |
pair of the regular expression operations. Defaults to strip | |
the prefix 'module.' by [(r'^module\\.', '')]. | |
callback (callable, callable): Callback function to modify the | |
checkpoint after loading the checkpoint. | |
Defaults to None. | |
""" | |
if self._is_full_state_dict(): | |
return super(DDPStrategy, self).load_checkpoint(filename, **kwargs) | |
else: | |
rank = get_rank() | |
filename = osp.join(filename, f'rank{rank}.pth') | |
return super(DDPStrategy, self).load_checkpoint(filename, **kwargs) | |
def load_model_state_dict( | |
self, | |
state_dict: dict, | |
*, | |
strict: bool = False, | |
revise_keys: list = [(r'^module.', '')], | |
) -> None: # type: ignore | |
"""Load model state from dict. | |
Warning: | |
`revise_keys` is not supported yet. | |
Args: | |
state_dict (dict): Model state dict returned by | |
:meth:`FSDPStrategy.model_state_dict`. If ``state_dict_type`` | |
is ``full``. ``state_dict`` could be the result of | |
``model.state_dict()`` | |
strict (bool): Whether to load model state dict strictly. | |
Defaults to False. | |
""" | |
# We should load state dict by `FSDP.load_state_dict` | |
self.model.load_state_dict(state_dict, strict=strict) | |
def load_optim_state_dict(self, state_dict: dict) -> None: | |
"""Load optimizer state from dict. | |
Args: | |
state_dict (dict): The optimizer state dict. If ``state_dict_type`` | |
is ``full``. ``state_dict`` could be the result of | |
``optimizer.state_dict()`` | |
""" | |
# optim_state_dict = FSDP.optim_state_dict_to_load(state_dict, self.model, self.optim_wrapper.optimizer) ## old fsdp | |
## correct order of args in latest pytorch | |
# https://github.com/pytorch/pytorch/blob/f3df7deab8953af76ff1723ed49094208057a834/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1847 | |
optim_state_dict = FSDP.optim_state_dict_to_load( | |
model=self.model, | |
optim=self.optim_wrapper.optimizer, | |
optim_state_dict=state_dict) | |
self.optim_wrapper.load_state_dict(optim_state_dict) | |
def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: | |
"""Make ``state_dict_type`` and ``state_dict_config`` can be configured | |
with string.""" | |
if isinstance(state_dict_cfg, str): | |
if state_dict_cfg == 'full': | |
self.state_dict_type = StateDictType.FULL_STATE_DICT | |
self.state_dict_config = FullStateDictConfig( | |
rank0_only=True, offload_to_cpu=True) | |
self.optim_state_dict_config = FullOptimStateDictConfig( | |
rank0_only=True, offload_to_cpu=True) | |
elif state_dict_cfg == 'local': | |
self.state_dict_type = StateDictType.LOCAL_STATE_DICT | |
self.state_dict_config = LocalStateDictConfig() | |
self.optim_state_dict_config = LocalOptimStateDictConfig() | |
else: | |
raise ValueError('FSDP only supports `full` and `local` ' | |
f'state_dict_type, but got {state_dict_cfg}') | |
elif isinstance(state_dict_cfg, dict): | |
if 'state_dict_type' not in state_dict_cfg: | |
self.state_dict_type = StateDictType.LOCAL_STATE_DICT | |
else: | |
state_dict_type = state_dict_cfg['state_dict_type'] | |
if isinstance(state_dict_type, str): | |
self.state_dict_type = StateDictType[ | |
state_dict_cfg['state_dict_type']] | |
else: | |
self.state_dict_type = state_dict_type | |
state_dict_config = state_dict_cfg.get('state_dict_config') | |
if state_dict_config is None: | |
self.state_dict_config = LocalStateDictConfig() | |
elif isinstance(state_dict_config, dict): | |
self.state_dict_config = FSDP_CONFIGS.build( | |
state_dict_cfg['state_dict_config']) | |
else: | |
self.state_dict_config = state_dict_config | |
optim_state_dict_config = state_dict_cfg.get( | |
'optim_state_dict_config') | |
if optim_state_dict_config is None: | |
self.optim_state_dict_config = LocalOptimStateDictConfig() | |
elif isinstance(optim_state_dict_config, dict): | |
self.optim_state_dict_config = FSDP_CONFIGS.build( | |
state_dict_cfg['optim_state_dict_config']) | |
else: | |
self.optim_state_dict_config = optim_state_dict_config | |
else: | |
raise TypeError('state_dict_cfg should be a `str` or a `dict`, ' | |
f'but got {type(state_dict_cfg)}') | |
if not isinstance(self.state_dict_type, StateDictType): | |
raise TypeError('state_dict_type must be StateDictType, but got ' | |
f'{type(self.state_dict_type)}') | |
if not isinstance(self.state_dict_config, StateDictConfig): | |
raise TypeError('state_dict_config must be StateDictConfig, but ' | |
f'got {type(self.state_dict_config)}') | |
if not isinstance(self.optim_state_dict_config, OptimStateDictConfig): | |
raise TypeError('optim_state_dict_config must be ' | |
'OptimStateDictConfig, but got ' | |
f'{type(self.optim_state_dict_config)}') | |
def build_optim_wrapper( | |
self, | |
optim_wrapper: Union[Optimizer, OptimWrapper, dict], | |
model: Optional[nn.Module] = None, | |
) -> BaseOptimWrapper: | |
"""Support sharding the optimizer state dict given a built optimizer or | |
optim_wrapper. | |
See specific usage in :meth:`BaseStrategy.build_optim_wrapper`. | |
""" | |
if isinstance(optim_wrapper, Optimizer): | |
optim_wrapper = OptimWrapper(optim_wrapper) | |
if isinstance(optim_wrapper, BaseOptimWrapper): | |
assert model is not None | |
# NOTE: The only difference is that FSDPStrategy will shard | |
# the the built OptimWrapper | |
optimizer = optim_wrapper.optimizer | |
param_groups = optimizer.param_groups | |
optim_state_dict = optimizer.state_dict() | |
assert not optim_state_dict['state'], ( | |
'Optimizer state_dict should be empty when giving an built ' | |
'optim_wrapper to FSDPStrategy') | |
# Align the state_dict with state_dict generated by | |
# FSDP.full_optim_state_dict | |
new_param_groups = [] | |
for group in param_groups: | |
new_group = { | |
key: value | |
for key, value in group.items() if key != 'param' | |
} | |
new_group['params'] = [ | |
self.id_to_name[id(param)] for param in group['params'] | |
] | |
new_param_groups.append(new_group) | |
optim_state_dict['param_groups'] = new_param_groups | |
defaults = { | |
k: v | |
for k, v in optimizer.defaults.items() if k != 'differentiable' | |
} | |
params_dict = {} | |
for k, v in model.named_parameters(): | |
if '_fsdp_wrapped_module' in k: | |
k = k.replace('_fsdp_wrapped_module.', '') | |
params_dict[k] = v | |
params = [] | |
for param_group in new_param_groups: | |
_params = [] | |
for param_name in param_group['params']: | |
if param_name not in params_dict: | |
raise RuntimeError( | |
'Failed to reconstruct the sharded optimizer. ' | |
'You can try to set `use_orig_params=True` in ' | |
'`model_wrapper`') | |
_params.append(params_dict[param_name]) | |
param_group = { | |
k: v | |
for k, v in param_group.items() if k != 'param' | |
} | |
param_group['params'] = _params | |
params.append(param_group) | |
new_optimizer = optimizer.__class__(params, **defaults) | |
# Force to load the converted optim_state_dict in full mode. | |
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): | |
optim_state_dict = FSDP.optim_state_dict_to_load( | |
optim_state_dict, model, new_optimizer) | |
new_optimizer.load_state_dict(optim_state_dict) | |
optim_wrapper.optimizer = new_optimizer | |
return optim_wrapper | |
if isinstance(optim_wrapper, (dict, ConfigDict, Config)): | |
assert model is not None | |
# optimizer must be defined for single optimizer training. | |
optimizer = optim_wrapper.get('optimizer', None) | |
optim_wrapper.setdefault('type', 'OptimWrapper') | |
if optim_wrapper.get('type', | |
'AmpOptimWrapper') in ('AmpOptimWrapper', | |
AmpOptimWrapper): | |
optim_wrapper.setdefault('use_fsdp', True) | |
# If optimizer is a built `Optimizer` instance, the optimizer | |
# wrapper should be built by `OPTIM_WRAPPERS` registry. | |
if isinstance(optimizer, Optimizer): | |
return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore | |
# If `optimizer` is not None or `constructor` is defined, it means, | |
# optimizer wrapper will be built by optimizer wrapper | |
# constructor. Therefore, `build_optim_wrapper` should be called. | |
if optimizer is not None or 'constructor' in optim_wrapper: | |
return build_optim_wrapper(model, optim_wrapper) | |
else: | |
# if `optimizer` is not defined, it should be the case of | |
# training with multiple optimizers. If `constructor` is not | |
# defined either, each value of `optim_wrapper` must be an | |
# `OptimWrapper` instance since `DefaultOptimizerConstructor` | |
# will not handle the case of training with multiple | |
# optimizers. `build_optim_wrapper` will directly build the | |
# `OptimWrapperDict` instance from `optim_wrapper.` | |
optim_wrappers = OrderedDict() | |
for name, optim in optim_wrapper.items(): | |
if not isinstance(optim, OptimWrapper): | |
raise ValueError( | |
'each item mush be an optimizer object when ' | |
'"type" and "constructor" are not in ' | |
f'optimizer, but got {name}={optim}') | |
optim_wrappers[name] = optim | |
return OptimWrapperDict(**optim_wrappers) | |
else: | |
raise TypeError('optimizer wrapper should be an OptimWrapper ' | |
f'object or dict, but got {optim_wrapper}') | |
def _build_param_scheduler( | |
self, | |
scheduler: Union[_ParamScheduler, Dict, List], | |
optim_wrapper: BaseOptimWrapper, | |
default_args: dict, | |
) -> List[_ParamScheduler]: | |
"""Override this method to update the scheduler with the reconstructed | |
sharded optimzer.""" | |
if not isinstance(scheduler, Sequence): | |
schedulers = [scheduler] | |
else: | |
schedulers = scheduler | |
max_epochs = default_args.pop('max_epochs', None) | |
max_iters = default_args.pop('max_iters', None) | |
param_schedulers = [] | |
for scheduler in schedulers: | |
# Update the built scheduler with the sharded optimizer | |
if isinstance(scheduler, (_ParamScheduler, LRScheduler)): | |
parameter_keys = inspect.signature( | |
scheduler.__class__).parameters.keys() | |
kwargs = { | |
k: v | |
for k, v in scheduler.state_dict().items() | |
if k in parameter_keys | |
} | |
scheduler = scheduler.__class__(optim_wrapper, **kwargs) | |
elif isinstance(scheduler, dict): | |
_scheduler = copy.deepcopy(scheduler) | |
# Set default end | |
if _scheduler.get('by_epoch', True): | |
if max_epochs is None: | |
raise ValueError( | |
'max_epochs must be specified in default_args') | |
default_end = max_epochs | |
else: | |
if max_iters is None: | |
raise ValueError( | |
'max_iters must be specified in default_args') | |
default_end = max_iters | |
_scheduler.setdefault('end', default_end) | |
self.logger.debug( | |
f'The `end` of {_scheduler["type"]} is not set. ' | |
'Use the max epochs/iters of train loop as default.') | |
param_schedulers.append( | |
PARAM_SCHEDULERS.build( | |
_scheduler, | |
default_args=dict( | |
optimizer=optim_wrapper, **default_args))) | |
else: | |
raise TypeError( | |
'scheduler should be a _ParamScheduler object or dict, ' | |
f'but got {scheduler}') | |
return param_schedulers | |