File size: 9,157 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from contextlib import contextmanager
from typing import Optional, Union

import torch
import torch.nn as nn

# a circular import will be caused by
# from mmengine.model.wrappers import is_model_wrapper
import mmengine
from mmengine.registry import OPTIM_WRAPPERS
from .optimizer_wrapper import OptimWrapper

try:
    import apex.amp as apex_amp
except ImportError:
    apex_amp = None


@OPTIM_WRAPPERS.register_module()
class ApexOptimWrapper(OptimWrapper):
    """A subclass of :class:`OptimWrapper` that supports automatic mixed
    precision training based on apex.amp.

    ``ApexOptimWrapper`` provides a unified interface with
    ``OptimWrapper``, so it can be used in the same way as ``OptimWrapper``.

    Warning:
        ``ApexOptimWrapper`` requires `nvidia apex <https://github.com/NVIDIA/apex>`_

    Args:
        opt_level (str): Pure or mixed precision optimization level. Accepted
            values are "O0", "O1", "O2", and "O3". Defaults to "O1".
        loss_scale (float or str, optional): If passed as a string, must be a
            string representing a number, e.g., "128.0", or the string
            "dynamic". Defaults to "dynamic".
        enabled (bool): If False, renders all Amp calls no-ops, so your script
            should run as if Amp were not present. Defaults to True.
        cast_model_type (torch.dtype, optional): Model's parameters and
            buffers to the desired type. Defaults to None.
        patch_torch_functions (bool, optional): Patch all Torch functions
            and Tensor methods to perform Tensor Core-friendly ops like GEMMs
            and convolutions in FP16, and any ops that benefit from FP32
            precision in FP32. Defaults to None.
        keep_batchnorm_fp32 (bool or str, optional): To enhance precision
            and enable cudnn batchnorm (which improves performance),
            it's often beneficial to keep batchnorm weights in FP32
            even if the rest of the model is FP16.
            If passed as a string, must be the string "True" or "False".
            Defaults to None.
        master_weights (bool, optional): Maintain FP32 master weights to
            accompany any FP16 model weights. FP32 master weights are stepped
            by the optimizer to enhance precision and capture small gradients.
            Defaults to None.
        cast_model_outputs (torch.dtype, optional): Option to ensure that
            the outputs of your model(s) are always cast to a particular type
            regardless of ``opt_level``. Defaults to None.
        num_losses (int): Option to tell Amp in advance how many
            losses/backward passes you plan to use. Defaults to 1.
        verbosity (int): Set to 0 to suppress Amp-related output.
            Defaults to 1.
        min_loss_scale (float, optional): Sets a floor for the loss scale
            values that can be chosen by dynamic loss scaling.
            The default value of None means that no floor is imposed.
            If dynamic loss scaling is not used, `min_loss_scale` is ignored.
            Defaults to None.
        max_loss_scale (float, optional): Sets a ceiling for the loss scale
            values that can be chosen by dynamic loss scaling. If dynamic
            loss scaling is not used, `max_loss_scale` is ignored.
            Defaults to 2.**24.
        **kwargs: Keyword arguments passed to OptimWrapper.

    Note:
        If you use ``IterBasedRunner`` and enable gradient accumulation,
        the original `max_iters` should be multiplied by
        ``accumulative_counts``.

    Note:
        `New in version 0.6.0.`
    """  # noqa: E501

    def __init__(self,
                 opt_level: str = 'O1',
                 loss_scale: Union[float, str, None] = 'dynamic',
                 enabled: Optional[bool] = True,
                 cast_model_type: Optional[torch.dtype] = None,
                 patch_torch_functions: Optional[bool] = None,
                 keep_batchnorm_fp32: Union[bool, str, None] = None,
                 master_weights: Optional[bool] = None,
                 cast_model_outputs: Optional[torch.dtype] = None,
                 num_losses: int = 1,
                 verbosity: int = 1,
                 min_loss_scale: Optional[float] = None,
                 max_loss_scale: Optional[float] = 2.**24,
                 **kwargs):
        assert apex_amp is not None, \
            'Apex is not installed. Please check ' \
            'https://github.com/NVIDIA/apex#linux.'
        super().__init__(**kwargs)
        self.opt_level = opt_level
        self.loss_scale = loss_scale
        self.enabled = enabled
        self.cast_model_type = cast_model_type
        self.patch_torch_functions = patch_torch_functions
        self.keep_batchnorm_fp32 = keep_batchnorm_fp32
        self.master_weights = master_weights
        self.cast_model_outputs = cast_model_outputs
        self.num_losses = num_losses
        self.verbosity = verbosity
        self.min_loss_scale = min_loss_scale
        self.max_loss_scale = max_loss_scale
        self._apex_amp_state_dict = None

    def backward(self, loss: torch.Tensor, **kwargs) -> None:
        """Perform gradient back propagation with :attr:`loss_scaler`.

        Args:
            loss (torch.Tensor): The loss of current iteration.
            kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward`
        """
        with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss:
            scaled_loss.backward(**kwargs)
        self._inner_count += 1

    def state_dict(self) -> dict:
        """Get the state dictionary of :attr:`optimizer` and
        :attr:`apex_amp`.

        Based on the state dictionary of the optimizer, the returned state
        dictionary will add a key named "apex_amp".

        Returns:
            dict: The merged state dict of :attr:`apex_amp` and
            :attr:`optimizer`.
        """
        state_dict = self.optimizer.state_dict()
        state_dict['apex_amp'] = apex_amp.state_dict()
        return state_dict

    def load_state_dict(self, state_dict: dict) -> None:
        """Load and parse the state dictionary of :attr:`optimizer` and
        :attr:`apex_amp`.

        If state_dict contains "apex_amp", the :attr:`apex_amp` will
        load the corresponding keys. Otherwise, only the :attr:`optimizer`
        will load the state dictionary.

        Note:
            :meth:`load_state_dict` shuold be called after
            `apex_amp.initialize` is called.
        Args:
            state_dict (dict): The state dict of :attr:`optimizer` and
                :attr:`apex_amp`
        """
        if 'apex_amp' in state_dict:
            # when `apex_amp` is not initialized, calling `load_state_dict`
            # will raise an error, so we temporarily cache the apex_amp
            # part, and then load it into `apex_amp` after completing
            # the `apex_amp` initialization in `optim_context` method
            if hasattr(self.optimizer, '_amp_stash'):
                apex_amp.load_state_dict(state_dict.pop('apex_amp'))
            else:
                self._apex_amp_state_dict = state_dict.pop('apex_amp')
        self.optimizer.load_state_dict(state_dict)

    @contextmanager
    def optim_context(self, model: nn.Module):
        """Enables the context for mixed precision training, and enables the
        context for disabling gradient synchronization during gradient
        accumulation context.

        Args:
            model (nn.Module): The training model.
        """
        with super().optim_context(model):
            # when a given optimizer be passed through apex_amp.initialize,
            # the "_amp_stash" property will be added
            if not hasattr(self.optimizer, '_amp_stash'):
                if mmengine.model.wrappers.is_model_wrapper(model):
                    model = model.module
                model, self.optimizer = apex_amp.initialize(
                    model,
                    self.optimizer,
                    opt_level=self.opt_level,
                    loss_scale=self.loss_scale,
                    enabled=self.enabled,
                    cast_model_type=self.cast_model_type,
                    patch_torch_functions=self.patch_torch_functions,
                    keep_batchnorm_fp32=self.keep_batchnorm_fp32,
                    master_weights=self.master_weights,
                    cast_model_outputs=self.cast_model_outputs,
                    num_losses=self.num_losses,
                    verbosity=self.verbosity,
                    min_loss_scale=self.min_loss_scale,
                    max_loss_scale=self.max_loss_scale)
                # loading apex_amp state_dict after initialization of apex_amp
                if self._apex_amp_state_dict is not None:
                    apex_amp.load_state_dict(self._apex_amp_state_dict)
                    self._apex_amp_state_dict = None
            yield