Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
| import torch | |
| import os | |
| import torch.distributed as dist | |
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
| checkpoint_wrapper, | |
| CheckpointImpl, | |
| apply_activation_checkpointing, | |
| ) | |
| from transformers.models.t5.modeling_t5 import T5Block | |
| from transformers.models.llama.modeling_llama import LlamaDecoderLayer | |
| from functools import partial | |
| non_reentrant_wrapper = partial( | |
| checkpoint_wrapper, | |
| checkpoint_impl=CheckpointImpl.NO_REENTRANT, | |
| ) | |
| check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) | |
| def apply_fsdp_checkpointing(model): | |
| """apply activation checkpointing to model | |
| returns None as model is updated directly | |
| """ | |
| print(f"--> applying fdsp activation checkpointing...") | |
| apply_activation_checkpointing( | |
| model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn | |
| ) | |