Accelerate documentation
Fully Sharded Data Parallel utilities
Fully Sharded Data Parallel utilities
enable_fsdp_ram_efficient_loading
Enables RAM efficient loading of Hugging Face models for FSDP in the environment.
disable_fsdp_ram_efficient_loading
Disables RAM efficient loading of Hugging Face models for FSDP in the environment.
merge_fsdp_weights
accelerate.utils.merge_fsdp_weights
< source >( checkpoint_dir: str output_path: str safe_serialization: bool = True remove_checkpoint_dir: bool = False )
Parameters
- checkpoint_dir (
str) — The directory containing the FSDP checkpoints (can be either the model or optimizer). - output_path (
str) — The path to save the merged checkpoint. - safe_serialization (
bool, optional, defaults toTrue) — Whether to save the merged weights with safetensors (recommended). - remove_checkpoint_dir (
bool, optional, defaults toFalse) — Whether to remove the checkpoint directory after merging.
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
SHARDED_STATE_DICT was used for the model. Weights will be saved to {output_path}/model.safetensors if
safe_serialization else pytorch_model.bin.
Note: this is a CPU-bound process.
FullyShardedDataParallelPlugin
class accelerate.FullyShardedDataParallelPlugin
< source >( fsdp_version: int = None sharding_strategy: typing.Union[str, ForwardRef('torch.distributed.fsdp.ShardingStrategy')] = None reshard_after_forward: typing.Union[str, ForwardRef('torch.distributed.fsdp.ShardingStrategy'), bool] = None backward_prefetch: typing.Union[str, ForwardRef('torch.distributed.fsdp.BackwardPrefetch'), NoneType] = None mixed_precision_policy: typing.Union[dict, ForwardRef('torch.distributed.fsdp.MixedPrecision'), ForwardRef('torch.distributed.fsdp.MixedPrecisionPolicy'), NoneType] = None auto_wrap_policy: typing.Union[typing.Callable, typing.Literal['transformer_based_wrap', 'size_based_wrap', 'no_wrap'], NoneType] = None cpu_offload: typing.Union[bool, ForwardRef('torch.distributed.fsdp.CPUOffload'), ForwardRef('torch.distributed.fsdp.CPUOffloadPolicy')] = None ignored_modules: typing.Union[collections.abc.Iterable[torch.nn.modules.module.Module], str, NoneType] = None state_dict_type: typing.Union[str, ForwardRef('torch.distributed.fsdp.StateDictType')] = None state_dict_config: typing.Union[ForwardRef('torch.distributed.fsdp.FullStateDictConfig'), ForwardRef('torch.distributed.fsdp.ShardedStateDictConfig'), NoneType] = None optim_state_dict_config: typing.Union[ForwardRef('torch.distributed.fsdp.FullOptimStateDictConfig'), ForwardRef('torch.distributed.fsdp.ShardedOptimStateDictConfig'), NoneType] = None limit_all_gathers: bool = True use_orig_params: typing.Optional[bool] = None param_init_fn: typing.Optional[typing.Callable[[torch.nn.modules.module.Module], NoneType]] = None sync_module_states: typing.Optional[bool] = None forward_prefetch: bool = None activation_checkpointing: bool = None cpu_ram_efficient_loading: bool = None transformer_cls_names_to_wrap: typing.Optional[list[str]] = None min_num_params: typing.Optional[int] = None )
Parameters
- fsdp_version (
int, defaults to1) — The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to FSDP2 format. - sharding_strategy (
Union[str, torch.distributed.fsdp.ShardingStrategy], defaults to'FULL_SHARD') — Sharding strategy to use. Should be either astror an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy. Is deprecated in favor ofreshard_after_forward. - reshard_after_forward (
Union[str, torch.distributed.fsdp.ShardingStrategy, bool], defaults to'FULL_SHARD'forfsdp_version=1andTrueforfsdp_version=2) — Sharding strategy to use. Should be a bool iffsdp_versionis set to 2 else astror an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy. - backward_prefetch (
Union[str, torch.distributed.fsdp.BackwardPrefetch], defaults to'NO_PREFETCH') — Backward prefetch strategy to use. Should be either astror an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch. - mixed_precision_policy (
Optional[Union[dict, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]], defaults toNone) — A config to enable mixed precision training with FullyShardedDataParallel. If passing in adict, it should have the following keys:param_dtype,reduce_dtype, andbuffer_dtype, can be an instance oftorch.distributed.fsdp.MixedPrecisionPolicyiffsdp_versionis set to 2. - auto_wrap_policy (
Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults toNO_WRAP) -- A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one oftransformer_based_wrap,size_based_wrap, orno_wrap. Seetorch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like. - cpu_offload (
Union[bool, torch.distributed.fsdp.CPUOffload, torch.distributed.fsdp.CPUOffloadPolicy], defaults toFalse) — Whether to offload parameters to CPU. Should be either aboolor an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadortorch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicyiffsdp_versionis set to 2. - ignored_modules (
Optional[Union[Iterable[torch.nn.Module], str]], defaults toNone) — A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name using regex fullmatch. - state_dict_type (
Union[str, torch.distributed.fsdp.StateDictType], defaults to'FULL_STATE_DICT') — State dict type to use. If a string, it must be one offull_state_dict,local_state_dict, orsharded_state_dict. - state_dict_config (
Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig], defaults toNone) — State dict config to use. Is determined based on thestate_dict_typeif not passed in. - optim_state_dict_config (
Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig], defaults toNone) — Optim state dict config to use. Is determined based on thestate_dict_typeif not passed in. - limit_all_gathers (
bool, defaults toTrue) — Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number of CUDA malloc retries. - use_orig_params (
bool, defaults toFalse) — Whether to use the original parameters for the optimizer. - param_init_fn (
Optional[Callable[[torch.nn.Module], None], defaults toNone) — ACallable[torch.nn.Module] -> Nonethat specifies how modules that are currently on the meta device should be initialized onto an actual device. Only applicable whensync_module_statesisTrue. By default is alambdawhich callsto_emptyon the module. - sync_module_states (
bool, defaults toFalse) — Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization. Defaults toFalseunlesscpu_ram_efficient_loadingisTrue, then will be forcibly enabled. - forward_prefetch (
bool, defaults toFalse) — Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. only use with Static graphs. - activation_checkpointing (
bool, defaults toFalse) — A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time for reduced memory usage. - cpu_ram_efficient_loading (
bool, defaults toNone) — If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. Only applicable for Transformers. When using this,sync_module_statesneeds to beTrue. - transformer_cls_names_to_wrap (
Optional[List[str]], defaults toNone) — A list of transformer layer class names to wrap. Only applicable whenauto_wrap_policyistransformer_based_wrap. - min_num_params (
Optional[int], defaults toNone) — The minimum number of parameters a module must have to be wrapped. Only applicable whenauto_wrap_policyissize_based_wrap.
This plugin is used to enable fully sharded data parallelism.
Given model, creates an auto_wrap_policy baesd on the passed in policy and if we can use the
transformer_cls_to_wrap
Sets the mixed precision policy for FSDP
Set the state dict config based on the StateDictType.
Validates the mixed precision policy, abstracted away to not bring in the imports if not needed.
fsdp2_load_full_state_dict
accelerate.utils.fsdp2_load_full_state_dict
< source >( accelerator model: Module full_sd: dict )
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the parameters from rank 0 to all other ranks. This function modifies the model in-place.
fsdp2_switch_optimizer_parameters
accelerate.utils.fsdp2_switch_optimizer_parameters
< source >( optimizer: Optimizer mapping: dict )
Parameters
- optimizer (
torch.optim.Optimizer) — Optimizer instance which contains the original model parameters - mapping (
dict) — Mapping from the original parameter (specified bydata_ptr) to the sharded parameter
Raises
KeyError
KeyError— If a parameter in the optimizer couldn’t be switched to its sharded version. This should never happen and indicates a bug. If we kept the original params instead of raising, the training wouldn’t be numerically correct and weights wouldn’t get updated.
Switches the parameters of the optimizer to new ones (sharded parameters in usual case). This function modifies the optimizer in-place.
fsdp2_prepare_model
accelerate.utils.fsdp2_prepare_model
< source >( accelerator model: Module ) → torch.nn.Module
Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.