Spaces:
Build error
Build error
File size: 7,999 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from typing import Union
import torch
import torch.nn as nn
from mmengine.device import (is_cuda_available, is_mlu_available,
is_npu_available)
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from .optimizer_wrapper import OptimWrapper
if is_npu_available():
from torch.npu.amp import GradScaler
elif is_mlu_available():
from torch.mlu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
@OPTIM_WRAPPERS.register_module()
class AmpOptimWrapper(OptimWrapper):
"""A subclass of :class:`OptimWrapper` that supports automatic mixed
precision training based on torch.cuda.amp.
``AmpOptimWrapper`` provides a unified interface with
``OptimWrapper``, so ``AmpOptimWrapper`` can be used in the same way
as ``OptimWrapper``.
Warnings:
``AmpOptimWrapper`` requires PyTorch >= 1.6.
Args:
loss_scale (float or str or dict): The initial configuration of
`torch.cuda.amp.GradScaler`. See more specific arguments
introduction at `PyTorch AMP <https://pytorch.org/docs/stable/amp.html?highlight=gradscalertorch.cuda.amp.GradScaler>`_ # noqa: E501
Defaults to ``dynamic``.
- "dynamic": Initialize GradScale without any arguments.
- float: Initialize GradScaler with ``init_scale``.
- dict: Initialize GradScaler with more detail configuration.
dtype (str or torch.dtype, optional): The data type to autocast in amp.
If a ``str`` is given, it will be converted to ``torch.dtype``.
Valid ``str`` format are `'float16'`, `'bfloat16'`, `'float32'` and
`'float64'`. If set to ``None``, the default data type will be used.
Defaults to None.
`New in version 0.6.1.`
use_fsdp (bool): Using ``ShardedGradScaler`` when it is True. It should
be enabled when using ``FullyShardedDataParallel``.
Defaults to False.
`New in version 0.8.0.`
**kwargs: Keyword arguments passed to OptimWrapper.
Warnings:
``dtype`` argument is only available with PyTorch version >= 1.10.0. If
you use PyTorch of an older version, it will be ignored.
Note:
If you use ``IterBasedRunner`` and enable gradient accumulation,
the original `max_iters` should be multiplied by
``accumulative_counts``.
"""
valid_dtypes = ('float16', 'bfloat16', 'float32', 'float64')
def __init__(self,
loss_scale: str = 'dynamic',
dtype: Union[str, torch.dtype] = None,
use_fsdp: bool = False,
**kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
assert is_cuda_available() or is_npu_available() or is_mlu_available(
), ('``AmpOptimizerWrapper`` is only available training '
'on gpu, npu or mlu')
super().__init__(**kwargs)
self._scale_update_param = None
if use_fsdp:
if digit_version(torch.__version__) >= digit_version('2.0.0'):
from torch.distributed.fsdp.sharded_grad_scaler import \
ShardedGradScaler
scaler_type = ShardedGradScaler
else:
raise RuntimeError(
'PyTorch>=2.0.0 is required when sets `use_fsdp=True`')
else:
scaler_type = GradScaler
if loss_scale == 'dynamic':
# If loss_scale is a string, it must be 'dynamic', then dynamic
# loss scaling will be used.
self.loss_scaler = scaler_type()
elif isinstance(loss_scale, float):
# Static loss scaling
self._scale_update_param = loss_scale
self.loss_scaler = scaler_type(init_scale=loss_scale)
elif isinstance(loss_scale, dict):
# More specific configuration.
self.loss_scaler = scaler_type(**loss_scale)
else:
raise TypeError('loss_scale must be of type float, dict, or '
f'"dynamic", but got {loss_scale}')
# convert string value to torch.dtype
if isinstance(dtype, str):
assert dtype in self.valid_dtypes, (
f'dtype should be any of {self.valid_dtypes}, got {dtype}')
dtype = getattr(torch, dtype)
assert dtype is None or isinstance(dtype, torch.dtype), (
f'dtype should be None or instance of torch.dtype, got {dtype}')
self.cast_dtype = dtype
def backward(self, loss: torch.Tensor, **kwargs):
"""Perform gradient back propagation with :attr:`loss_scaler`.
Args:
loss (torch.Tensor): The loss of current iteration.
kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward`
"""
self.loss_scaler.scale(loss).backward(**kwargs)
self._inner_count += 1
def step(self, **kwargs):
"""Update parameters with :attr:`loss_scaler`.
Args:
kwargs: Keyword arguments passed to
:meth:`torch.optim.Optimizer.step`.
"""
##-------------zero out nan--------------
params = [p for pg in self.optimizer.param_groups for p in pg["params"]]
for p in params:
if hasattr(p, "grad") and p.grad is not None:
p.grad.data[torch.isnan(p.grad.data)] = 0
p.grad.data[torch.isinf(p.grad.data)] = 0
##----------------------------------------
if self.clip_grad_kwargs:
self.loss_scaler.unscale_(self.optimizer)
self._clip_grad()
self.loss_scaler.step(self.optimizer, **kwargs)
self.loss_scaler.update(self._scale_update_param)
def state_dict(self) -> dict:
"""Get the state dictionary of :attr:`optimizer` and
:attr:`loss_scaler`.
Based on the state dictionary of the optimizer, the returned state
dictionary will add a key named "loss_scaler".
Returns:
dict: The merged state dict of :attr:`loss_scaler` and
:attr:`optimizer`.
"""
# save state_dict of loss_scaler
state_dict = super().state_dict()
state_dict['loss_scaler'] = self.loss_scaler.state_dict()
return state_dict
def load_state_dict(self, state_dict: dict):
"""Load and parse the state dictionary of :attr:`optimizer` and
:attr:`loss_scaler`.
If state_dict contains "loss_scaler.", the :attr:`loss_scaler` will
load the corresponding keys. Otherwise, only the :attr:`optimizer`
will load the state dictionary.
Args:
state_dict (dict): The state dict of :attr:`optimizer` and
:attr:`loss_scaler`
"""
if 'loss_scaler' in state_dict:
self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler'))
if 'base_param_settings' in state_dict:
self.base_param_settings = state_dict.pop('base_param_settings')
# load state_dict of optimizer
self.optimizer.load_state_dict(state_dict)
@contextmanager
def optim_context(self, model: nn.Module):
"""Enables the context for mixed precision training, and enables the
context for disabling gradient synchronization during gradient
accumulation context.
Args:
model (nn.Module): The training model.
"""
from mmengine.runner.amp import autocast
with super().optim_context(model), autocast(dtype=self.cast_dtype):
yield
|