File size: 6,778 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from contextlib import ExitStack, contextmanager
from typing import Dict, Union

import torch
import torch.nn as nn
from torch.nn.parallel.distributed import DistributedDataParallel

from mmengine.device import get_device
from mmengine.optim import OptimWrapperDict
from mmengine.registry import MODEL_WRAPPERS
from .distributed import MMDistributedDataParallel


@MODEL_WRAPPERS.register_module()
class MMSeparateDistributedDataParallel(DistributedDataParallel):
    """A DistributedDataParallel wrapper for models in MMGeneration.

    In MMedting and MMGeneration there is a need to wrap different modules in
    the models with separate DistributedDataParallel. Otherwise, it will cause
    errors for GAN training. For example, the GAN model, usually has two
    submodules: generator and discriminator. If we wrap both of them in one
    standard DistributedDataParallel, it will cause errors during training,
    because when we update the parameters of the generator (or discriminator),
    the parameters of the discriminator (or generator) is not updated, which is
    not allowed for DistributedDataParallel. So we design this wrapper to
    separately wrap DistributedDataParallel for generator and discriminator.
    In this wrapper, we perform two operations:

    1. Wraps each module in the models with separate MMDistributedDataParallel.
       Note that only modules with parameters will be wrapped.
    2. Calls ``train_step``, ``val_step`` and ``test_step`` of submodules to
       get losses and predictions.

    Args:
        module (nn.Module): model contain multiple submodules which have
            separately updating strategy.
        broadcast_buffers (bool): Same as that in
            ``torch.nn.parallel.distributed.DistributedDataParallel``.
            Defaults to False.
        find_unused_parameters (bool): Same as that in
            ``torch.nn.parallel.distributed.DistributedDataParallel``.
            Traverse the autograd graph of all tensors contained in returned
            value of the wrapped module's forward function. Defaults to False.
        **kwargs: Keyword arguments passed to ``MMDistributedDataParallel``.

            - device_ids (List[int] or torch.device, optional): CUDA devices
              for module.
            - output_device (int or torch.device, optional): Device location of
              output for single-device CUDA modules.
            - dim (int): Defaults to 0.
            - process_group (ProcessGroup, optional): The process group to be
              used for distributed data all-reduction.
            - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults
              to 25.
            - check_reduction (bool): This argument is deprecated. Defaults
              to False.
            - gradient_as_bucket_view (bool): Defaults to False.
            - static_graph (bool): Defaults to False.

    See more information about arguments in
    :class:`torch.nn.parallel.DistributedDataParallel`.
    """

    def __init__(self,
                 module: nn.Module,
                 broadcast_buffers: bool = False,
                 find_unused_parameters: bool = False,
                 **kwargs):
        super(DistributedDataParallel, self).__init__()
        self.module = module
        device = get_device()
        # Wrap the submodule with parameters of `self.module` to
        # `MMDistributedDataParallel`
        for name, sub_module in module._modules.items():
            # module without parameters.
            if next(sub_module.parameters(), None) is None:
                sub_module = sub_module.to(device)
            elif all(not p.requires_grad for p in sub_module.parameters()):
                sub_module = sub_module.to(device)
            else:
                sub_module = MMDistributedDataParallel(
                    module=sub_module.to(device),
                    broadcast_buffers=broadcast_buffers,
                    find_unused_parameters=find_unused_parameters,
                    **kwargs)
            module._modules[name] = sub_module

    def train_step(self, data: Union[dict, tuple, list],
                   optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]:
        """Interface for model forward, backward and parameters updating during
        training process.

        Args:
            data (dict or tuple or list): Data sampled from dataset.
            optim_wrapper (OptimWrapperDict): A wrapper of optimizer to
                update parameters.

        Returns:
            Dict[str, torch.Tensor]: A dict of tensor for logging.
        """
        return self.module.train_step(data, optim_wrapper)

    def val_step(self, data: Union[dict, tuple, list]) -> list:
        """Gets the prediction of module during validation process.

        Args:
            data (dict or tuple or list): Data sampled from dataset.

        Returns:
            list: The predictions of given data.
        """
        return self.module.val_step(data)

    def test_step(self, data: Union[dict, tuple, list]) -> list:
        """Gets the predictions of module during testing process.

        Args:
            data (dict or tuple or list): Data sampled from dataset.

        Returns:
            list: The predictions of given data.
        """
        return self.module.test_step(data)

    @contextmanager
    def no_sync(self):
        """Enables ``no_sync`` context of all sub ``MMDistributedDataParallel``
        modules."""
        with ExitStack() as stack:
            for sub_ddp_model in self.module._modules.values():
                stack.enter_context(sub_ddp_model.no_sync())
                yield

    def train(self, mode: bool = True) -> 'MMSeparateDistributedDataParallel':
        """Sets the module in training mode.

        In order to make the ddp wrapper inheritance hierarchy more uniform,
        ``MMSeparateDistributedDataParallel`` inherits from
        ``DistributedDataParallel``, but will not call its constructor.
        Since the attributes of ``DistributedDataParallel`` have not been
        initialized, call the ``train`` method of ``DistributedDataParallel``
        will raise an error if pytorch version <= 1.9. Therefore, override
        this method to call the ``train`` method of submodules.

        Args:
            mode (bool): whether to set training mode (``True``) or evaluation
                mode (``False``). Defaults to ``True``.

        Returns:
            Module: self.
        """
        self.training = mode
        self.module.train(mode)
        return self