Spaces:
Sleeping
Sleeping
""" | |
This file includes public APIs for FSDP such as the classes used for the | |
constructor arguments. | |
""" | |
from dataclasses import dataclass | |
from enum import auto, Enum | |
from typing import Optional, Sequence, Type | |
import torch | |
from torch.nn.modules.batchnorm import _BatchNorm | |
__all__ = [ | |
"ShardingStrategy", | |
"BackwardPrefetch", | |
"MixedPrecision", | |
"CPUOffload", | |
"StateDictType", | |
"StateDictConfig", | |
"FullStateDictConfig", | |
"LocalStateDictConfig", | |
"ShardedStateDictConfig", | |
"OptimStateDictConfig", | |
"FullOptimStateDictConfig", | |
"LocalOptimStateDictConfig", | |
"ShardedOptimStateDictConfig", | |
"StateDictSettings", | |
] | |
class ShardingStrategy(Enum): | |
""" | |
This specifies the sharding strategy to be used for distributed training by | |
:class:`FullyShardedDataParallel`. | |
- ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded. | |
For the parameters, this strategy unshards (via all-gather) before the | |
forward, reshards after the forward, unshards before the backward | |
computation, and reshards after the backward computation. For gradients, | |
it synchronizes and shards them (via reduce-scatter) after the backward | |
computation. The sharded optimizer states are updated locally per rank. | |
- ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during | |
computation, and additionally, parameters are sharded outside | |
computation. For the parameters, this strategy unshards before the | |
forward, does not reshard them after the forward, and only reshards them | |
after the backward computation. The sharded optimizer states are updated | |
locally per rank. Inside ``no_sync()``, the parameters are not resharded | |
after the backward computation. | |
- ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded | |
but instead replicated across ranks similar to PyTorch's | |
:class:`DistributedDataParallel` API. For gradients, this strategy | |
synchronizes them (via all-reduce) after the backward computation. The | |
unsharded optimizer states are updated locally per rank. | |
- ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across | |
nodes. This results in reduced communication volume as expensive all-gathers and | |
reduce-scatters are only done within a node, which can be more performant for medium | |
-sized models. | |
- ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across | |
nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput | |
since the unsharded parameters are not freed after the forward pass, saving the | |
all-gathers in the pre-backward. | |
""" | |
FULL_SHARD = auto() | |
SHARD_GRAD_OP = auto() | |
NO_SHARD = auto() | |
HYBRID_SHARD = auto() | |
_HYBRID_SHARD_ZERO2 = auto() | |
class BackwardPrefetch(Enum): | |
""" | |
This configures explicit backward prefetching, which improves throughput by | |
enabling communication and computation overlap in the backward pass at the | |
cost of slightly increased memory usage. | |
- ``BACKWARD_PRE``: This enables the most overlap but increases memory | |
usage the most. This prefetches the next set of parameters *before* the | |
current set of parameters' gradient computation. This overlaps the *next | |
all-gather* and the *current gradient computation*, and at the peak, it | |
holds the current set of parameters, next set of parameters, and current | |
set of gradients in memory. | |
- ``BACKWARD_POST``: This enables less overlap but requires less memory | |
usage. This prefetches the next set of parameters *after* the current | |
set of parameters' gradient computation. This overlaps the *current | |
reduce-scatter* and the *next gradient computation*, and it frees the | |
current set of parameters before allocating memory for the next set of | |
parameters, only holding the next set of parameters and current set of | |
gradients in memory at the peak. | |
- FSDP's ``backward_prefetch`` argument accepts ``None``, which disables | |
the backward prefetching altogether. This has no overlap and does not | |
increase memory usage. In general, we do not recommend this setting since | |
it may degrade throughput significantly. | |
For more technical context: For a single process group using NCCL backend, | |
any collectives, even if issued from different streams, contend for the | |
same per-device NCCL stream, which implies that the relative order in which | |
the collectives are issued matters for overlapping. The two backward | |
prefetching values correspond to different issue orders. | |
""" | |
# NOTE: For both modes, the ordering that defines "current" and "next" is | |
# not always exact in the current implementation. A mistargeted prefetch | |
# simply means that the parameter memory is allocated earlier than needed, | |
# possibly increasing peak memory usage, but does not affect correctness. | |
BACKWARD_PRE = auto() | |
BACKWARD_POST = auto() | |
class MixedPrecision: | |
""" | |
This configures FSDP-native mixed precision training. | |
Attributes: | |
param_dtype (Optional[torch.dtype]): This specifies the dtype for model | |
parameters during forward and backward and thus the dtype for | |
forward and backward computation. Outside forward and backward, the | |
*sharded* parameters are kept in full precision (e.g. for the | |
optimizer step), and for model checkpointing, the parameters are | |
always saved in full precision. (Default: ``None``) | |
reduce_dtype (Optional[torch.dtype]): This specifies the dtype for | |
gradient reduction (i.e. reduce-scatter or all-reduce). If this is | |
``None`` but ``param_dtype`` is not ``None``, then this takes on | |
the ``param_dtype`` value, still running gradient reduction in low | |
precision. This is permitted to differ from ``param_dtype``, e.g. | |
to force gradient reduction to run in full precision. (Default: | |
``None``) | |
buffer_dtype (Optional[torch.dtype]): This specifies the dtype for | |
buffers. FSDP does not shard buffers. Rather, FSDP casts them to | |
``buffer_dtype`` in the first forward pass and keeps them in that | |
dtype thereafter. For model checkpointing, the buffers are saved | |
in full precision except for ``LOCAL_STATE_DICT``. (Default: | |
``None``) | |
keep_low_precision_grads (bool): If ``False``, then FSDP upcasts | |
gradients to full precision after the backward pass in preparation | |
for the optimizer step. If ``True``, then FSDP keeps the gradients | |
in the dtype used for gradient reduction, which can save memory if | |
using a custom optimizer that supports running in low precision. | |
(Default: ``False``) | |
cast_forward_inputs (bool): If ``True``, then this FSDP module casts | |
its forward args and kwargs to ``param_dtype``. This is to ensure | |
that parameter and input dtypes match for forward computation, as | |
required by many ops. This may need to be set to ``True`` when only | |
applying mixed precision to some but not all FSDP modules, in which | |
case a mixed-precision FSDP submodule needs to recast its inputs. | |
(Default: ``False``) | |
cast_root_forward_inputs (bool): If ``True``, then the root FSDP module | |
casts its forward args and kwargs to ``param_dtype``, overriding | |
the value of ``cast_forward_inputs``. For non-root FSDP modules, | |
this does not do anything. (Default: ``True``) | |
_module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies | |
module classes to ignore for mixed precision when using an | |
``auto_wrap_policy``: Modules of these classes will have FSDP | |
applied to them separately with mixed precision disabled (meaning | |
that the final FSDP construction would deviate from the specified | |
policy). If ``auto_wrap_policy`` is not specified, then this does | |
not do anything. This API is experimental and subject to change. | |
(Default: ``(_BatchNorm,)``) | |
.. note:: This API is experimental and subject to change. | |
.. note:: Only floating point tensors are cast to their specified dtypes. | |
.. note:: In ``summon_full_params``, parameters are forced to full | |
precision, but buffers are not. | |
.. note:: Layer norm and batch norm accumulate in ``float32`` even when | |
their inputs are in a low precision like ``float16`` or ``bfloat16``. | |
Disabling FSDP's mixed precision for those norm modules only means that | |
the affine parameters are kept in ``float32``. However, this incurs | |
separate all-gathers and reduce-scatters for those norm modules, which | |
may be inefficient, so if the workload permits, the user should prefer | |
to still apply mixed precision to those modules. | |
.. note:: By default, if the user passes a model with any ``_BatchNorm`` | |
modules and specifies an ``auto_wrap_policy``, then the batch norm | |
modules will have FSDP applied to them separately with mixed precision | |
disabled. See the ``_module_classes_to_ignore`` argument. | |
.. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and | |
``cast_forward_inputs=False`` by default. For the root FSDP instance, | |
its ``cast_root_forward_inputs`` takes precedence over its | |
``cast_forward_inputs``. For non-root FSDP instances, their | |
``cast_root_forward_inputs`` values are ignored. The default setting is | |
sufficient for the typical case where each FSDP instance has the same | |
``MixedPrecision`` configuration and only needs to cast inputs to the | |
``param_dtype`` at the beginning of the model's forward pass. | |
.. note:: For nested FSDP instances with different ``MixedPrecision`` | |
configurations, we recommend setting individual ``cast_forward_inputs`` | |
values to configure casting inputs or not before each instance's | |
forward. In such a case, since the casts happen before each FSDP | |
instance's forward, a parent FSDP instance should have its non-FSDP | |
submodules run before its FSDP submodules to avoid the activation dtype | |
being changed due to a different ``MixedPrecision`` configuration. | |
Example:: | |
>>> # xdoctest: +SKIP("undefined variables") | |
>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) | |
>>> model[1] = FSDP( | |
>>> model[1], | |
>>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), | |
>>> ) | |
>>> model = FSDP( | |
>>> model, | |
>>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), | |
>>> ) | |
The above shows a working example. On the other hand, if ``model[1]`` | |
were replaced with ``model[0]``, meaning that the submodule using | |
different ``MixedPrecision`` ran its forward first, then ``model[1]`` | |
would incorrectly see ``float16`` activations instead of ``bfloat16`` | |
ones. | |
""" | |
param_dtype: Optional[torch.dtype] = None | |
reduce_dtype: Optional[torch.dtype] = None | |
buffer_dtype: Optional[torch.dtype] = None | |
keep_low_precision_grads: bool = False | |
cast_forward_inputs: bool = False | |
cast_root_forward_inputs: bool = True | |
_module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,) | |
class CPUOffload: | |
""" | |
This configures CPU offloading. | |
Attributes: | |
offload_params (bool): This specifies whether to offload parameters to | |
CPU when not involved in computation. If ``True``, then this | |
offloads gradients to CPU as well, meaning that the optimizer step | |
runs on CPU. | |
""" | |
offload_params: bool = False | |
class StateDictType(Enum): | |
""" | |
This enum indicates that which type of ``state_dict`` the FSDP module is | |
currently processing (returning or loading). | |
The default value is FULL_STATE_DICT to comply the PyTorch convention. | |
..note:: | |
FSDP currently supports three types of ``state_dict``: | |
1. ``state_dict/load_state_dict`: this pair of APIs return and load | |
the non-sharded, unflattened parameters. The semantics is the | |
same as using DDP. | |
2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return | |
and load local sharded, flattened parameters. The values returned | |
by ``_local_state_dict`` can be directly used by FSDP and is only | |
meaningful to FSDP (because parameters are flattened). Note that | |
these APIs are meant for use via the :func:`state_dict_type` | |
context manager as follows: | |
>>> # xdoctest: +SKIP("undefined variables") | |
>>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): | |
... state = fsdp.state_dict() # loads local state dict | |
3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs | |
return and load sharded, unflattened parameters. The ``state_dict`` | |
return by ``sharded_state_dict`` can be used by all other parallel | |
schemes (resharding may be required). | |
""" | |
FULL_STATE_DICT = auto() | |
LOCAL_STATE_DICT = auto() | |
SHARDED_STATE_DICT = auto() | |
class StateDictConfig: | |
""" | |
``StateDictConfig`` is the base class for all ``state_dict`` configuration | |
classes. Users should instantiate a child class (e.g. | |
``FullStateDictConfig``) in order to configure settings for the | |
corresponding ``state_dict`` type supported by FSDP. | |
Attributes: | |
offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict | |
values to CPU, and if ``False``, then FSDP keeps them on GPU. | |
(Default: ``False``) | |
""" | |
offload_to_cpu: bool = False | |
class FullStateDictConfig(StateDictConfig): | |
""" | |
``FullStateDictConfig`` is a config class meant to be used with | |
``StateDictType.FULL_STATE_DICT``. We recommend enabling both | |
``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state | |
dicts to save GPU memory and CPU memory, respectively. This config class | |
is meant to be used via the :func:`state_dict_type` context manager as | |
follows: | |
>>> # xdoctest: +SKIP("undefined variables") | |
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
>>> fsdp = FSDP(model, auto_wrap_policy=...) | |
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) | |
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): | |
>>> state = fsdp.state_dict() | |
>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. | |
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc: | |
>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP | |
>>> if dist.get_rank() == 0: | |
>>> # Load checkpoint only on rank 0 to avoid memory redundancy | |
>>> state_dict = torch.load("my_checkpoint.pt") | |
>>> model.load_state_dict(state_dict) | |
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument | |
>>> # communicates loaded checkpoint states from rank 0 to rest of the world. | |
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) | |
>>> # After this point, all ranks have FSDP model with loaded checkpoint. | |
Attributes: | |
rank0_only (bool): If ``True``, then only rank 0 saves the full state | |
dict, and nonzero ranks save an empty dict. If ``False``, then all | |
ranks save the full state dict. (Default: ``False``) | |
""" | |
rank0_only: bool = False | |
class LocalStateDictConfig(StateDictConfig): | |
pass | |
class ShardedStateDictConfig(StateDictConfig): | |
""" | |
``ShardedStateDictConfig`` is a config class meant to be used with | |
``StateDictType.SHARDED_STATE_DICT``. | |
Attributes: | |
_use_dtensor (bool): If ``True``, then FSDP saves the state dict values | |
as ``DTensor``, and if ``False``, then FSDP saves them as | |
``ShardedTensor``. (Default: ``False``) | |
.. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig` | |
and it is used by FSDP to determine the type of state dict values. Users should not | |
manually modify ``_use_dtensor``. | |
""" | |
_use_dtensor: bool = False | |
class OptimStateDictConfig: | |
""" | |
``OptimStateDictConfig`` is the base class for all ``optim_state_dict`` | |
configuration classes. Users should instantiate a child class (e.g. | |
``FullOptimStateDictConfig``) in order to configure settings for the | |
corresponding ``optim_state_dict`` type supported by FSDP. | |
Attributes: | |
offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's | |
tensor values to CPU, and if ``False``, then FSDP keeps them on the | |
original device (which is GPU unless parameter CPU offloading is | |
enabled). (Default: ``True``) | |
""" | |
offload_to_cpu: bool = True | |
class FullOptimStateDictConfig(OptimStateDictConfig): | |
""" | |
Attributes: | |
rank0_only (bool): If ``True``, then only rank 0 saves the full state | |
dict, and nonzero ranks save an empty dict. If ``False``, then all | |
ranks save the full state dict. (Default: ``False``) | |
""" | |
rank0_only: bool = False | |
class LocalOptimStateDictConfig(OptimStateDictConfig): | |
offload_to_cpu: bool = False | |
class ShardedOptimStateDictConfig(OptimStateDictConfig): | |
""" | |
``ShardedOptimStateDictConfig`` is a config class meant to be used with | |
``StateDictType.SHARDED_STATE_DICT``. | |
Attributes: | |
_use_dtensor (bool): If ``True``, then FSDP saves the state dict values | |
as ``DTensor``, and if ``False``, then FSDP saves them as | |
``ShardedTensor``. (Default: ``False``) | |
.. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig` | |
and it is used by FSDP to determine the type of state dict values. Users should not | |
manually modify ``_use_dtensor``. | |
""" | |
_use_dtensor: bool = False | |
class StateDictSettings: | |
state_dict_type: StateDictType | |
state_dict_config: StateDictConfig | |
optim_state_dict_config: OptimStateDictConfig | |