|
|
|
|
|
|
|
import torch |
|
import os |
|
import torch.distributed as dist |
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
|
checkpoint_wrapper, |
|
CheckpointImpl, |
|
apply_activation_checkpointing, |
|
) |
|
|
|
from transformers.models.t5.modeling_t5 import T5Block |
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer |
|
from functools import partial |
|
|
|
non_reentrant_wrapper = partial( |
|
checkpoint_wrapper, |
|
checkpoint_impl=CheckpointImpl.NO_REENTRANT, |
|
) |
|
|
|
check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) |
|
|
|
|
|
def apply_fsdp_checkpointing(model): |
|
"""apply activation checkpointing to model |
|
returns None as model is updated directly |
|
""" |
|
print(f"--> applying fdsp activation checkpointing...") |
|
|
|
apply_activation_checkpointing( |
|
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn |
|
) |
|
|