if trainer.is_fsdp_enabled: | |
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") | |
trainer.save_model(script_args.output_dir) | |
TPU | |
PyTorch XLA supports FSDP training for TPUs and it can be enabled by modifying the FSDP configuration file generated by accelerate config. |