Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
from contextlib import contextmanager | |
from typing import Dict, List, Optional | |
import torch | |
import torch.nn as nn | |
from torch.optim import Optimizer | |
from mmengine.logging import MessageHub, print_log | |
from mmengine.registry import OPTIM_WRAPPERS | |
from mmengine.utils.dl_utils import has_batch_norm | |
from .base import BaseOptimWrapper | |
class OptimWrapper(BaseOptimWrapper): | |
"""Optimizer wrapper provides a common interface for updating parameters. | |
Optimizer wrapper provides a unified interface for single precision | |
training and automatic mixed precision training with different hardware. | |
OptimWrapper encapsulates optimizer to provide simplified interfaces | |
for commonly used training techniques such as gradient accumulative and | |
grad clips. ``OptimWrapper`` implements the basic logic of gradient | |
accumulation and gradient clipping based on ``torch.optim.Optimizer``. | |
The subclasses only need to override some methods to implement the mixed | |
precision training. See more information in :class:`AmpOptimWrapper`. | |
Args: | |
optimizer (Optimizer): Optimizer used to update model parameters. | |
accumulative_counts (int): The number of iterations to accumulate | |
gradients. The parameters will be updated per | |
``accumulative_counts``. | |
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be | |
the arguments of :func:`torch.nn.utils.clip_grad_norm_` or | |
:func:`torch.nn.utils.clip_grad_value_`. ``clip_grad`` should be a | |
dict, and the keys could be set as follows: | |
If the key ``type`` is not set, or ``type`` is "norm", | |
the accepted keys are as follows: | |
- max_norm (float or int): Max norm of the gradients. | |
- norm_type (float or int): Type of the used p-norm. Can be | |
``'inf'`` for infinity norm. | |
- error_if_nonfinite (bool): If True, an error is thrown if | |
the total norm of the gradients from :attr:`parameters` is | |
``nan``, ``inf``, or ``-inf``. Defaults to False (will switch | |
to True in the future) | |
If the key ``type`` is set to "value", the accepted keys are as | |
follows: | |
- clip_value (float or int): maximum allowed value of the | |
gradients. The gradients are clipped in the range | |
``(-clip_value, +clip_value)``. | |
Note: | |
If ``accumulative_counts`` is larger than 1, perform | |
:meth:`update_params` under the context of ``optim_context`` | |
could avoid unnecessary gradient synchronization. | |
Note: | |
If you use ``IterBasedRunner`` and enable gradient accumulation, | |
the original `max_iters` should be multiplied by | |
``accumulative_counts``. | |
Note: | |
The subclass should ensure that once :meth:`update_params` is called, | |
``_inner_count += 1`` is automatically performed. | |
Examples: | |
>>> # Config sample of OptimWrapper and enable clipping gradient by | |
>>> # norm. | |
>>> optim_wrapper_cfg = dict( | |
>>> type='OptimWrapper', | |
>>> _accumulative_counts=1, | |
>>> clip_grad=dict(max_norm=0.2)) | |
>>> # Config sample of OptimWrapper and enable clipping gradient by | |
>>> # value. | |
>>> optim_wrapper_cfg = dict( | |
>>> type='OptimWrapper', | |
>>> _accumulative_counts=1, | |
>>> clip_grad=dict(type='value', clip_value=0.2)) | |
>>> # Use OptimWrapper to update model. | |
>>> import torch.nn as nn | |
>>> import torch | |
>>> from torch.optim import SGD | |
>>> from torch.utils.data import DataLoader | |
>>> from mmengine.optim import OptimWrapper | |
>>> | |
>>> model = nn.Linear(1, 1) | |
>>> dataset = torch.randn(10, 1, 1) | |
>>> dataloader = DataLoader(dataset) | |
>>> optimizer = SGD(model.parameters(), lr=0.1) | |
>>> optim_wrapper = OptimWrapper(optimizer) | |
>>> | |
>>> for data in dataloader: | |
>>> loss = model(data) | |
>>> optim_wrapper.update_params(loss) | |
>>> # Enable gradient accumulation | |
>>> optim_wrapper_cfg = dict( | |
>>> type='OptimWrapper', | |
>>> _accumulative_counts=3, | |
>>> clip_grad=dict(max_norm=0.2)) | |
>>> ddp_model = DistributedDataParallel(model) | |
>>> optimizer = SGD(ddp_model.parameters(), lr=0.1) | |
>>> optim_wrapper = OptimWrapper(optimizer) | |
>>> optim_wrapper.initialize_count_status(0, len(dataloader)) | |
>>> # If model is a subclass instance of DistributedDataParallel, | |
>>> # `optim_context` context manager can avoid unnecessary gradient | |
>>> # synchronize. | |
>>> for iter, data in enumerate(dataloader): | |
>>> with optim_wrapper.optim_context(ddp_model): | |
>>> loss = model(data) | |
>>> optim_wrapper.update_params(loss) | |
""" | |
def __init__(self, | |
optimizer: Optimizer, | |
accumulative_counts: int = 1, | |
clip_grad: Optional[dict] = None): | |
assert accumulative_counts > 0, ( | |
'_accumulative_counts at least greater than or equal to 1') | |
self._accumulative_counts = accumulative_counts | |
self.optimizer = optimizer | |
if clip_grad is not None: | |
# clip_grad_kwargs should not be non-empty dict. | |
assert isinstance(clip_grad, dict) and clip_grad, ( | |
'If `clip_grad` is not None, it should be a `dict` ' | |
'which is the arguments of `torch.nn.utils.clip_grad_norm_` ' | |
'or clip_grad_value_`.') | |
clip_type = clip_grad.pop('type', 'norm') ## if type not found then defaults to 'norm' | |
if clip_type == 'norm': | |
self.clip_func = torch.nn.utils.clip_grad_norm_ | |
self.grad_name = 'grad_norm' | |
elif clip_type == 'value': | |
self.clip_func = torch.nn.utils.clip_grad_value_ | |
self.grad_name = 'grad_value' | |
else: | |
raise ValueError('type of clip_grad should be "norm" or ' | |
f'"value" but got {clip_type}') | |
assert clip_grad, ('`clip_grad` should contain other arguments ' | |
'besides `type`. The arguments should match ' | |
'with the `torch.nn.utils.clip_grad_norm_` or ' | |
'clip_grad_value_`') | |
self.clip_grad_kwargs = clip_grad | |
# Used to update `grad_norm` log message. | |
self.message_hub = MessageHub.get_current_instance() | |
self._inner_count = 0 | |
# `_max_counts` means the total number of parameter updates. It | |
# ensures that the gradient of the last few iterations will not be | |
# lost when the `_max_counts` is not divisible by | |
# `accumulative_counts`. | |
self._max_counts = -1 | |
# The `_remainder_iter` is used for calculating loss factor at the | |
# last few iterations. If `_max_counts` has not been initialized, | |
# the loss factor will always be the same as `_accumulative_counts`. | |
self._remainder_counts = -1 | |
# The Following code is used to initialize `base_param_settings`. | |
# `base_param_settings` is used to store the parameters that are not | |
# updated by the optimizer. | |
# The `base_param_settings` used for tracking the base learning in the | |
# optimizer. If the optimizer has multiple parameter groups, this | |
# params will not be scaled by the loss factor. | |
if len(optimizer.param_groups) > 1: | |
self.base_param_settings = { | |
'params': torch.tensor([0.0], dtype=torch.float) | |
} | |
self.base_param_settings.update(**self.optimizer.defaults) | |
else: | |
self.base_param_settings = None # type: ignore | |
def update_params( # type: ignore | |
self, | |
loss: torch.Tensor, | |
step_kwargs: Optional[Dict] = None, | |
zero_kwargs: Optional[Dict] = None) -> None: | |
"""Update parameters in :attr:`optimizer`. | |
Args: | |
loss (torch.Tensor): A tensor for back propagation. | |
step_kwargs (dict): Arguments for optimizer.step. | |
Defaults to None. | |
New in version v0.4.0. | |
zero_kwargs (dict): Arguments for optimizer.zero_grad. | |
Defaults to None. | |
New in version v0.4.0. | |
""" | |
if step_kwargs is None: | |
step_kwargs = {} | |
if zero_kwargs is None: | |
zero_kwargs = {} | |
loss = self.scale_loss(loss) | |
self.backward(loss) | |
# Update parameters only if `self._inner_count` is divisible by | |
# `self._accumulative_counts` or `self._inner_count` equals to | |
# `self._max_counts` | |
if self.should_update(): | |
self.step(**step_kwargs) | |
self.zero_grad(**zero_kwargs) | |
def backward(self, loss: torch.Tensor, **kwargs) -> None: | |
"""Perform gradient back propagation. | |
Provide unified ``backward`` interface compatible with automatic mixed | |
precision training. Subclass can overload this method to implement the | |
required logic. For example, ``torch.cuda.amp`` require some extra | |
operation on GradScaler during backward process. | |
Note: | |
If subclasses inherit from ``OptimWrapper`` override | |
``backward``, ``_inner_count +=1`` must be implemented. | |
Args: | |
loss (torch.Tensor): The loss of current iteration. | |
kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward`. | |
""" | |
loss.backward(**kwargs) | |
self._inner_count += 1 | |
def zero_grad(self, **kwargs) -> None: | |
"""A wrapper of ``Optimizer.zero_grad``. | |
Provide unified ``zero_grad`` interface compatible with automatic mixed | |
precision training. Subclass can overload this method to implement the | |
required logic. | |
Args: | |
kwargs: Keyword arguments passed to | |
:meth:`torch.optim.Optimizer.zero_grad`. | |
""" | |
self.optimizer.zero_grad(**kwargs) | |
def step(self, **kwargs) -> None: | |
"""A wrapper of ``Optimizer.step``. | |
Provide unified ``step`` interface compatible with automatic mixed | |
precision training. Subclass can overload this method to implement the | |
required logic. For example, ``torch.cuda.amp`` require some extra | |
operation on ``GradScaler`` during step process. | |
Clip grad if :attr:`clip_grad_kwargs` is not None, and then update | |
parameters. | |
Args: | |
kwargs: Keyword arguments passed to | |
:meth:`torch.optim.Optimizer.step`. | |
""" | |
##-------------zero out nan-------------- | |
params = [p for pg in self.optimizer.param_groups for p in pg["params"]] | |
for p in params: | |
if hasattr(p, "grad") and p.requires_grad and p.grad is not None: | |
p.grad.data[torch.isnan(p.grad.data)] = 0 | |
p.grad.data[torch.isinf(p.grad.data)] = 0 | |
##---------------------------------------- | |
if self.clip_grad_kwargs: | |
self._clip_grad() | |
self.optimizer.step(**kwargs) | |
def optim_context(self, model: nn.Module): | |
"""A Context for gradient accumulation and automatic mix precision | |
training. | |
If subclasses need to enable the context for mix precision training, | |
e.g., ``:class:`AmpOptimWrapper``, the corresponding context should be | |
enabled in `optim_context`. Since ``OptimWrapper`` uses default fp32 | |
training, ``optim_context`` will only enable the context for | |
blocking the unnecessary gradient synchronization during gradient | |
accumulation | |
If model is an instance with ``no_sync`` method (which means | |
blocking the gradient synchronization) and | |
``self._accumulative_counts != 1``. The model will not automatically | |
synchronize gradients if ``cur_iter`` is divisible by | |
``self._accumulative_counts``. Otherwise, this method will enable an | |
empty context. | |
Args: | |
model (nn.Module): The training model. | |
""" | |
# During gradient accumulation process, the gradient synchronize | |
# should only happen before updating parameters. | |
if not self.should_sync() and hasattr(model, 'no_sync'): | |
with model.no_sync(): | |
yield | |
else: | |
yield | |
def _clip_grad(self) -> None: | |
"""Clip the gradients of parameters.""" | |
params: List[torch.Tensor] = [] | |
for param_group in self.optimizer.param_groups: | |
params.extend(param_group['params']) | |
params = list( | |
filter(lambda p: p.requires_grad and p.grad is not None, params)) | |
if len(params) > 0: | |
grad = self.clip_func(params, **self.clip_grad_kwargs) | |
# `torch.nn.utils.clip_grad_value_` will return None. | |
if grad is not None: | |
self.message_hub.update_scalar(f'train/{self.grad_name}', | |
float(grad)) | |
def initialize_count_status(self, model: nn.Module, init_counts: int, | |
max_counts: int) -> None: | |
"""Initialize gradient accumulation related attributes. | |
``OptimWrapper`` can be used without calling | |
``initialize_iter_status``. However, Consider the case of ``len( | |
dataloader) == 10``, and the ``accumulative_iter == 3``. Since 10 is | |
not divisible by 3, the last iteration will not trigger | |
``optimizer.step()``, resulting in one less parameter updating. | |
Args: | |
model (nn.Module): Training model | |
init_counts (int): The initial value of the inner count. | |
max_counts (int): The maximum value of the inner count. | |
""" | |
self._inner_count = init_counts | |
self._max_counts = max_counts | |
if self._inner_count % self._accumulative_counts != 0: | |
print_log( | |
'Resumed iteration number is not divisible by ' | |
'`_accumulative_counts` in `GradientCumulativeOptimizerHook`, ' | |
'which means the gradient of some iterations is lost and the ' | |
'result may be influenced slightly.', | |
logger='current', | |
level=logging.WARNING) | |
if has_batch_norm(model) and self._accumulative_counts > 1: | |
print_log( | |
'Gradient accumulative may slightly decrease ' | |
'performance because the model has BatchNorm layers.', | |
logger='current', | |
level=logging.WARNING) | |
# Remainder of `_max_counts` divided by `_accumulative_counts` | |
self._remainder_counts = self._max_counts % self._accumulative_counts | |
def should_update(self) -> bool: | |
"""Decide whether the parameters should be updated at the current | |
iteration. | |
Called by :meth:`update_params` and check whether the optimizer | |
wrapper should update parameters at current iteration. | |
Returns: | |
bool: Whether to update parameters. | |
""" | |
return (self._inner_count % self._accumulative_counts == 0 | |
or self._inner_count == self._max_counts) | |
def should_sync(self) -> bool: | |
"""Decide whether the automatic gradient synchronization should be | |
allowed at the current iteration. | |
It takes effect when gradient accumulation is used to skip | |
synchronization at the iterations where the parameter is not updated. | |
Since ``should_sync`` is called by :meth:`optim_context`, and it is | |
called before :meth:`backward` which means ``self._inner_count += 1`` | |
has not happened yet. Therefore, ``self._inner_count += 1`` should be | |
performed manually here. | |
Returns: | |
bool: Whether to block the automatic gradient synchronization. | |
""" | |
return ((self._inner_count + 1) % self._accumulative_counts == 0 | |
or (self._inner_count + 1) == self._max_counts) | |
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor: | |
"""Get scaled loss according to ``_accumulative_counts``, | |
``_inner_count`` and max_counts. | |
Args: | |
loss (torch.Tensor): Original loss calculated by model. | |
Returns: | |
loss (torch.Tensor): Scaled loss. | |
""" | |
if self._accumulative_counts == 1: | |
# update parameters without gradient accumulation. The gradient | |
# should not be rescaled and `loss_factor=1`. | |
loss_factor = 1 | |
elif self._max_counts == -1: | |
loss_factor = self._accumulative_counts | |
else: | |
# if `self._accumulative_counts > 1`, the gradient needs to be | |
# rescaled and accumulated. In most cases, `loss_factor` equals to | |
# `self._accumulative_counts`. However, `self._max_counts` may not | |
# be divisible by `self._accumulative_counts`, so the | |
# `loss_scale` for the last few iterations needs to be | |
# recalculated. | |
if self._inner_count < self._max_counts - self._remainder_counts: | |
loss_factor = self._accumulative_counts | |
else: | |
loss_factor = self._remainder_counts | |
assert loss_factor > 0, ( | |
'loss_factor should be larger than zero! This error could ' | |
'happened when initialize_iter_status called with an ' | |
'error `init_counts` or `max_counts`') | |
loss = loss / loss_factor | |
return loss | |
def inner_count(self): | |
"""Get the number of updating parameters of optimizer wrapper.""" | |
return self._inner_count | |
def __repr__(self): | |
wrapper_info = (f'Type: {type(self).__name__}\n' | |
f'_accumulative_counts: {self._accumulative_counts}\n' | |
'optimizer: \n') | |
optimizer_str = repr(self.optimizer) + '\n' | |
return wrapper_info + optimizer_str | |