File size: 6,364 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
from contextlib import contextmanager
from typing import Optional

import torch

from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
                             is_npu_available)
from mmengine.logging import print_log
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION


@contextmanager
def autocast(device_type: Optional[str] = None,
             dtype: Optional[torch.dtype] = None,
             enabled: bool = True,
             cache_enabled: Optional[bool] = None):
    """A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``.

    Pytorch 1.5.0 provide ``torch.cuda.amp.autocast`` for running in
    mixed precision , and update it to ``torch.autocast`` in 1.10.0.
    Both interfaces have different arguments, and ``torch.autocast``
    support running with cpu additionally.

    This function provides a unified interface by wrapping
    ``torch.autocast`` and ``torch.cuda.amp.autocast``, which resolves the
    compatibility issues that ``torch.cuda.amp.autocast`` does not support
    running mixed precision with cpu, and both contexts have different
    arguments. We suggest users using this function in the code
    to achieve maximized compatibility of different PyTorch versions.

    Note:
        ``autocast`` requires pytorch version >= 1.5.0. If pytorch version
        <= 1.10.0 and cuda is not available, it will raise an error with
        ``enabled=True``, since ``torch.cuda.amp.autocast`` only support cuda
        mode.

    Examples:
         >>> # case1: 1.10 > Pytorch version >= 1.5.0
         >>> with autocast():
         >>>    # run in mixed precision context
         >>>    pass
         >>> with autocast(device_type='cpu')::
         >>>    # raise error, torch.cuda.amp.autocast only support cuda mode.
         >>>    pass
         >>> # case2: Pytorch version >= 1.10.0
         >>> with autocast():
         >>>    # default cuda mixed precision context
         >>>    pass
         >>> with autocast(device_type='cpu'):
         >>>    # cpu mixed precision context
         >>>    pass
         >>> with autocast(
         >>>     device_type='cuda', enabled=True, cache_enabled=True):
         >>>    # enable precision context with more specific arguments.
         >>>    pass

    Args:
        device_type (str, required):  Whether to use 'cuda' or 'cpu' device.
        enabled(bool):  Whether autocasting should be enabled in the region.
            Defaults to True
        dtype (torch_dtype, optional):  Whether to use ``torch.float16`` or
            ``torch.bfloat16``.
        cache_enabled(bool, optional):  Whether the weight cache inside
            autocast should be enabled.
    """
    # If `enabled` is True, enable an empty context and all calculations
    # are performed under fp32.
    assert digit_version(TORCH_VERSION) >= digit_version('1.5.0'), (
        'The minimum pytorch version requirements of mmengine is 1.5.0, but '
        f'got {TORCH_VERSION}')

    if (digit_version('1.5.0') <= digit_version(TORCH_VERSION) <
            digit_version('1.10.0')):
        # If pytorch version is between 1.5.0 and 1.10.0, the default value of
        # dtype for `torch.cuda.amp.autocast` is torch.float16.
        assert (
            device_type == 'cuda' or device_type == 'mlu'
            or device_type is None), (
                'Pytorch version under 1.10.0 only supports running automatic '
                'mixed training with cuda or mlu')
        if dtype is not None or cache_enabled is not None:
            print_log(
                f'{dtype} and {device_type} will not work for '
                '`autocast` since your Pytorch version: '
                f'{TORCH_VERSION} <= 1.10.0',
                logger='current',
                level=logging.WARNING)

        if is_npu_available():
            with torch.npu.amp.autocast(enabled=enabled):
                yield
        elif is_mlu_available():
            with torch.mlu.amp.autocast(enabled=enabled):
                yield
        elif is_cuda_available():
            with torch.cuda.amp.autocast(enabled=enabled):
                yield
        else:
            if not enabled:
                yield
            else:
                raise RuntimeError(
                    'If pytorch versions is between 1.5.0 and 1.10, '
                    '`autocast` is only available in gpu mode')

    else:
        # Modified from https://github.com/pytorch/pytorch/blob/master/torch/amp/autocast_mode.py # noqa: E501
        # This code should update with the `torch.autocast`.
        if cache_enabled is None:
            cache_enabled = torch.is_autocast_cache_enabled()
        device = get_device()
        device_type = device if device_type is None else device_type

        if device_type == 'cuda':
            if dtype is None:
                dtype = torch.get_autocast_gpu_dtype()

            if dtype == torch.bfloat16 and not \
                    torch.cuda.is_bf16_supported():
                raise RuntimeError(
                    'Current CUDA Device does not support bfloat16. Please '
                    'switch dtype to float16.')

        elif device_type == 'cpu':
            if dtype is None:
                dtype = torch.bfloat16
            assert dtype == torch.bfloat16, (
                'In CPU autocast, only support `torch.bfloat16` dtype')

        elif device_type == 'mlu':
            pass

        elif device_type == 'npu':
            pass

        else:
            # Device like MPS does not support fp16 training or testing.
            # If an inappropriate device is set and fp16 is enabled, an error
            # will be thrown.
            if enabled is False:
                yield
                return
            else:
                raise ValueError('User specified autocast device_type must be '
                                 f'cuda or cpu, but got {device_type}')

        with torch.autocast(
                device_type=device_type,
                enabled=enabled,
                dtype=dtype,
                cache_enabled=cache_enabled):
            yield