Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from functools import wraps | |
from operator import attrgetter | |
from typing import List, Union | |
import torch | |
from torch.utils.checkpoint import checkpoint | |
def wrap_forward(forward): | |
def wrapper(*args): | |
return checkpoint(forward, *args) | |
return wrapper | |
def turn_on_activation_checkpointing(model: torch.nn.Module, | |
modules: Union[List[str], str]): | |
if isinstance(modules, str): | |
modules = [modules] | |
for module_name in modules: | |
module = attrgetter(module_name)(model) | |
module.forward = wrap_forward(module.forward) | |