Accelerate documentation
Fully Sharded Data Parallel utilities
Fully Sharded Data Parallel utilities
enable_fsdp_ram_efficient_loading
Enables RAM efficient loading of Hugging Face models for FSDP in the environment.
disable_fsdp_ram_efficient_loading
Disables RAM efficient loading of Hugging Face models for FSDP in the environment.
merge_fsdp_weights
accelerate.utils.merge_fsdp_weights
< source >( checkpoint_dir: str output_path: str safe_serialization: bool = True remove_checkpoint_dir: bool = False )
Parameters
- checkpoint_dir (
str) — The directory containing the FSDP checkpoints (can be either the model or optimizer). - output_path (
str) — The path to save the merged checkpoint. - safe_serialization (
bool, optional, defaults toTrue) — Whether to save the merged weights with safetensors (recommended). - remove_checkpoint_dir (
bool, optional, defaults toFalse) — Whether to remove the checkpoint directory after merging.
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
SHARDED_STATE_DICT was used for the model. Weights will be saved to {output_path}/model.safetensors if
safe_serialization else pytorch_model.bin.
Note: this is a CPU-bound process.
FullyShardedDataParallelPlugin
class accelerate.FullyShardedDataParallelPlugin
< source >( sharding_strategy: Union = None backward_prefetch: Union = None mixed_precision_policy: Union = None auto_wrap_policy: Union = None cpu_offload: Union = None ignored_modules: Optional = None state_dict_type: Union = None state_dict_config: Union = None optim_state_dict_config: Union = None limit_all_gathers: bool = True use_orig_params: bool = None param_init_fn: Optional = None sync_module_states: bool = None forward_prefetch: bool = None activation_checkpointing: bool = None cpu_ram_efficient_loading: bool = None transformer_cls_names_to_wrap: Optional = None min_num_params: Optional = None )
Parameters
- sharding_strategy (
Union[str, torch.distributed.fsdp.ShardingStrategy], defaults to'FULL_SHARD') — Sharding strategy to use. Should be either astror an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy. - backward_prefetch (
Union[str, torch.distributed.fsdp.BackwardPrefetch], defaults to'NO_PREFETCH') — Backward prefetch strategy to use. Should be either astror an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch. - mixed_precision_policy (
Optional[Union[dict, torch.distributed.fsdp.MixedPrecision]], defaults toNone) — A config to enable mixed precision training with FullyShardedDataParallel. If passing in adict, it should have the following keys:param_dtype,reduce_dtype, andbuffer_dtype. - auto_wrap_policy (
Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults toNO_WRAP) -- A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one oftransformer_based_wrap,size_based_wrap, orno_wrap. Seetorch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like. - cpu_offload (
Union[bool, torch.distributed.fsdp.CPUOffload], defaults toFalse) — Whether to offload parameters to CPU. Should be either aboolor an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload. - ignored_modules (
Optional[Iterable[torch.nn.Module]], defaults toNone) — A list of modules to ignore when wrapping with FSDP. - state_dict_type (
Union[str, torch.distributed.fsdp.StateDictType], defaults to'FULL_STATE_DICT') — State dict type to use. If a string, it must be one offull_state_dict,local_state_dict, orsharded_state_dict. - state_dict_config (
Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig], defaults toNone) — State dict config to use. Is determined based on thestate_dict_typeif not passed in. - optim_state_dict_config (
Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig], defaults toNone) — Optim state dict config to use. Is determined based on thestate_dict_typeif not passed in. - limit_all_gathers (
bool, defaults toTrue) — Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number of CUDA malloc retries. - use_orig_params (
bool, defaults toFalse) — Whether to use the original parameters for the optimizer. - param_init_fn (
Optional[Callable[[torch.nn.Module], None], defaults toNone) — ACallable[torch.nn.Module] -> Nonethat specifies how modules that are currently on the meta device should be initialized onto an actual device. Only applicable whensync_module_statesisTrue. By default is alambdawhich callsto_emptyon the module. - sync_module_states (
bool, defaults toFalse) — Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization. Defaults toFalseunlesscpu_ram_efficient_loadingisTrue, then will be forcibly enabled. - forward_prefetch (
bool, defaults toFalse) — Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. only use with Static graphs. - activation_checkpointing (
bool, defaults toFalse) — A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time for reduced memory usage. - cpu_ram_efficient_loading (
bool, defaults toNone) — If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. Only applicable for Transformers. When using this,sync_module_statesneeds to beTrue. - transformer_cls_names_to_wrap (
Optional[List[str]], defaults toNone) — A list of transformer layer class names to wrap. Only applicable whenauto_wrap_policyistransformer_based_wrap. - min_num_params (
Optional[int], defaults toNone) — The minimum number of parameters a module must have to be wrapped. Only applicable whenauto_wrap_policyissize_based_wrap.
This plugin is used to enable fully sharded data parallelism.
Given model, creates an auto_wrap_policy baesd on the passed in policy and if we can use the
transformer_cls_to_wrap
Sets the mixed precision policy for FSDP
Set the state dict config based on the StateDictType.