File size: 19,290 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
"""

This file includes public APIs for FSDP such as the classes used for the

constructor arguments.

"""

from dataclasses import dataclass
from enum import auto, Enum

from typing import Optional, Sequence, Type

import torch
from torch.nn.modules.batchnorm import _BatchNorm

__all__ = [
    "ShardingStrategy",
    "BackwardPrefetch",
    "MixedPrecision",
    "CPUOffload",
    "StateDictType",
    "StateDictConfig",
    "FullStateDictConfig",
    "LocalStateDictConfig",
    "ShardedStateDictConfig",
    "OptimStateDictConfig",
    "FullOptimStateDictConfig",
    "LocalOptimStateDictConfig",
    "ShardedOptimStateDictConfig",
    "StateDictSettings",
]


class ShardingStrategy(Enum):
    """

    This specifies the sharding strategy to be used for distributed training by

    :class:`FullyShardedDataParallel`.



    - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded.

      For the parameters, this strategy unshards (via all-gather) before the

      forward, reshards after the forward, unshards before the backward

      computation, and reshards after the backward computation. For gradients,

      it synchronizes and shards them (via reduce-scatter) after the backward

      computation. The sharded optimizer states are updated locally per rank.

    - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during

      computation, and additionally, parameters are sharded outside

      computation. For the parameters, this strategy unshards before the

      forward, does not reshard them after the forward, and only reshards them

      after the backward computation. The sharded optimizer states are updated

      locally per rank. Inside ``no_sync()``, the parameters are not resharded

      after the backward computation.

    - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded

      but instead replicated across ranks similar to PyTorch's

      :class:`DistributedDataParallel` API. For gradients, this strategy

      synchronizes them (via all-reduce) after the backward computation. The

      unsharded optimizer states are updated locally per rank.

    - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across

      nodes. This results in reduced communication volume as expensive all-gathers and

      reduce-scatters are only done within a node, which can be more performant for medium

      -sized models.

    - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across

      nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput

      since the unsharded parameters are not freed after the forward pass, saving the

      all-gathers in the pre-backward.

    """

    FULL_SHARD = auto()
    SHARD_GRAD_OP = auto()
    NO_SHARD = auto()
    HYBRID_SHARD = auto()
    _HYBRID_SHARD_ZERO2 = auto()


class BackwardPrefetch(Enum):
    """

    This configures explicit backward prefetching, which improves throughput by

    enabling communication and computation overlap in the backward pass at the

    cost of slightly increased memory usage.



    - ``BACKWARD_PRE``: This enables the most overlap but increases memory

      usage the most. This prefetches the next set of parameters *before* the

      current set of parameters' gradient computation. This overlaps the *next

      all-gather* and the *current gradient computation*, and at the peak, it

      holds the current set of parameters, next set of parameters, and current

      set of gradients in memory.

    - ``BACKWARD_POST``: This enables less overlap but requires less memory

      usage. This prefetches the next set of parameters *after* the current

      set of parameters' gradient computation. This overlaps the *current

      reduce-scatter* and the *next gradient computation*, and it frees the

      current set of parameters before allocating memory for the next set of

      parameters, only holding the next set of parameters and current set of

      gradients in memory at the peak.

    - FSDP's ``backward_prefetch`` argument accepts ``None``, which disables

      the backward prefetching altogether. This has no overlap and does not

      increase memory usage. In general, we do not recommend this setting since

      it may degrade throughput significantly.



    For more technical context: For a single process group using NCCL backend,

    any collectives, even if issued from different streams, contend for the

    same per-device NCCL stream, which implies that the relative order in which

    the collectives are issued matters for overlapping. The two backward

    prefetching values correspond to different issue orders.

    """

    # NOTE: For both modes, the ordering that defines "current" and "next" is
    # not always exact in the current implementation. A mistargeted prefetch
    # simply means that the parameter memory is allocated earlier than needed,
    # possibly increasing peak memory usage, but does not affect correctness.
    BACKWARD_PRE = auto()
    BACKWARD_POST = auto()


@dataclass
class MixedPrecision:
    """

    This configures FSDP-native mixed precision training.



    Attributes:

        param_dtype (Optional[torch.dtype]): This specifies the dtype for model

            parameters during forward and backward and thus the dtype for

            forward and backward computation. Outside forward and backward, the

            *sharded* parameters are kept in full precision (e.g. for the

            optimizer step), and for model checkpointing, the parameters are

            always saved in full precision. (Default: ``None``)

        reduce_dtype (Optional[torch.dtype]): This specifies the dtype for

            gradient reduction (i.e. reduce-scatter or all-reduce). If this is

            ``None`` but ``param_dtype`` is not ``None``, then this takes on

            the ``param_dtype`` value, still running gradient reduction in low

            precision. This is permitted to differ from ``param_dtype``, e.g.

            to force gradient reduction to run in full precision. (Default:

            ``None``)

        buffer_dtype (Optional[torch.dtype]): This specifies the dtype for

            buffers. FSDP does not shard buffers. Rather, FSDP casts them to

            ``buffer_dtype`` in the first forward pass and keeps them in that

            dtype thereafter. For model checkpointing, the buffers are saved

            in full precision except for ``LOCAL_STATE_DICT``. (Default:

            ``None``)

        keep_low_precision_grads (bool): If ``False``, then FSDP upcasts

            gradients to full precision after the backward pass in preparation

            for the optimizer step. If ``True``, then FSDP keeps the gradients

            in the dtype used for gradient reduction, which can save memory if

            using a custom optimizer that supports running in low precision.

            (Default: ``False``)

        cast_forward_inputs (bool): If ``True``, then this FSDP module casts

            its forward args and kwargs to ``param_dtype``. This is to ensure

            that parameter and input dtypes match for forward computation, as

            required by many ops. This may need to be set to ``True`` when only

            applying mixed precision to some but not all FSDP modules, in which

            case a mixed-precision FSDP submodule needs to recast its inputs.

            (Default: ``False``)

        cast_root_forward_inputs (bool): If ``True``, then the root FSDP module

            casts its forward args and kwargs to ``param_dtype``, overriding

            the value of ``cast_forward_inputs``. For non-root FSDP modules,

            this does not do anything. (Default: ``True``)

        _module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies

            module classes to ignore for mixed precision when using an

            ``auto_wrap_policy``: Modules of these classes will have FSDP

            applied to them separately with mixed precision disabled (meaning

            that the final FSDP construction would deviate from the specified

            policy). If ``auto_wrap_policy`` is not specified, then this does

            not do anything. This API is experimental and subject to change.

            (Default: ``(_BatchNorm,)``)



    .. note:: This API is experimental and subject to change.



    .. note:: Only floating point tensors are cast to their specified dtypes.



    .. note:: In ``summon_full_params``, parameters are forced to full

        precision, but buffers are not.



    .. note:: Layer norm and batch norm accumulate in ``float32`` even when

        their inputs are in a low precision like ``float16`` or ``bfloat16``.

        Disabling FSDP's mixed precision for those norm modules only means that

        the affine parameters are kept in ``float32``. However, this incurs

        separate all-gathers and reduce-scatters for those norm modules, which

        may be inefficient, so if the workload permits, the user should prefer

        to still apply mixed precision to those modules.



    .. note:: By default, if the user passes a model with any ``_BatchNorm``

        modules and specifies an ``auto_wrap_policy``, then the batch norm

        modules will have FSDP applied to them separately with mixed precision

        disabled. See the ``_module_classes_to_ignore`` argument.



    .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and

        ``cast_forward_inputs=False`` by default. For the root FSDP instance,

        its ``cast_root_forward_inputs`` takes precedence over its

        ``cast_forward_inputs``. For non-root FSDP instances, their

        ``cast_root_forward_inputs`` values are ignored. The default setting is

        sufficient for the typical case where each FSDP instance has the same

        ``MixedPrecision`` configuration and only needs to cast inputs to the

        ``param_dtype`` at the beginning of the model's forward pass.



    .. note:: For nested FSDP instances with different ``MixedPrecision``

        configurations, we recommend setting individual ``cast_forward_inputs``

        values to configure casting inputs or not before each instance's

        forward. In such a case, since the casts happen before each FSDP

        instance's forward, a parent FSDP instance should have its non-FSDP

        submodules run before its FSDP submodules to avoid the activation dtype

        being changed due to a different ``MixedPrecision`` configuration.



        Example::



            >>> # xdoctest: +SKIP("undefined variables")

            >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))

            >>> model[1] = FSDP(

            >>>     model[1],

            >>>     mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),

            >>> )

            >>> model = FSDP(

            >>>     model,

            >>>     mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),

            >>> )



        The above shows a working example. On the other hand, if ``model[1]``

        were replaced with ``model[0]``, meaning that the submodule using

        different ``MixedPrecision`` ran its forward first, then ``model[1]``

        would incorrectly see ``float16`` activations instead of ``bfloat16``

        ones.



    """

    param_dtype: Optional[torch.dtype] = None
    reduce_dtype: Optional[torch.dtype] = None
    buffer_dtype: Optional[torch.dtype] = None
    keep_low_precision_grads: bool = False
    cast_forward_inputs: bool = False
    cast_root_forward_inputs: bool = True
    _module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,)


@dataclass
class CPUOffload:
    """

    This configures CPU offloading.



    Attributes:

        offload_params (bool): This specifies whether to offload parameters to

            CPU when not involved in computation. If ``True``, then this

            offloads gradients to CPU as well, meaning that the optimizer step

            runs on CPU.

    """

    offload_params: bool = False


class StateDictType(Enum):
    """

    This enum indicates that which type of ``state_dict`` the FSDP module is

    currently processing (returning or loading).

    The default value is FULL_STATE_DICT to comply the PyTorch convention.

    ..note::

        FSDP currently supports three types of ``state_dict``:

            1. ``state_dict/load_state_dict`: this pair of APIs return and load

               the non-sharded, unflattened parameters. The semantics is the

               same as using DDP.

            2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return

               and load local sharded, flattened parameters. The values returned

               by ``_local_state_dict`` can be directly used by FSDP and is only

               meaningful to FSDP (because parameters are flattened). Note that

               these APIs are meant for use via the :func:`state_dict_type`

               context manager as follows:

                   >>> # xdoctest: +SKIP("undefined variables")

                   >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT):

                   ...     state = fsdp.state_dict()  # loads local state dict

            3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs

               return and load sharded, unflattened parameters. The ``state_dict``

               return by ``sharded_state_dict`` can be used by all other parallel

               schemes (resharding may be required).

    """

    FULL_STATE_DICT = auto()
    LOCAL_STATE_DICT = auto()
    SHARDED_STATE_DICT = auto()


@dataclass
class StateDictConfig:
    """

    ``StateDictConfig`` is the base class for all ``state_dict`` configuration

    classes. Users should instantiate a child class (e.g.

    ``FullStateDictConfig``) in order to configure settings for the

    corresponding ``state_dict`` type supported by FSDP.



    Attributes:

        offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict

            values to CPU, and if ``False``, then FSDP keeps them on GPU.

            (Default: ``False``)

    """

    offload_to_cpu: bool = False


@dataclass
class FullStateDictConfig(StateDictConfig):
    """

    ``FullStateDictConfig`` is a config class meant to be used with

    ``StateDictType.FULL_STATE_DICT``. We recommend enabling both

    ``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state

    dicts to save GPU memory and CPU memory, respectively. This config class

    is meant to be used via the :func:`state_dict_type` context manager as

    follows:



        >>> # xdoctest: +SKIP("undefined variables")

        >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

        >>> fsdp = FSDP(model, auto_wrap_policy=...)

        >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)

        >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):

        >>>     state = fsdp.state_dict()

        >>>     # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.

        >>> # To reload checkpoint for inference, finetuning, transfer learning, etc:

        >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP

        >>> if dist.get_rank() == 0:

        >>>     # Load checkpoint only on rank 0 to avoid memory redundancy

        >>>     state_dict = torch.load("my_checkpoint.pt")

        >>>     model.load_state_dict(state_dict)

        >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument

        >>> # communicates loaded checkpoint states from rank 0 to rest of the world.

        >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)

        >>> # After this point, all ranks have FSDP model with loaded checkpoint.



    Attributes:

        rank0_only (bool): If ``True``, then only rank 0 saves the full state

            dict, and nonzero ranks save an empty dict. If ``False``, then all

            ranks save the full state dict. (Default: ``False``)

    """

    rank0_only: bool = False


@dataclass
class LocalStateDictConfig(StateDictConfig):
    pass


@dataclass
class ShardedStateDictConfig(StateDictConfig):
    """

    ``ShardedStateDictConfig`` is a config class meant to be used with

    ``StateDictType.SHARDED_STATE_DICT``.



    Attributes:

        _use_dtensor (bool): If ``True``, then FSDP saves the state dict values

            as ``DTensor``, and if ``False``, then FSDP saves them as

            ``ShardedTensor``. (Default: ``False``)



    .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig`

      and it is used by FSDP to determine the type of state dict values. Users should not

      manually modify ``_use_dtensor``.

    """

    _use_dtensor: bool = False


@dataclass
class OptimStateDictConfig:
    """

    ``OptimStateDictConfig`` is the base class for all ``optim_state_dict``

    configuration classes.  Users should instantiate a child class (e.g.

    ``FullOptimStateDictConfig``) in order to configure settings for the

    corresponding ``optim_state_dict`` type supported by FSDP.



    Attributes:

        offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's

            tensor values to CPU, and if ``False``, then FSDP keeps them on the

            original device (which is GPU unless parameter CPU offloading is

            enabled). (Default: ``True``)

    """

    offload_to_cpu: bool = True


@dataclass
class FullOptimStateDictConfig(OptimStateDictConfig):
    """

    Attributes:

        rank0_only (bool): If ``True``, then only rank 0 saves the full state

            dict, and nonzero ranks save an empty dict. If ``False``, then all

            ranks save the full state dict. (Default: ``False``)

    """

    rank0_only: bool = False


@dataclass
class LocalOptimStateDictConfig(OptimStateDictConfig):
    offload_to_cpu: bool = False


@dataclass
class ShardedOptimStateDictConfig(OptimStateDictConfig):
    """

    ``ShardedOptimStateDictConfig`` is a config class meant to be used with

    ``StateDictType.SHARDED_STATE_DICT``.



    Attributes:

        _use_dtensor (bool): If ``True``, then FSDP saves the state dict values

            as ``DTensor``, and if ``False``, then FSDP saves them as

            ``ShardedTensor``. (Default: ``False``)



    .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig`

      and it is used by FSDP to determine the type of state dict values. Users should not

      manually modify ``_use_dtensor``.

    """

    _use_dtensor: bool = False


@dataclass
class StateDictSettings:
    state_dict_type: StateDictType
    state_dict_config: StateDictConfig
    optim_state_dict_config: OptimStateDictConfig