## Below is an example yaml for BF16 mixed-precision training using PyTorch Fully Sharded Data Parallism (FSDP) with CPU offloading on 8 GPUs.
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
+distributed_type: FSDP
downcast_bf16: 'no'
dynamo_backend: 'NO'
+fsdp_config:
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
+ fsdp_backward_prefetch_policy: BACKWARD_PRE
+ fsdp_offload_params: true
+ fsdp_sharding_strategy: 1
+ fsdp_state_dict_type: FULL_STATE_DICT
+ fsdp_transformer_layer_cls_to_wrap: T5Block
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: bf16
num_machines: 1
+num_processes: 8
rdzv_backend: static
same_network: true
use_cpu: false
##
from accelerate import Accelerator accelerator = Accelerator() - model, optimizer, dataloader, scheduler = accelerator.prepare( - model, optimizer, dataloader, scheduler -) +model = accelerator.prepare(model) +# Optimizer can be any PyTorch optimizer class +optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr) +optimizer, dataloader, scheduler = accelerator.prepare( + optimizer, dataloader, scheduler +)## If the YAML was generated through the `accelerate config` command: ``` accelerate launch {script_name.py} {--arg1} {--arg2} ... ``` If the YAML is saved to a `~/config.yaml` file: ``` accelerate launch --config_file ~/config.yaml {script_name.py} {--arg1} {--arg2} ... ``` Or you can use `accelerate launch` with right configuration parameters and have no `config.yaml` file: ``` accelerate launch \ --use_fsdp \ --num_processes=8 \ --mixed_precision=bf16 \ --fsdp_sharding_strategy=1 \ --fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \ --fsdp_transformer_layer_cls_to_wrap=T5Block \ --fsdp_offload_params=true \ {script_name.py} {--arg1} {--arg2} ... ``` ## For PyTorch FDSP, you need to prepare the model first **before** preparing the optimizer since FSDP will shard parameters in-place and this will break any previously initialized optimizers. For transformer models, please use `TRANSFORMER_BASED_WRAP` auto wrap policy as shown in the config above. ## To learn more checkout the related documentation: - How to use Fully Sharded Data Parallelism - Accelerate Large Model Training using PyTorch Fully Sharded Data Parallel