Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from contextlib import contextmanager | |
from typing import Union | |
import torch | |
import torch.nn as nn | |
from mmengine.device import (is_cuda_available, is_mlu_available, | |
is_npu_available) | |
from mmengine.registry import OPTIM_WRAPPERS | |
from mmengine.utils import digit_version | |
from mmengine.utils.dl_utils import TORCH_VERSION | |
from .optimizer_wrapper import OptimWrapper | |
if is_npu_available(): | |
from torch.npu.amp import GradScaler | |
elif is_mlu_available(): | |
from torch.mlu.amp import GradScaler | |
else: | |
from torch.cuda.amp import GradScaler | |
class AmpOptimWrapper(OptimWrapper): | |
"""A subclass of :class:`OptimWrapper` that supports automatic mixed | |
precision training based on torch.cuda.amp. | |
``AmpOptimWrapper`` provides a unified interface with | |
``OptimWrapper``, so ``AmpOptimWrapper`` can be used in the same way | |
as ``OptimWrapper``. | |
Warnings: | |
``AmpOptimWrapper`` requires PyTorch >= 1.6. | |
Args: | |
loss_scale (float or str or dict): The initial configuration of | |
`torch.cuda.amp.GradScaler`. See more specific arguments | |
introduction at `PyTorch AMP <https://pytorch.org/docs/stable/amp.html?highlight=gradscalertorch.cuda.amp.GradScaler>`_ # noqa: E501 | |
Defaults to ``dynamic``. | |
- "dynamic": Initialize GradScale without any arguments. | |
- float: Initialize GradScaler with ``init_scale``. | |
- dict: Initialize GradScaler with more detail configuration. | |
dtype (str or torch.dtype, optional): The data type to autocast in amp. | |
If a ``str`` is given, it will be converted to ``torch.dtype``. | |
Valid ``str`` format are `'float16'`, `'bfloat16'`, `'float32'` and | |
`'float64'`. If set to ``None``, the default data type will be used. | |
Defaults to None. | |
`New in version 0.6.1.` | |
use_fsdp (bool): Using ``ShardedGradScaler`` when it is True. It should | |
be enabled when using ``FullyShardedDataParallel``. | |
Defaults to False. | |
`New in version 0.8.0.` | |
**kwargs: Keyword arguments passed to OptimWrapper. | |
Warnings: | |
``dtype`` argument is only available with PyTorch version >= 1.10.0. If | |
you use PyTorch of an older version, it will be ignored. | |
Note: | |
If you use ``IterBasedRunner`` and enable gradient accumulation, | |
the original `max_iters` should be multiplied by | |
``accumulative_counts``. | |
""" | |
valid_dtypes = ('float16', 'bfloat16', 'float32', 'float64') | |
def __init__(self, | |
loss_scale: str = 'dynamic', | |
dtype: Union[str, torch.dtype] = None, | |
use_fsdp: bool = False, | |
**kwargs): | |
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( | |
'`torch.cuda.amp` is only available when pytorch version >= 1.6') | |
assert is_cuda_available() or is_npu_available() or is_mlu_available( | |
), ('``AmpOptimizerWrapper`` is only available training ' | |
'on gpu, npu or mlu') | |
super().__init__(**kwargs) | |
self._scale_update_param = None | |
if use_fsdp: | |
if digit_version(torch.__version__) >= digit_version('2.0.0'): | |
from torch.distributed.fsdp.sharded_grad_scaler import \ | |
ShardedGradScaler | |
scaler_type = ShardedGradScaler | |
else: | |
raise RuntimeError( | |
'PyTorch>=2.0.0 is required when sets `use_fsdp=True`') | |
else: | |
scaler_type = GradScaler | |
if loss_scale == 'dynamic': | |
# If loss_scale is a string, it must be 'dynamic', then dynamic | |
# loss scaling will be used. | |
self.loss_scaler = scaler_type() | |
elif isinstance(loss_scale, float): | |
# Static loss scaling | |
self._scale_update_param = loss_scale | |
self.loss_scaler = scaler_type(init_scale=loss_scale) | |
elif isinstance(loss_scale, dict): | |
# More specific configuration. | |
self.loss_scaler = scaler_type(**loss_scale) | |
else: | |
raise TypeError('loss_scale must be of type float, dict, or ' | |
f'"dynamic", but got {loss_scale}') | |
# convert string value to torch.dtype | |
if isinstance(dtype, str): | |
assert dtype in self.valid_dtypes, ( | |
f'dtype should be any of {self.valid_dtypes}, got {dtype}') | |
dtype = getattr(torch, dtype) | |
assert dtype is None or isinstance(dtype, torch.dtype), ( | |
f'dtype should be None or instance of torch.dtype, got {dtype}') | |
self.cast_dtype = dtype | |
def backward(self, loss: torch.Tensor, **kwargs): | |
"""Perform gradient back propagation with :attr:`loss_scaler`. | |
Args: | |
loss (torch.Tensor): The loss of current iteration. | |
kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` | |
""" | |
self.loss_scaler.scale(loss).backward(**kwargs) | |
self._inner_count += 1 | |
def step(self, **kwargs): | |
"""Update parameters with :attr:`loss_scaler`. | |
Args: | |
kwargs: Keyword arguments passed to | |
:meth:`torch.optim.Optimizer.step`. | |
""" | |
##-------------zero out nan-------------- | |
params = [p for pg in self.optimizer.param_groups for p in pg["params"]] | |
for p in params: | |
if hasattr(p, "grad") and p.grad is not None: | |
p.grad.data[torch.isnan(p.grad.data)] = 0 | |
p.grad.data[torch.isinf(p.grad.data)] = 0 | |
##---------------------------------------- | |
if self.clip_grad_kwargs: | |
self.loss_scaler.unscale_(self.optimizer) | |
self._clip_grad() | |
self.loss_scaler.step(self.optimizer, **kwargs) | |
self.loss_scaler.update(self._scale_update_param) | |
def state_dict(self) -> dict: | |
"""Get the state dictionary of :attr:`optimizer` and | |
:attr:`loss_scaler`. | |
Based on the state dictionary of the optimizer, the returned state | |
dictionary will add a key named "loss_scaler". | |
Returns: | |
dict: The merged state dict of :attr:`loss_scaler` and | |
:attr:`optimizer`. | |
""" | |
# save state_dict of loss_scaler | |
state_dict = super().state_dict() | |
state_dict['loss_scaler'] = self.loss_scaler.state_dict() | |
return state_dict | |
def load_state_dict(self, state_dict: dict): | |
"""Load and parse the state dictionary of :attr:`optimizer` and | |
:attr:`loss_scaler`. | |
If state_dict contains "loss_scaler.", the :attr:`loss_scaler` will | |
load the corresponding keys. Otherwise, only the :attr:`optimizer` | |
will load the state dictionary. | |
Args: | |
state_dict (dict): The state dict of :attr:`optimizer` and | |
:attr:`loss_scaler` | |
""" | |
if 'loss_scaler' in state_dict: | |
self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler')) | |
if 'base_param_settings' in state_dict: | |
self.base_param_settings = state_dict.pop('base_param_settings') | |
# load state_dict of optimizer | |
self.optimizer.load_state_dict(state_dict) | |
def optim_context(self, model: nn.Module): | |
"""Enables the context for mixed precision training, and enables the | |
context for disabling gradient synchronization during gradient | |
accumulation context. | |
Args: | |
model (nn.Module): The training model. | |
""" | |
from mmengine.runner.amp import autocast | |
with super().optim_context(model), autocast(dtype=self.cast_dtype): | |
yield | |