Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from contextlib import contextmanager | |
| from typing import Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| # a circular import will be caused by | |
| # from mmengine.model.wrappers import is_model_wrapper | |
| import mmengine | |
| from mmengine.registry import OPTIM_WRAPPERS | |
| from .optimizer_wrapper import OptimWrapper | |
| try: | |
| import apex.amp as apex_amp | |
| except ImportError: | |
| apex_amp = None | |
| class ApexOptimWrapper(OptimWrapper): | |
| """A subclass of :class:`OptimWrapper` that supports automatic mixed | |
| precision training based on apex.amp. | |
| ``ApexOptimWrapper`` provides a unified interface with | |
| ``OptimWrapper``, so it can be used in the same way as ``OptimWrapper``. | |
| Warning: | |
| ``ApexOptimWrapper`` requires `nvidia apex <https://github.com/NVIDIA/apex>`_ | |
| Args: | |
| opt_level (str): Pure or mixed precision optimization level. Accepted | |
| values are "O0", "O1", "O2", and "O3". Defaults to "O1". | |
| loss_scale (float or str, optional): If passed as a string, must be a | |
| string representing a number, e.g., "128.0", or the string | |
| "dynamic". Defaults to "dynamic". | |
| enabled (bool): If False, renders all Amp calls no-ops, so your script | |
| should run as if Amp were not present. Defaults to True. | |
| cast_model_type (torch.dtype, optional): Model's parameters and | |
| buffers to the desired type. Defaults to None. | |
| patch_torch_functions (bool, optional): Patch all Torch functions | |
| and Tensor methods to perform Tensor Core-friendly ops like GEMMs | |
| and convolutions in FP16, and any ops that benefit from FP32 | |
| precision in FP32. Defaults to None. | |
| keep_batchnorm_fp32 (bool or str, optional): To enhance precision | |
| and enable cudnn batchnorm (which improves performance), | |
| it's often beneficial to keep batchnorm weights in FP32 | |
| even if the rest of the model is FP16. | |
| If passed as a string, must be the string "True" or "False". | |
| Defaults to None. | |
| master_weights (bool, optional): Maintain FP32 master weights to | |
| accompany any FP16 model weights. FP32 master weights are stepped | |
| by the optimizer to enhance precision and capture small gradients. | |
| Defaults to None. | |
| cast_model_outputs (torch.dtype, optional): Option to ensure that | |
| the outputs of your model(s) are always cast to a particular type | |
| regardless of ``opt_level``. Defaults to None. | |
| num_losses (int): Option to tell Amp in advance how many | |
| losses/backward passes you plan to use. Defaults to 1. | |
| verbosity (int): Set to 0 to suppress Amp-related output. | |
| Defaults to 1. | |
| min_loss_scale (float, optional): Sets a floor for the loss scale | |
| values that can be chosen by dynamic loss scaling. | |
| The default value of None means that no floor is imposed. | |
| If dynamic loss scaling is not used, `min_loss_scale` is ignored. | |
| Defaults to None. | |
| max_loss_scale (float, optional): Sets a ceiling for the loss scale | |
| values that can be chosen by dynamic loss scaling. If dynamic | |
| loss scaling is not used, `max_loss_scale` is ignored. | |
| Defaults to 2.**24. | |
| **kwargs: Keyword arguments passed to OptimWrapper. | |
| Note: | |
| If you use ``IterBasedRunner`` and enable gradient accumulation, | |
| the original `max_iters` should be multiplied by | |
| ``accumulative_counts``. | |
| Note: | |
| `New in version 0.6.0.` | |
| """ # noqa: E501 | |
| def __init__(self, | |
| opt_level: str = 'O1', | |
| loss_scale: Union[float, str, None] = 'dynamic', | |
| enabled: Optional[bool] = True, | |
| cast_model_type: Optional[torch.dtype] = None, | |
| patch_torch_functions: Optional[bool] = None, | |
| keep_batchnorm_fp32: Union[bool, str, None] = None, | |
| master_weights: Optional[bool] = None, | |
| cast_model_outputs: Optional[torch.dtype] = None, | |
| num_losses: int = 1, | |
| verbosity: int = 1, | |
| min_loss_scale: Optional[float] = None, | |
| max_loss_scale: Optional[float] = 2.**24, | |
| **kwargs): | |
| assert apex_amp is not None, \ | |
| 'Apex is not installed. Please check ' \ | |
| 'https://github.com/NVIDIA/apex#linux.' | |
| super().__init__(**kwargs) | |
| self.opt_level = opt_level | |
| self.loss_scale = loss_scale | |
| self.enabled = enabled | |
| self.cast_model_type = cast_model_type | |
| self.patch_torch_functions = patch_torch_functions | |
| self.keep_batchnorm_fp32 = keep_batchnorm_fp32 | |
| self.master_weights = master_weights | |
| self.cast_model_outputs = cast_model_outputs | |
| self.num_losses = num_losses | |
| self.verbosity = verbosity | |
| self.min_loss_scale = min_loss_scale | |
| self.max_loss_scale = max_loss_scale | |
| self._apex_amp_state_dict = None | |
| def backward(self, loss: torch.Tensor, **kwargs) -> None: | |
| """Perform gradient back propagation with :attr:`loss_scaler`. | |
| Args: | |
| loss (torch.Tensor): The loss of current iteration. | |
| kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` | |
| """ | |
| with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss: | |
| scaled_loss.backward(**kwargs) | |
| self._inner_count += 1 | |
| def state_dict(self) -> dict: | |
| """Get the state dictionary of :attr:`optimizer` and | |
| :attr:`apex_amp`. | |
| Based on the state dictionary of the optimizer, the returned state | |
| dictionary will add a key named "apex_amp". | |
| Returns: | |
| dict: The merged state dict of :attr:`apex_amp` and | |
| :attr:`optimizer`. | |
| """ | |
| state_dict = self.optimizer.state_dict() | |
| state_dict['apex_amp'] = apex_amp.state_dict() | |
| return state_dict | |
| def load_state_dict(self, state_dict: dict) -> None: | |
| """Load and parse the state dictionary of :attr:`optimizer` and | |
| :attr:`apex_amp`. | |
| If state_dict contains "apex_amp", the :attr:`apex_amp` will | |
| load the corresponding keys. Otherwise, only the :attr:`optimizer` | |
| will load the state dictionary. | |
| Note: | |
| :meth:`load_state_dict` shuold be called after | |
| `apex_amp.initialize` is called. | |
| Args: | |
| state_dict (dict): The state dict of :attr:`optimizer` and | |
| :attr:`apex_amp` | |
| """ | |
| if 'apex_amp' in state_dict: | |
| # when `apex_amp` is not initialized, calling `load_state_dict` | |
| # will raise an error, so we temporarily cache the apex_amp | |
| # part, and then load it into `apex_amp` after completing | |
| # the `apex_amp` initialization in `optim_context` method | |
| if hasattr(self.optimizer, '_amp_stash'): | |
| apex_amp.load_state_dict(state_dict.pop('apex_amp')) | |
| else: | |
| self._apex_amp_state_dict = state_dict.pop('apex_amp') | |
| self.optimizer.load_state_dict(state_dict) | |
| def optim_context(self, model: nn.Module): | |
| """Enables the context for mixed precision training, and enables the | |
| context for disabling gradient synchronization during gradient | |
| accumulation context. | |
| Args: | |
| model (nn.Module): The training model. | |
| """ | |
| with super().optim_context(model): | |
| # when a given optimizer be passed through apex_amp.initialize, | |
| # the "_amp_stash" property will be added | |
| if not hasattr(self.optimizer, '_amp_stash'): | |
| if mmengine.model.wrappers.is_model_wrapper(model): | |
| model = model.module | |
| model, self.optimizer = apex_amp.initialize( | |
| model, | |
| self.optimizer, | |
| opt_level=self.opt_level, | |
| loss_scale=self.loss_scale, | |
| enabled=self.enabled, | |
| cast_model_type=self.cast_model_type, | |
| patch_torch_functions=self.patch_torch_functions, | |
| keep_batchnorm_fp32=self.keep_batchnorm_fp32, | |
| master_weights=self.master_weights, | |
| cast_model_outputs=self.cast_model_outputs, | |
| num_losses=self.num_losses, | |
| verbosity=self.verbosity, | |
| min_loss_scale=self.min_loss_scale, | |
| max_loss_scale=self.max_loss_scale) | |
| # loading apex_amp state_dict after initialization of apex_amp | |
| if self._apex_amp_state_dict is not None: | |
| apex_amp.load_state_dict(self._apex_amp_state_dict) | |
| self._apex_amp_state_dict = None | |
| yield | |