Spaces:
Sleeping
Sleeping
import contextlib | |
import platform | |
import uuid | |
import warnings | |
import weakref | |
from collections import defaultdict | |
from itertools import count | |
from typing import ( | |
Any, | |
Callable, | |
ContextManager, | |
DefaultDict, | |
Dict, | |
Iterable, | |
List, | |
Optional, | |
Tuple, | |
) | |
from weakref import ReferenceType | |
import torch | |
import torch.fx.traceback as fx_traceback | |
from torch._functorch._aot_autograd.functional_utils import is_fun | |
from torch.utils._pytree import tree_map | |
from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode | |
from torch.utils._python_dispatch import TorchDispatchMode | |
__all__ = [ | |
"checkpoint", | |
"checkpoint_sequential", | |
"CheckpointError", | |
"CheckpointFunction", | |
"check_backward_validity", | |
"detach_variable", | |
"get_device_states", | |
"set_device_states", | |
"noop_context_fn", | |
"set_checkpoint_early_stop", | |
"DefaultDeviceType", | |
"set_checkpoint_debug_enabled", | |
] | |
_DEFAULT_DETERMINISM_MODE = "default" | |
_checkpoint_debug_enabled: Optional[bool] = None | |
def set_checkpoint_debug_enabled(enabled: Optional[bool]): | |
""" | |
Context manager that sets whether checkpoint should print additional debug | |
information when running. See the ``debug`` flag for | |
:func:`~torch.utils.checkpoint.checkpoint` for more information. Note that | |
when set, this context manager overrides the value of ``debug`` passed to | |
checkpoint. To defer to the local setting, pass ``None`` to this context. | |
Args: | |
enabled (bool): Whether checkpoint should print debug information. | |
Default is 'None'. | |
""" | |
global _checkpoint_debug_enabled | |
try: | |
prev = _checkpoint_debug_enabled | |
_checkpoint_debug_enabled = enabled | |
yield | |
finally: | |
_checkpoint_debug_enabled = prev | |
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: | |
if isinstance(inputs, tuple): | |
out = [] | |
for inp in inputs: | |
if not isinstance(inp, torch.Tensor): | |
out.append(inp) | |
continue | |
x = inp.detach() | |
x.requires_grad = inp.requires_grad | |
out.append(x) | |
return tuple(out) | |
else: | |
raise RuntimeError( | |
"Only tuple of tensors is supported. Got Unsupported input type: ", | |
type(inputs).__name__, | |
) | |
def check_backward_validity(inputs: Iterable[Any]) -> None: | |
if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): | |
warnings.warn( | |
"None of the inputs have requires_grad=True. Gradients will be None" | |
) | |
def _get_device_module(device="cuda"): | |
device_module = getattr(torch, device) | |
return device_module | |
class DefaultDeviceType: | |
r""" | |
A class that manages the default device type for checkpointing. | |
If no non-CPU tensors are present, the default device type will | |
be used. The default value is 'cuda'. The device type is used in | |
the checkpointing process when determining which device states | |
to save and restore for recomputation. | |
""" | |
_default_device_type = "cuda" | |
def set_device_type(device: str = "cuda"): | |
""" | |
Set the default device type for checkpointing. | |
Args: | |
device (str): The device type to be set as default. Default is 'cuda'. | |
""" | |
DefaultDeviceType._default_device_type = device | |
def get_device_type() -> str: | |
""" | |
Get the current default device type for checkpointing. | |
Returns: | |
str: The current default device type. | |
""" | |
return DefaultDeviceType._default_device_type | |
def _infer_device_type(*args): | |
device_types = list( | |
{ | |
arg.device.type | |
for arg in args | |
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu" | |
} | |
) | |
if len(device_types) > 1: | |
warnings.warn( | |
"Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. " | |
"Device state will only be saved for devices of a single device type, and the remaining " | |
"devices will be ignored. Consequently, if any checkpointed functions involve randomness, " | |
"this may result in incorrect gradients. (Note that if CUDA devices are among the devices " | |
"detected, it will be prioritized; otherwise, the first device encountered will be selected.)" | |
) | |
if len(device_types) == 0: | |
return DefaultDeviceType.get_device_type() | |
elif "cuda" in device_types: | |
return "cuda" | |
else: | |
return device_types[0] | |
# We can't know if the run_fn will internally move some args to different devices, | |
# which would require logic to preserve rng states for those devices as well. | |
# We could paranoically stash and restore ALL the rng states for all visible devices, | |
# but that seems very wasteful for most cases. Compromise: Stash the RNG state for | |
# the device of all Tensor args. | |
# | |
# To consider: maybe get_device_states and set_device_states should reside in torch/random.py? | |
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: | |
# This will not error out if "arg" is a CPU tensor or a non-tensor type because | |
# the conditionals short-circuit. | |
fwd_device_ids = list( | |
{ | |
arg.get_device() | |
for arg in args | |
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu" | |
} | |
) | |
fwd_device_states = [] | |
device_module = _get_device_module(_infer_device_type(*args)) | |
for device_id in fwd_device_ids: | |
with device_module.device(device_id): | |
fwd_device_states.append(device_module.get_rng_state()) | |
return fwd_device_ids, fwd_device_states | |
def set_device_states(devices, states) -> None: | |
device_module = _get_device_module(_infer_device_type(*states)) | |
for device, state in zip(devices, states): | |
with device_module.device(device): | |
device_module.set_rng_state(state) | |
def _get_autocast_kwargs(device="cuda"): | |
if device == "cuda": | |
device_autocast_kwargs = { | |
"enabled": torch.is_autocast_enabled(), | |
"dtype": torch.get_autocast_gpu_dtype(), | |
"cache_enabled": torch.is_autocast_cache_enabled(), | |
} | |
elif _supports_autocast(device): | |
device_module = _get_device_module(device) | |
device_autocast_kwargs = { | |
"enabled": device_module.is_autocast_enabled(), | |
"dtype": device_module.get_autocast_dtype(), | |
"cache_enabled": torch.is_autocast_cache_enabled(), | |
} | |
else: | |
device_autocast_kwargs = None | |
cpu_autocast_kwargs = { | |
"enabled": torch.is_autocast_cpu_enabled(), | |
"dtype": torch.get_autocast_cpu_dtype(), | |
"cache_enabled": torch.is_autocast_cache_enabled(), | |
} | |
return device_autocast_kwargs, cpu_autocast_kwargs | |
def _supports_autocast(device): | |
device_module = _get_device_module(device) | |
return device == "cuda" or (hasattr(device_module, "is_autocast_enabled") | |
and hasattr(device_module, "get_autocast_dtype")) | |
class CheckpointFunction(torch.autograd.Function): | |
def forward(ctx, run_function, preserve_rng_state, *args): | |
check_backward_validity(args) | |
ctx.run_function = run_function | |
ctx.preserve_rng_state = preserve_rng_state | |
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. | |
ctx.device = _infer_device_type(*args) | |
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs( | |
ctx.device | |
) | |
if preserve_rng_state: | |
ctx.fwd_cpu_state = torch.get_rng_state() | |
# Don't eagerly initialize the cuda context by accident. | |
# (If the user intends that the context is initialized later, within their | |
# run_function, we SHOULD actually stash the cuda state here. Unfortunately, | |
# we have no way to anticipate this will happen before we run the function.) | |
ctx.had_device_in_fwd = False | |
device_module = _get_device_module(ctx.device) | |
if getattr(device_module, "_initialized", False): | |
ctx.had_device_in_fwd = True | |
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args) | |
# Save non-tensor inputs in ctx, keep a placeholder None for tensors | |
# to be filled out during the backward. | |
ctx.inputs = [] | |
ctx.tensor_indices = [] | |
tensor_inputs = [] | |
for i, arg in enumerate(args): | |
if torch.is_tensor(arg): | |
tensor_inputs.append(arg) | |
ctx.tensor_indices.append(i) | |
ctx.inputs.append(None) | |
else: | |
ctx.inputs.append(arg) | |
ctx.save_for_backward(*tensor_inputs) | |
with torch.no_grad(): | |
outputs = run_function(*args) | |
return outputs | |
def backward(ctx, *args): | |
if not torch.autograd._is_checkpoint_valid(): | |
raise RuntimeError( | |
"Checkpointing is not compatible with .grad() or when an `inputs` parameter" | |
" is passed to .backward(). Please use .backward() and do not pass its `inputs`" | |
" argument." | |
) | |
# Copy the list to avoid modifying original list. | |
inputs = list(ctx.inputs) | |
tensor_indices = ctx.tensor_indices | |
tensors = ctx.saved_tensors | |
device_module = _get_device_module(ctx.device) | |
# Fill in inputs with appropriate saved tensors. | |
for i, idx in enumerate(tensor_indices): | |
inputs[idx] = tensors[i] | |
# Stash the surrounding rng state, and mimic the state that was | |
# present at this time during forward. Restore the surrounding state | |
# when we're done. | |
rng_devices = [] | |
if ctx.preserve_rng_state and ctx.had_device_in_fwd: | |
rng_devices = ctx.fwd_devices | |
with torch.random.fork_rng( | |
devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device | |
): | |
if ctx.preserve_rng_state: | |
torch.set_rng_state(ctx.fwd_cpu_state) | |
if ctx.had_device_in_fwd: | |
set_device_states(ctx.fwd_devices, ctx.fwd_device_states) | |
detached_inputs = detach_variable(tuple(inputs)) | |
device_autocast_ctx = device_module.amp.autocast( | |
**ctx.device_autocast_kwargs | |
) if _supports_autocast(ctx.device) else contextlib.nullcontext() | |
with torch.enable_grad(), device_autocast_ctx, \ | |
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): | |
outputs = ctx.run_function(*detached_inputs) | |
if isinstance(outputs, torch.Tensor): | |
outputs = (outputs,) | |
# run backward() with only tensor that requires grad | |
outputs_with_grad = [] | |
args_with_grad = [] | |
for i in range(len(outputs)): | |
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: | |
outputs_with_grad.append(outputs[i]) | |
args_with_grad.append(args[i]) | |
if len(outputs_with_grad) == 0: | |
raise RuntimeError( | |
"none of output has requires_grad=True," | |
" this checkpoint() is not necessary" | |
) | |
torch.autograd.backward(outputs_with_grad, args_with_grad) | |
grads = tuple( | |
inp.grad if isinstance(inp, torch.Tensor) else None | |
for inp in detached_inputs | |
) | |
return (None, None) + grads | |
def noop_context_fn(): | |
return contextlib.nullcontext(), contextlib.nullcontext() | |
# TorchDynamo does not step inside utils.checkpoint function. The flow | |
# looks likes this | |
# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by | |
# speculatively checking if the forward function is safe to trace. | |
# 2) If yes, then Dynamo-generated Fx graph has the wrapped higher | |
# order op. As a result, TorchDynamo does not look inside utils.checkpoint. | |
# 3) If not, then TorchDynamo falls back to eager by performing a graph | |
# break. And here, the following disable wrapper ensures that | |
# TorchDynamo does not trigger again on the frames created by | |
# utils.checkpoint innards. | |
def checkpoint( | |
function, | |
*args, | |
use_reentrant: Optional[bool] = None, | |
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, | |
determinism_check: str = _DEFAULT_DETERMINISM_MODE, | |
debug: bool = False, | |
**kwargs | |
): | |
r"""Checkpoint a model or part of the model. | |
Activation checkpointing is a technique that trades compute for memory. | |
Instead of keeping tensors needed for backward alive until they are used in | |
gradient computation during backward, forward computation in checkpointed | |
regions omits saving tensors for backward and recomputes them during the | |
backward pass. Activation checkpointing can be applied to any part of a | |
model. | |
There are currently two checkpointing implementations available, determined | |
by the :attr:`use_reentrant` parameter. It is recommended that you use | |
``use_reentrant=False``. Please refer the note below for a discussion of | |
their differences. | |
.. warning:: | |
If the :attr:`function` invocation during the backward pass differs | |
from the forward pass, e.g., due to a global variable, the checkpointed | |
version may not be equivalent, potentially causing an | |
error being raised or leading to silently incorrect gradients. | |
.. warning:: | |
The ``use_reentrant`` parameter should be passed explicitly. In version | |
2.4 we will raise an exception if ``use_reentrant`` is not passed. | |
If you are using the ``use_reentrant=True`` variant, please refer to the | |
note below for important considerations and potential limitations. | |
.. note:: | |
The reentrant variant of checkpoint (``use_reentrant=True``) and | |
the non-reentrant variant of checkpoint (``use_reentrant=False``) | |
differ in the following ways: | |
* Non-reentrant checkpoint stops recomputation as soon as all needed | |
intermediate activations have been recomputed. This feature is enabled | |
by default, but can be disabled with :func:`set_checkpoint_early_stop`. | |
Reentrant checkpoint always recomputes :attr:`function` in its | |
entirety during the backward pass. | |
* The reentrant variant does not record the autograd graph during the | |
forward pass, as it runs with the forward pass under | |
:func:`torch.no_grad`. The non-reentrant version does record the | |
autograd graph, allowing one to perform backward on the graph within | |
checkpointed regions. | |
* The reentrant checkpoint only supports the | |
:func:`torch.autograd.backward` API for the backward pass without its | |
`inputs` argument, while the non-reentrant version supports all ways | |
of performing the backward pass. | |
* At least one input and output must have ``requires_grad=True`` for the | |
reentrant variant. If this condition is unmet, the checkpointed part | |
of the model will not have gradients. The non-reentrant version does | |
not have this requirement. | |
* The reentrant version does not consider tensors in nested structures | |
(e.g., custom objects, lists, dicts, etc) as participating in | |
autograd, while the non-reentrant version does. | |
* The reentrant checkpoint does not support checkpointed regions with | |
detached tensors from the computational graph, whereas the | |
non-reentrant version does. For the reentrant variant, if the | |
checkpointed segment contains tensors detached using ``detach()`` or | |
with :func:`torch.no_grad`, the backward pass will raise an error. | |
This is because ``checkpoint`` makes all the outputs require gradients | |
and this causes issues when a tensor is defined to have no gradient in | |
the model. To avoid this, detach the tensors outside of the | |
``checkpoint`` function. | |
Args: | |
function: describes what to run in the forward pass of the model or | |
part of the model. It should also know how to handle the inputs | |
passed as the tuple. For example, in LSTM, if user passes | |
``(activation, hidden)``, :attr:`function` should correctly use the | |
first input as ``activation`` and the second input as ``hidden`` | |
preserve_rng_state(bool, optional): Omit stashing and restoring | |
the RNG state during each checkpoint. Note that under torch.compile, | |
this flag doesn't take effect and we always preserve RNG state. | |
Default: ``True`` | |
use_reentrant(bool): | |
specify whether to use the activation checkpoint variant that | |
requires reentrant autograd. This parameter should be passed | |
explicitly. In version 2.4 we will raise an exception if | |
``use_reentrant`` is not passed. If ``use_reentrant=False``, | |
``checkpoint`` will use an implementation that does not require | |
reentrant autograd. This allows ``checkpoint`` to support additional | |
functionality, such as working as expected with | |
``torch.autograd.grad`` and support for keyword arguments input into | |
the checkpointed function. | |
context_fn(Callable, optional): A callable returning a tuple of two | |
context managers. The function and its recomputation will be run | |
under the first and second context managers respectively. | |
This argument is only supported if ``use_reentrant=False``. | |
determinism_check(str, optional): A string specifying the determinism | |
check to perform. By default it is set to ``"default"`` which | |
compares the shapes, dtypes, and devices of the recomputed tensors | |
against those the saved tensors. To turn off this check, specify | |
``"none"``. Currently these are the only two supported values. | |
Please open an issue if you would like to see more determinism | |
checks. This argument is only supported if ``use_reentrant=False``, | |
if ``use_reentrant=True``, the determinism check is always disabled. | |
debug(bool, optional): If ``True``, error messages will also include | |
a trace of the operators ran during the original forward computation | |
as well as the recomputation. This argument is only supported if | |
``use_reentrant=False``. | |
args: tuple containing inputs to the :attr:`function` | |
Returns: | |
Output of running :attr:`function` on :attr:`*args` | |
""" | |
if use_reentrant is None: | |
warnings.warn( | |
"torch.utils.checkpoint: the use_reentrant parameter should be " | |
"passed explicitly. In version 2.4 we will raise an exception " | |
"if use_reentrant is not passed. use_reentrant=False is " | |
"recommended, but if you need to preserve the current default " | |
"behavior, you can pass use_reentrant=True. Refer to docs for more " | |
"details on the differences between the two variants." | |
) | |
use_reentrant = True | |
# Hack to mix *args with **kwargs in a python 2.7-compliant way | |
preserve = kwargs.pop("preserve_rng_state", True) | |
if kwargs and use_reentrant: | |
raise ValueError( | |
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) | |
) | |
if use_reentrant: | |
if context_fn is not noop_context_fn or debug is not False: | |
raise ValueError( | |
"Passing `context_fn` or `debug` is only supported when " | |
"use_reentrant=False." | |
) | |
return CheckpointFunction.apply(function, preserve, *args) | |
else: | |
gen = _checkpoint_without_reentrant_generator( | |
function, preserve, context_fn, determinism_check, debug, *args, **kwargs | |
) | |
# Runs pre-forward logic | |
next(gen) | |
ret = function(*args, **kwargs) | |
# Runs post-forward logic | |
try: | |
next(gen) | |
except StopIteration: | |
return ret | |
def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs): | |
r"""Checkpoint a sequential model to save memory. | |
Sequential models execute a list of modules/functions in order | |
(sequentially). Therefore, we can divide such a model in various segments | |
and checkpoint each segment. All segments except the last will not store | |
the intermediate activations. The inputs of each checkpointed segment will | |
be saved for re-running the segment in the backward pass. | |
.. warning:: | |
The ``use_reentrant`` parameter should be passed explicitly. In version | |
2.4 we will raise an exception if ``use_reentrant`` is not passed. | |
If you are using the ``use_reentrant=True` variant, please see | |
:func:`~torch.utils.checkpoint.checkpoint` for | |
the important considerations and limitations of this variant. It is | |
recommended that you use ``use_reentrant=False``. | |
.. warning: | |
Since PyTorch 1.4, it allows only one Tensor as the input and | |
intermediate outputs, just like :class:`torch.nn.Sequential`. | |
Args: | |
functions: A :class:`torch.nn.Sequential` or the list of modules or | |
functions (comprising the model) to run sequentially. | |
segments: Number of chunks to create in the model | |
input: A Tensor that is input to :attr:`functions` | |
preserve_rng_state(bool, optional): Omit stashing and restoring | |
the RNG state during each checkpoint. | |
Default: ``True`` | |
use_reentrant(bool): | |
specify whether to use the activation checkpoint variant that | |
requires reentrant autograd. This parameter should be passed | |
explicitly. In version 2.4 we will raise an exception if | |
``use_reentrant`` is not passed. If ``use_reentrant=False``, | |
``checkpoint`` will use an implementation that does not require | |
reentrant autograd. This allows ``checkpoint`` to support additional | |
functionality, such as working as expected with | |
``torch.autograd.grad`` and support for keyword arguments input into | |
the checkpointed function. | |
Returns: | |
Output of running :attr:`functions` sequentially on :attr:`*inputs` | |
Example: | |
>>> # xdoctest: +SKIP("stub") | |
>>> model = nn.Sequential(...) | |
>>> input_var = checkpoint_sequential(model, chunks, input_var) | |
""" | |
if use_reentrant is None: | |
warnings.warn( | |
"torch.utils.checkpoint.checkpoint_sequential: the use_reentrant " | |
"parameter should be passed explicitly. " | |
"In version 2.4 we will raise an exception if use_reentrant " | |
"is not passed. use_reentrant=False is " | |
"recommended, but if you need to preserve the current default " | |
"behavior, you can pass use_reentrant=True. Refer to docs for more " | |
"details on the differences between the two variants." | |
) | |
use_reentrant = True | |
# Hack for keyword-only parameter in a python 2.7-compliant way | |
preserve = kwargs.pop("preserve_rng_state", True) | |
if kwargs: | |
raise ValueError( | |
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) | |
) | |
def run_function(start, end, functions): | |
def forward(input): | |
for j in range(start, end + 1): | |
input = functions[j](input) | |
return input | |
return forward | |
if isinstance(functions, torch.nn.Sequential): | |
functions = list(functions.children()) | |
segment_size = len(functions) // segments | |
# the last chunk has to be non-volatile | |
end = -1 | |
for start in range(0, segment_size * (segments - 1), segment_size): | |
end = start + segment_size - 1 | |
input = checkpoint( | |
run_function(start, end, functions), | |
input, | |
use_reentrant=use_reentrant, | |
preserve_rng_state=preserve, | |
) | |
return run_function(end + 1, len(functions) - 1, functions)(input) | |
def _internal_assert(cond): | |
if not cond: | |
raise AssertionError( | |
"Something went unexpectedly wrong in activation checkpoint. " | |
"Please report this bug by filing an issue to PyTorch." | |
) | |
# NOTE [ Nestable Checkpoint ] | |
# | |
# The semantics of nested checkpoint can be defined by two basic rules. | |
# Following the two rules leads to an important implication that is central | |
# to motivating the design. | |
# | |
# Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden | |
# from any outer layers of checkpoint. | |
# | |
# Rule 2. The inputs of inner checkpoints are treated as tensors saved to its | |
# parent checkpoint. | |
# | |
# Implication: To recompute any given saved tensor, we need to recompute all of | |
# the checkpoints wrapping it. | |
# | |
# Why is this implied? To unpack a saved tensor X during backward we need to | |
# recompute the inner-most checkpoint (#1), and in order to recompute that | |
# checkpoint I need to have its inputs, which are managed by that checkpoint's | |
# parent (#2), which thus also needs to be recomputed first. Continue this line | |
# of reasoning and we realize that in order to unpack X, all checkpoints that | |
# were active at the time X was saved need to be recomputed. (unless we have | |
# already done so in that backward for some other saved tensor). | |
# | |
# In practice, we use a noop autograd Function to save inputs as saved tensors. | |
# During unpack calling ctx.saved_tensor triggers the parent checkpoint to | |
# recompute. | |
# | |
# Rule 3. We should start recomputation as if there are no checkpoints currently | |
# active. Checkpoints encountered during recomputation are still | |
# respected. | |
# | |
# When we start recomputation, we push the saved variable hook meant for | |
# recomputation on the stack. See examples in Rule 6 for more context. | |
# | |
# * * * * | |
# | |
# Beyond the basic semantics specific to nested checkpoint, we impose several | |
# more constraints that may apply to checkpointing in general. | |
# | |
# Rule 4. Lifetime of recomputed tensors | |
# | |
# Recomputed tensors are considered specific to particular invocations | |
# of backward and are always cleared immediately as they are unpacked | |
# Particularly, we require this to happen even if retain_graph=True. | |
# | |
# [ Implementation details of Rule 4 ] | |
# | |
# If we were okay with recomputed tensors staying alive after backward is run | |
# with retain_graph=True, we would store recomputed variables as the values of a | |
# WeakKeyDictionary and pack strong references to the keys, so that as we | |
# backward, those packed keys would be cleared as long as retain_graph=False. | |
# Clearing the packed key clears the corresponding entry in the WKD. | |
# | |
# If we wish recomputed variables to be immediately cleared as we unpack them in | |
# the retain_graph=True case, we cannot rely on the packed keys to be cleared by | |
# backward automatically. Instead of packing the strong reference to the key | |
# directly, we pack a container object, which we manually clear as we unpack. | |
# | |
# An important detail is that if a second backward happens, the second | |
# recomputation needs to reset the container with a newly created key. | |
# | |
# Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we | |
# know we need. | |
# | |
# [ Implementation details of Rule 5 ] | |
# | |
# During recomputation, raise an exception if the number of recomputed tensors | |
# matches the number of tensors that we expected to recompute. We wrap the | |
# recomputation call with a try-catch to catch this specific exception. See | |
# Rule #6 below for some examples. | |
# | |
# Rule 6. We support doing backward inside checkpoint context | |
# | |
# [ retain_graph is True] | |
# | |
# def fn(x): | |
# y = x.sin() | |
# z = y.cos() | |
# gx, = torch.autograd.grad(z, x, retains_grad=True) | |
# return gx, z | |
# | |
# out = checkpoint(fn)(inp) | |
# out.backward() | |
# | |
# Because z is saved by cos while checkpoint is enabled, it would not be | |
# actually saved, and so the .grad() call inside must trigger a recomputation. | |
# | |
# During recomputation the "inner pack hook" has two responsibilities: | |
# | |
# 1) As usual, populating the WeakKeyDictionary storing recomputed tensors | |
# 2) Pack the actual tensor (detached) so that one may perform backward on the | |
# recomputed graph. The tensors saved to this graph will live until the end | |
# of recomputation, or die earlier if someone performs backward with | |
# retain_graph=False. | |
# | |
# More generally performing backward on the recomputed graph occurs in the | |
# following cases: | |
# - If backward is performed inside forward, | |
# - During the original forward IF early-stop is disabled | |
# - During the original backward | |
# - If there are multiple .grad()/.backward() calls, we would perform backward | |
# on the recomputed graph even if early-stop is enabled (see the example below) | |
# | |
# [ retain_graph is False ] | |
# | |
# The example below shows what happens if during recomputation we find that some | |
# of the tensors we are trying to recompute have already been cleared. | |
# | |
# Spoiler: we don't do anything special, we just skip over them! | |
# | |
# def fn(x): | |
# y = x.sin() # (1) | |
# z = y.cos() # (2) | |
# gx, = torch.autograd.grad(z, x) # (3) | |
# return x.cos() * gx # (4) | |
# | |
# out = checkpoint(fn)(inp) | |
# out.backward() # (5) | |
# | |
# 1, 2. Don't save x and y since we are inside a checkpoint. | |
# 3. Trigger a recompute of fn since x and y weren't saved. | |
# And depending on whether early stop is enabled, either stop at (2) or | |
# continue running the function. | |
# Because we are running backward with retain_graph=False, we clear x and y's | |
# holders. | |
# 4. Don't save x since we are inside a checkpoint. | |
# 5. Calling backward triggers another recompute of fn. During recompute, we see | |
# that x and y have already been cleared in the original graph as indicated | |
# by holder=None. We skip over them. We still save x at (4) (since its holder | |
# is still alive.) | |
_enable_checkpoint_early_stop = True | |
def set_checkpoint_early_stop(enable: bool): | |
"""Context manager that sets whether checkpoint should stop recomputation early. | |
By default, non-reentrant checkpoint stops recomputation as soon as it | |
has computed all needed Tensors. This context manager can be used to disable | |
that feature if it is problematic for your specific application. | |
This context manager only needs to be active when forward is run. It does | |
not need to be active during backward. | |
Example:: | |
>>> # xdoctest: +SKIP(failing) | |
>>> message = "saved tensors default hooks are disabled" | |
>>> with set_checkpoint_early_stop(False): | |
... # Any checkpoint under this context manager will respect this | |
... # context manager, even if its backward is performed outside. | |
... out = checkpoint(fn, inputs) | |
... | |
>>> out.backward() | |
""" | |
global _enable_checkpoint_early_stop | |
try: | |
prev = _enable_checkpoint_early_stop | |
_enable_checkpoint_early_stop = enable | |
yield | |
finally: | |
_enable_checkpoint_early_stop = prev | |
class _Handle: | |
pass | |
class _Holder: | |
def __init__(self): | |
self.handles: Dict[int, Optional[_Handle]] = dict() | |
class _NoopSaveInputs(torch.autograd.Function): | |
def forward(*args): | |
return torch.empty((0,)) | |
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: | |
# Only tensors can be saved with ctx.save_for_backward, everything else | |
# is captured by get_args, which is saved directly on ctx | |
tensor_indices, tensors = zip( | |
*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)] | |
) | |
idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)} | |
# args but with tensors replaced with None as placeholders | |
args = [None if isinstance(o, torch.Tensor) else o for o in inputs] | |
def get_args(saved_tensors): | |
# restore the placeholders with the original tensors grabbed from | |
# ctx.saved_tensors (which may be saved on a parent checkpoint if | |
# this checkpoint is nested, and that would trigger a recursive | |
# unpack!) | |
ret = [ | |
saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o | |
for i, o in enumerate(args) | |
] | |
# grab the tail since we also saved the dummy to avoid having to explicitly | |
# handle the case where there are no tensor inputs | |
return ret[1:] | |
ctx.get_args = get_args | |
ctx.save_for_backward(*tensors) | |
def backward(ctx, *grad_outputs): | |
raise AssertionError("Did not expect to backward on this graph") | |
class _CheckpointFrame: | |
def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn): | |
self.recompute_fn = recompute_fn | |
self.input_saver = None | |
self.weak_holders: List[ReferenceType] = [] | |
# We store this as a weakkeydictionary so that in the case of a partial | |
# backward, the entries in the dict are cleared alongside the Holder | |
# which will be removed when the SavedVariable is cleared. | |
self.recomputed: DefaultDict[ | |
int, weakref.WeakKeyDictionary[_Handle, torch.Tensor] | |
] = defaultdict(weakref.WeakKeyDictionary) | |
# We need both recomp_counter and recomputed since they can diverge | |
# https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885 | |
self.recomp_counter: DefaultDict[int, int] = defaultdict(int) | |
self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool) | |
# See Rule 5 | |
self.early_stop = early_stop | |
# Debugging | |
self.metadata_fn = metadata_fn | |
self.unpack_error_cb = unpack_error_cb | |
self.x_metadatas = [] | |
self.forward_completed = False | |
self.ignore_saved_mismatch = False | |
def check_recomputed_tensors_match(self, gid): | |
if self.ignore_saved_mismatch: | |
# TODO: we can probably make this check stricter by checking that | |
# the metadata of the first tensors still match. | |
return | |
# NOTE [ Error handling for checkpoint ] | |
# | |
# At a high level, we need to check that the tensors saved | |
# during original forward matches tensors saved during recompute | |
# This means handling 3 cases: | |
# | |
# 1. During recompute, more tensors were saved. | |
# | |
# Usually this is hidden due to the StopRecomputationError | |
# but if early stop is not enabled, or we would have errored | |
# anyway because there aren't enough weak_holders. But we | |
# do want to have a nice error. See the _recomputation_hook | |
# for details. | |
if not len(self.weak_holders) == self.recomp_counter[gid]: | |
# 2. During recompute, fewer tensors were saved | |
# | |
# We know that everytime we save something do original forward | |
# we append to weak_holder, and every time we save a tensor | |
# during recompute we increment recompute_counter. | |
raise CheckpointError( | |
"torch.utils.checkpoint: A different number of tensors was saved " | |
"during the original forward and recomputation.\n" | |
f"Number of tensors saved during forward: {len(self.weak_holders)}\n" | |
f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}" | |
) | |
# 3. During recompute, the same tensors were saved, but they | |
# have different metadata | |
nb_meta_different = [] | |
for idx, weak_holder in enumerate(self.weak_holders): | |
holder = weak_holder() | |
if holder is None: | |
continue | |
# We've seen all holders since we iterate over them in order | |
# For every holder that is still alive now, it must've been | |
# alive when we saw it during recompute, therefore, the | |
# gid must be set. | |
_internal_assert(gid in holder.handles) | |
# We know this is the first unpack, so it couldn't have been set | |
# to None yet. | |
_internal_assert(holder.handles[gid] is not None) | |
# We always set these together in the recomputation hook | |
_internal_assert(holder.handles[gid] in self.recomputed[gid]) | |
# see pack hook, x_metadata is 1:1 with weak_holders. | |
x_meta = self.x_metadatas[idx] | |
recomputed_x = self.recomputed[gid][holder.handles[gid]] | |
if x_meta != self.metadata_fn(recomputed_x): | |
nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x))) | |
if len(nb_meta_different) > 0: | |
mismatched_tensors = "" | |
for idx, x_meta, recomputed_meta in nb_meta_different: | |
mismatched_tensors += ( | |
f"tensor at position {idx}:\n" | |
f"saved metadata: {x_meta}\n" | |
f"recomputed metadata: {recomputed_meta}\n" | |
) | |
raise CheckpointError( | |
"torch.utils.checkpoint: Recomputed values for the following tensors " | |
"have different metadata than during the forward pass.\n" | |
f"{mismatched_tensors}" | |
) | |
_checkpoint_error_template = """ \ | |
An error happened while unpacking tensors; dumping logs of latest computation | |
because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`. | |
Scroll all the way down for guidance on how to navigate these logs. | |
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ | |
| 1. Stack traces of the operators that ran in the original forward | | |
+------------------------------------------------------------------------------+ | |
{forward_traces} | |
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ | |
| 2. Stack traces of the operators that ran during recomputation | | |
+------------------------------------------------------------------------------+ | |
{recompute_traces} | |
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ | |
| 3. Log of operators in the original forward and recomputation | | |
+------------------------------------------------------------------------------+ | |
(Scroll up to correlate stack traces with each operation listed below. This | |
helps identify their source in the code.) | |
IMPORTANT: Differences in "detach" calls between the original forward and the | |
recomputation are expected. They are introduced by the checkpointing | |
mechanism and can be ignored. | |
Operations executed during the original forward: | |
{forward_ops} | |
Operations executed during recomputation: | |
{recompute_ops} | |
+------------------------------------------------------------------------------+ | |
ERROR: Detected non-determinism while running activation checkpointing | |
You are seeing this error because you passed `debug=True` to checkpoint and | |
tensors to be saved during the original forward and differ between those saved | |
during recomputation. This can happen if different operators were ran in the | |
original forward and in the recomputation. | |
To identify where the mismatch may be coming from, you can do the following: | |
1) Compare the operators ran during original forward and recomputation to | |
see where they differ. These operators are printed above in the order they | |
were executed. | |
2) Review the stack trace for each operator to locate its invocation source. | |
Each operator's stack trace is printed in their execution order. | |
Note that the logs can be quite long. Here's how they are structured: | |
(Tip: you can Ctrl-f for these headers) | |
1. Stack traces of the operators that ran in the original forward | |
2. Stack traces of the operators that ran during recomputation | |
3. Log of operators in the original forward and recomputation | |
4. Error message <--- You are here | |
-------------------------------------------------------------------------------- | |
""" | |
class CheckpointError(RuntimeError): | |
pass | |
def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]: | |
# This function returns the context_fn and error_cb to be used by the | |
# checkpointing mechanism. error_cb is invoked when an error is detected | |
# during unpack. | |
# record_context_cpp is not support on non-linux non-x86_64 platforms | |
cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux' | |
class CaptureLogs: | |
def __init__(self): | |
self.logs = None | |
self.tbs = None | |
def get_context_manager(self): | |
def logging_mode(): | |
with LoggingTensorMode(), \ | |
capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb: | |
self.logs, self.tbs = logs_and_tb | |
yield logs_and_tb | |
return logging_mode() | |
capture_logs_fwd = CaptureLogs() | |
capture_logs_recompute = CaptureLogs() | |
def unpack_error_cb(e: CheckpointError): | |
def get_str_tb(label, capture_logs): | |
out = "" | |
total_len = len(capture_logs.logs) | |
for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)): | |
out += f"{log} ({i + 1} of {total_len} in {label})\n\n" | |
found_torch_dispatch = False | |
for line in tb: | |
# Start printing stack trace only after __torch_dispatch__ is found | |
is_torch_dispatch = line['name'] == '__torch_dispatch__' | |
if not found_torch_dispatch and not is_torch_dispatch: | |
continue | |
elif is_torch_dispatch: | |
found_torch_dispatch = True | |
continue | |
out += f"{line['filename']}:{line['line']}:{line['name']}\n" | |
out += "\n\n" | |
return out | |
assert capture_logs_fwd.logs is not None | |
assert capture_logs_recompute.logs is not None | |
raise CheckpointError( | |
_checkpoint_error_template.format( | |
forward_traces=get_str_tb("original", capture_logs_fwd), | |
recompute_traces=get_str_tb("recompute", capture_logs_recompute), | |
forward_ops="\n".join(capture_logs_fwd.logs), | |
recompute_ops="\n".join(capture_logs_recompute.logs) | |
) | |
) from e | |
def context_fn(): | |
return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager() | |
return context_fn, unpack_error_cb | |
def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]: | |
# These properties are fast to check, easy to understand | |
return { | |
"shape": x.shape, | |
"dtype": x.dtype, | |
"device": x.device | |
} | |
_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = { | |
_DEFAULT_DETERMINISM_MODE: _default_meta_extractor, | |
"none": lambda _: None, | |
} | |
# See Rule 5 | |
class _StopRecomputationError(Exception): | |
pass | |
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): | |
def __init__(self, target_frame_ref: ReferenceType, gid: int): | |
def pack_hook(x): | |
target_frame = target_frame_ref() | |
assert target_frame is not None # appease mypy | |
recomp_idx = target_frame.recomp_counter[gid] | |
target_frame.recomp_counter[gid] += 1 | |
if recomp_idx >= len(target_frame.weak_holders): | |
assert not target_frame.early_stop | |
if not target_frame.forward_completed: | |
# We run into this case when early stop is not enabled and do | |
# grad within checkpoint. | |
# We need to set this flag, so we don't error out later when | |
# we check if the number of tensors saved during forward and | |
# recomputation match. | |
target_frame.ignore_saved_mismatch = True | |
return x.detach() | |
raise CheckpointError( | |
"torch.utils.checkpoint: trying to save more tensors during " | |
"recomputation than during the original forward pass." | |
) | |
holder = target_frame.weak_holders[recomp_idx]() | |
# This holder may have been cleared because someone may have called | |
# backward within forward. If so, we don't need to save. | |
if holder is not None: | |
_internal_assert(holder.handles.get(gid, None) is None) | |
holder.handles[gid] = _Handle() | |
target_frame.recomputed[gid][holder.handles[gid]] = x.detach() | |
if target_frame.early_stop and target_frame.recomp_counter[gid] == len( | |
target_frame.weak_holders | |
): | |
raise _StopRecomputationError() | |
# See Rule 6: [ retain_graph is True ] above | |
return x.detach() | |
def unpack_hook(x): | |
# See Rule 6: [ retain_graph is True ] above for an example of when | |
# the graph created during recomputation could be backwarded. | |
return x | |
super().__init__(pack_hook, unpack_hook) | |
class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): | |
def __init__(self, frame): | |
def pack_hook(x): | |
# See Rule 4 above | |
holder = _Holder() | |
frame.weak_holders.append(weakref.ref(holder)) | |
# Save metadata to detect non-determinism | |
if frame.metadata_fn is not None: | |
with torch.no_grad(): | |
frame.x_metadatas.append(frame.metadata_fn(x)) | |
return holder | |
def unpack_hook(holder): | |
gid = torch._C._current_graph_task_id() | |
if gid == -1: | |
# generate a temporary id if we trigger unpack outside of a backward call | |
gid = int(uuid.uuid4()) | |
if not frame.is_recomputed[gid]: | |
ctx = frame.input_saver.grad_fn | |
args = ctx.get_args(ctx.saved_tensors) | |
try: | |
with _recomputation_hook( | |
weakref.ref(frame), gid | |
), torch.autograd.enable_grad(): | |
frame.recompute_fn(*args) | |
except _StopRecomputationError: | |
pass | |
frame.is_recomputed[gid] = True | |
frame.check_recomputed_tensors_match(gid) | |
_internal_assert(gid in holder.handles) | |
if holder.handles[gid] is None: | |
raise CheckpointError( | |
"torch.utils.checkpoint: Unpack is being triggered for a tensor that was already " | |
"unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do " | |
"so only once. Otherwise please open an issue with details on your use case." | |
) | |
_internal_assert(holder.handles[gid] in frame.recomputed[gid]) | |
ret = frame.recomputed[gid][holder.handles[gid]] | |
holder.handles[gid] = None | |
return ret | |
if frame.unpack_error_cb is not None: | |
def unpack_hook_with_error_cb(holder): | |
try: | |
return unpack_hook(holder) | |
except CheckpointError as e: | |
frame.unpack_error_cb(e) | |
super().__init__(pack_hook, unpack_hook_with_error_cb) | |
else: | |
super().__init__(pack_hook, unpack_hook) | |
def _is_compiling(func, args, kwargs): | |
# Check if we are under AOTAutograd tracing | |
# There should probably be a better way to do this... | |
# TODO: unify _is_compiling across all compile stacks | |
for arg in args: | |
if isinstance(arg, torch.Tensor) and is_fun(arg): | |
return True | |
return False | |
def _detach(x): | |
if isinstance(x, torch.Tensor): | |
return x.detach() | |
return x | |
uid = count(1) | |
# NOTE: torch.utils.checkpoint internal logic will call these two functions unknown number of times | |
# (i.e. there could be _CachedTorchDispatchMode calls that doesn't map to a _CachingTorchDispatchMode call), | |
# so we ignore these ops and just always recompute them. | |
_ignored_ops = { | |
torch.ops.prim.device.default, | |
torch.ops.aten.detach.default, | |
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) | |
class _CachingTorchDispatchMode(TorchDispatchMode): | |
r""" | |
A :class:`TorchDispatchMode` to implement selective activation checkpointing | |
that's compatible with torch.compile. Used together with _CachedTorchDispatchMode. | |
""" | |
def __init__(self, policy_fn, storage): | |
self.policy_fn = policy_fn | |
self.storage = storage | |
def push_into_storage(self, out, func, args, kwargs): | |
out_detached = tree_map(_detach, out) | |
self.storage[func].append(out_detached) | |
def _handle_compile_in_forward_ctx(self, should_not_recompute, func, args, kwargs): | |
if func in _ignored_ops: | |
return func(*args, **kwargs) | |
if should_not_recompute: | |
fx_traceback.current_meta["recompute"] = 0 | |
# NOTE: Here we just store and reuse output of all ops, since in torch.compile mode | |
# we decide and handle recomputation in the partitioner. | |
out = func(*args, **kwargs) | |
self.push_into_storage(out, func, args, kwargs) | |
return out | |
def __torch_dispatch__(self, func, types, args=(), kwargs=None): | |
if kwargs is None: | |
kwargs = {} | |
should_not_recompute = self.policy_fn("forward", func, *args, **kwargs) | |
if _is_compiling(func, args, kwargs): | |
return self._handle_compile_in_forward_ctx(should_not_recompute, func, args, kwargs) | |
else: | |
if should_not_recompute: | |
out = func(*args, **kwargs) | |
self.push_into_storage(out, func, args, kwargs) | |
else: | |
out = func(*args, **kwargs) | |
return out | |
class _CachedTorchDispatchMode(TorchDispatchMode): | |
r""" | |
A :class:`TorchDispatchMode` to implement selective activation checkpointing | |
that's compatible with torch.compile. Used together with _CachingTorchDispatchMode. | |
""" | |
def __init__(self, policy_fn, storage): | |
self.policy_fn = policy_fn | |
self.storage = storage | |
def pop_from_storage(self, func, args, kwargs): | |
assert func in self.storage | |
out = self.storage[func].pop(0) | |
return out | |
def _handle_compile_in_recompute_ctx(self, should_not_recompute, func, args, kwargs): | |
if func in _ignored_ops: | |
return func(*args, **kwargs) | |
out = self.pop_from_storage(func, args, kwargs) | |
return out | |
def __torch_dispatch__(self, func, types, args=(), kwargs=None): | |
if kwargs is None: | |
kwargs = {} | |
should_not_recompute = self.policy_fn("recompute", func, *args, **kwargs) | |
if _is_compiling(func, args, kwargs): | |
return self._handle_compile_in_recompute_ctx(should_not_recompute, func, args, kwargs) | |
else: | |
if should_not_recompute: | |
out = self.pop_from_storage(func, args, kwargs) | |
else: | |
out = func(*args, **kwargs) | |
return out | |
def _pt2_selective_checkpoint_context_fn_gen(policy_fn): | |
""" | |
A helper function that generates a pair of contexts to be later passed into | |
`torch.utils.checkpoint` API to implment selective checkpointing. | |
.. warning:: | |
This is context_fn is intended for use with torch.compile only. | |
Args: | |
policy_fn (Callable[[Callable, List[Any], Dict[str, Any]], bool]): Policy function | |
to decide whether a particular op should be recomputed in backward pass or not. | |
In eager mode: | |
If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed. | |
If policy_fn(...) returns False, the op is guaranteed to be recomputed. | |
In torch.compile mode: | |
If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed. | |
If policy_fn(...) returns False, the op may or may not be recomputed | |
(it's up to the partitioner to decide). | |
Returns: | |
A pair of generated contexts. | |
Example: | |
>>> # xdoctest: +REQUIRES(LINUX) | |
>>> | |
>>> def get_custom_policy(): | |
>>> no_recompute_list = [ | |
>>> torch.ops.aten.mm.default, | |
>>> ] | |
>>> def custom_policy(mode, func, *args, **kwargs): | |
>>> return func in no_recompute_list | |
>>> return custom_policy | |
>>> | |
>>> def selective_checkpointing_context_fn(): | |
>>> return _pt2_selective_checkpoint_context_fn_gen(get_custom_policy()) | |
>>> | |
>>> def gn(x, y): | |
>>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y | |
>>> | |
>>> def fn(x, y): | |
>>> return torch.utils.checkpoint.checkpoint( | |
>>> gn, x, y, | |
>>> use_reentrant=False, | |
>>> context_fn=selective_checkpointing_context_fn, | |
>>> ) | |
>>> | |
>>> x = torch.randn(4, 4, requires_grad=True) | |
>>> y = torch.randn(4, 4, requires_grad=True) | |
>>> | |
>>> compiled_fn = torch.compile(fn) | |
""" | |
storage: Dict[Any, List[Any]] = defaultdict(list) | |
return _CachingTorchDispatchMode(policy_fn, storage), _CachedTorchDispatchMode(policy_fn, storage) | |
# NB: this helper wraps fn before calling checkpoint_impl. kwargs and | |
# saving/restoring of global state is handled here. | |
def _checkpoint_without_reentrant_generator( | |
fn, | |
preserve_rng_state=True, | |
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, | |
determinism_check: str = _DEFAULT_DETERMINISM_MODE, | |
debug: bool = False, | |
*args, | |
**kwargs | |
): | |
"""Checkpointing without reentrant autograd. | |
Args: | |
function: describes what to run in the forward pass of the model or | |
part of the model. It should also know how to handle the inputs | |
passed as the tuple. For example, in LSTM, if user passes | |
``(activation, hidden)``, :attr:`function` should correctly use the | |
first input as ``activation`` and the second input as ``hidden`` | |
preserve_rng_state(bool, optional): Omit stashing and restoring | |
the RNG state during each checkpoint. | |
Default: ``True`` | |
context_fn(Callable, optional): A callable returning a tuple of two | |
context managers. The function and its recomputation will be run | |
under the first and second context managers respectively. | |
determinism_check(str, optional): A string specifying the determinism | |
check to perform. By default it is set to ``"default"`` which | |
compares the shapes, dtypes, and devices of the recomputed tensors | |
against those the saved tensors. To turn off this check, specify | |
``"none"``. Currently these are the only two supported values. | |
Please open an issue if you would like to see more determinism | |
checks. | |
debug(bool, optional): If ``True``, error messages will also include | |
a trace of the operators ran during the original forward computation | |
as well as the recomputation. | |
*args: Arguments to pass in to the given ``function``. | |
**kwargs: Keyword arguments to pass into the given ``function``. | |
""" | |
unpack_error_cb = None | |
if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: | |
if context_fn != noop_context_fn: | |
raise ValueError( | |
"debug=True is incompatible with non-default context_fn" | |
) | |
context_fn, unpack_error_cb = _get_debug_context_and_cb() | |
if determinism_check in _allowed_determinism_checks_to_fns: | |
metadata_fn = _allowed_determinism_checks_to_fns[determinism_check] | |
else: | |
raise ValueError( | |
f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, " | |
f"but got {determinism_check}" | |
) | |
device = _infer_device_type(*args) | |
device_module = _get_device_module(device) | |
forward_context, recompute_context = context_fn() | |
if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn: | |
assert ( | |
isinstance(forward_context, TorchDispatchMode) and | |
isinstance(recompute_context, TorchDispatchMode) | |
), \ | |
"In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \ | |
"must generate a tuple of two `TorchDispatchMode`s." | |
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. | |
device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device=device) | |
if preserve_rng_state: | |
fwd_cpu_state = torch.get_rng_state() | |
# Don't eagerly initialize the cuda context by accident. | |
# (If the user intends that the context is initialized later, within their | |
# run_function, we SHOULD actually stash the cuda state here. Unfortunately, | |
# we have no way to anticipate this will happen before we run the function. | |
# If they do so, we raise an error.) | |
had_device_in_fwd = False | |
if getattr(device_module, "_initialized", False): | |
had_device_in_fwd = True | |
fwd_devices, fwd_device_states = get_device_states(*args) | |
def recompute_fn(*inputs): | |
kwargs, *args = inputs | |
# This will be called later during recomputation. This wrapping enables | |
# the necessary global state to be captured. | |
rng_devices = [] | |
if preserve_rng_state and had_device_in_fwd: | |
rng_devices = fwd_devices | |
with torch.random.fork_rng( | |
devices=rng_devices, enabled=preserve_rng_state, device_type=device | |
): | |
if preserve_rng_state: | |
torch.set_rng_state(fwd_cpu_state) | |
if had_device_in_fwd: | |
set_device_states(fwd_devices, fwd_device_states) | |
device_autocast_ctx = device_module.amp.autocast( | |
**device_autocast_kwargs | |
) if _supports_autocast(device) else contextlib.nullcontext() | |
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ | |
recompute_context: | |
fn(*args, **kwargs) | |
new_frame = _CheckpointFrame( | |
recompute_fn, | |
_enable_checkpoint_early_stop, | |
unpack_error_cb, | |
metadata_fn | |
) | |
dummy = torch.empty((0,), requires_grad=True) | |
new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args) | |
# When ambient grad_mode is False | |
if new_frame.input_saver.grad_fn is None: | |
yield | |
return | |
with _checkpoint_hook(new_frame), forward_context: | |
yield | |
new_frame.forward_completed = True | |
if getattr(device_module, "_initialized", False) and \ | |
preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined] | |
# Device was not initialized before running the forward, so we didn't | |
# stash the device state. | |
raise RuntimeError( | |
"PyTorch's device state was initialized in the forward pass " | |
"of a Checkpoint, which is not allowed. Please open an issue " | |
"if you need this feature." | |
) | |
return | |