Spaces:
Sleeping
Sleeping
import abc | |
import collections | |
import contextlib | |
import functools | |
import logging | |
import threading | |
import weakref | |
from collections import defaultdict, namedtuple | |
from typing import ( | |
Any, | |
Callable, | |
cast, | |
Deque, | |
Dict, | |
List, | |
Optional, | |
Sequence, | |
Set, | |
Tuple, | |
Union, | |
) | |
import torch | |
from torch.autograd.variable import Variable | |
from torch.utils._python_dispatch import TorchDispatchMode | |
from torch.utils.hooks import RemovableHandle | |
log = logging.getLogger(__name__) | |
__all__ = [ | |
"saved_tensors_hooks", | |
"save_on_cpu", | |
"disable_saved_tensors_hooks", | |
"register_multi_grad_hook", | |
"allow_mutation_on_saved_tensors", | |
"Node", | |
"GradientEdge", | |
"get_gradient_edge", | |
"increment_version", | |
] | |
class Node(abc.ABC): | |
def name(self) -> str: | |
r"""Return the name. | |
Example:: | |
>>> import torch | |
>>> a = torch.tensor([0., 0., 0.], requires_grad=True) | |
>>> b = a.clone() | |
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) | |
>>> print(b.grad_fn.name()) | |
CloneBackward0 | |
""" | |
... | |
def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]: | |
... | |
def metadata(self) -> dict: | |
r"""Return the metadata.""" | |
... | |
def _register_hook_dict(self, tensor: torch.Tensor) -> None: | |
... | |
def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle: | |
r"""Register a backward hook. | |
The hook will be called every time a gradient with respect to the | |
Node is computed. The hook should have the following signature:: | |
hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None | |
The hook should not modify its argument, but it can optionally return | |
a new gradient which will be used in place of :attr:`grad_inputs`. | |
This function returns a handle with a method ``handle.remove()`` | |
that removes the hook from the module. | |
.. note:: | |
See :ref:`backward-hooks-execution` for more information on how when this hook | |
is executed, and how its execution is ordered relative to other hooks. | |
Example:: | |
>>> import torch | |
>>> a = torch.tensor([0., 0., 0.], requires_grad=True) | |
>>> b = a.clone() | |
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) | |
>>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,)) | |
>>> b.sum().backward(retain_graph=True) | |
>>> print(a.grad) | |
tensor([2., 2., 2.]) | |
>>> handle.remove() # Removes the hook | |
>>> a.grad = None | |
>>> b.sum().backward(retain_graph=True) | |
>>> print(a.grad) | |
tensor([1., 1., 1.]) | |
""" | |
... | |
def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle: | |
r"""Register a backward pre-hook. | |
The hook will be called every time a gradient with respect to the | |
Node is computed. The hook should have the following signature:: | |
hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None | |
The hook should not modify its argument, but it can optionally return | |
a new gradient which will be used in place of :attr:`grad_outputs`. | |
This function returns a handle with a method ``handle.remove()`` | |
that removes the hook from the module. | |
.. note:: | |
See :ref:`backward-hooks-execution` for more information on how when this hook | |
is executed, and how its execution is ordered relative to other hooks. | |
Example:: | |
>>> a = torch.tensor([0., 0., 0.], requires_grad=True) | |
>>> b = a.clone() | |
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) | |
>>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,)) | |
>>> b.sum().backward(retain_graph=True) | |
>>> print(a.grad) | |
tensor([2., 2., 2.]) | |
>>> handle.remove() | |
>>> a.grad = None | |
>>> b.sum().backward(retain_graph=True) | |
>>> print(a.grad) | |
tensor([1., 1., 1.]) | |
""" | |
... | |
def __subclasshook__(cls, C): | |
if cls is Node: | |
if ( | |
C is not None and C is getattr(torch._C._functions, C.__name__, None) | |
) or issubclass(C, torch.autograd.function.BackwardCFunction): | |
return True | |
return NotImplemented | |
def _get_grad_fn_or_grad_acc(t): | |
if t.requires_grad and t.grad_fn is None: | |
return t.view_as(t).grad_fn.next_functions[0][0] | |
else: | |
return t.grad_fn | |
GradientEdge = namedtuple("GradientEdge", ("node output_nr")) | |
GradientEdge.__doc__ = """\ | |
Object representing a given gradient edge within the autograd graph. | |
To get the gradient edge where a given Tensor gradient will be computed, | |
you can do ``edge = autograd.graph.get_gradient_edge(tensor)``. | |
""" | |
def get_gradient_edge(tensor): | |
"""Get the gradient edge for computing the gradient of the given Tensor. | |
In particular, it is equivalent to call | |
``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``. | |
""" | |
if not tensor.requires_grad: | |
raise RuntimeError( | |
"It is not possible to get the gradient edge for a Tensor that does not require gradients" | |
) | |
grad_fn = _get_grad_fn_or_grad_acc(tensor) | |
# Note that output_nr default to 0 which is the right value | |
# for the AccumulateGrad node. | |
return GradientEdge(grad_fn, tensor.output_nr) | |
def increment_version(tensor): | |
"""Update autograd metadata tracking whether the given Tensor was modified in place. | |
This is to enable more accurate error checking within the autograd engine. | |
It is already done automatically by PyTorch functions and within custom Function | |
when mark_dirty() is called appropriately so you only need to call this explicitly | |
if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't | |
know about. For example a custom kernel that reads the Tensor data_ptr and modifies | |
the memory inplace based on this pointer. | |
Note that incrementing the version counter multiple times for a single inplace operation | |
is not problematic. | |
""" | |
torch._C._increment_version(tensor) | |
class saved_tensors_hooks: | |
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors. | |
Use this context-manager to define how intermediary results of an operation | |
should be packed before saving, and unpacked on retrieval. | |
In that context, the ``pack_hook`` function will be called everytime an | |
operation saves a tensor for backward (this includes intermediary results | |
saved using | |
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but | |
also those recorded by a PyTorch-defined operation). The output of | |
``pack_hook`` is then stored in the computation graph instead of the | |
original tensor. | |
The ``unpack_hook`` is called when the saved tensor needs to be accessed, | |
namely when executing :func:`torch.Tensor.backward()` or | |
:func:`torch.autograd.grad()`. It takes as argument the *packed* object | |
returned by ``pack_hook`` and should return a tensor which has the same | |
content as the original tensor (passed as input to the corresponding | |
``pack_hook``). | |
The hooks should have the following signatures: | |
pack_hook(tensor: Tensor) -> Any | |
unpack_hook(Any) -> Tensor | |
where the return value of ``pack_hook`` is a valid input to ``unpack_hook``. | |
In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms | |
of value, size, dtype and device. | |
Example:: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) | |
>>> def pack_hook(x): | |
... print("Packing", x) | |
... return x | |
>>> | |
>>> def unpack_hook(x): | |
... print("Unpacking", x) | |
... return x | |
>>> | |
>>> a = torch.ones(5, requires_grad=True) | |
>>> b = torch.ones(5, requires_grad=True) * 2 | |
>>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): | |
... y = a * b | |
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True) | |
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>) | |
>>> y.sum().backward() | |
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True) | |
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>) | |
.. warning :: | |
Performing an inplace operation on the input to either hooks may lead | |
to undefined behavior. | |
.. warning :: | |
Only one pair of hooks is allowed at a time. When recursively nesting this | |
context-manager, only the inner-most pair of hooks will be applied. | |
""" | |
def __init__( | |
self, | |
pack_hook: Callable[[torch.Tensor], Any], | |
unpack_hook: Callable[[Any], torch.Tensor], | |
): | |
self.pack_hook = pack_hook | |
self.unpack_hook = unpack_hook | |
def __enter__(self): | |
torch._C._autograd._push_saved_tensors_default_hooks( | |
self.pack_hook, self.unpack_hook | |
) | |
def __exit__(self, *args: object): | |
torch._C._autograd._pop_saved_tensors_default_hooks() | |
class save_on_cpu(saved_tensors_hooks): | |
"""Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward. | |
When performing operations within this context manager, intermediary | |
results saved in the graph during the forward pass will be moved to CPU, | |
then copied back to the original device when needed for the backward pass. | |
If the graph was already on CPU, no tensor copy is performed. | |
Use this context-manager to trade compute for GPU memory usage (e.g. | |
when your model doesn't fit in GPU memory during training). | |
Args: | |
pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory | |
during packing and copied to GPU asynchronously during unpacking. | |
Defaults to ``False``. | |
Also see :ref:`cuda-memory-pinning`. | |
Example:: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) | |
>>> a = torch.randn(5, requires_grad=True, device="cuda") | |
>>> b = torch.randn(5, requires_grad=True, device="cuda") | |
>>> c = torch.randn(5, requires_grad=True, device="cuda") | |
>>> | |
>>> def f(a, b, c): | |
... prod_1 = a * b # a and b are saved on GPU | |
... with torch.autograd.graph.save_on_cpu(): | |
... prod_2 = prod_1 * c # prod_1 and c are saved on CPU | |
... y = prod_2 * a # prod_2 and a are saved on GPU | |
... return y | |
>>> | |
>>> y = f(a, b, c) | |
>>> del a, b, c # for illustration only | |
>>> # the content of a, b, and prod_2 are still alive on GPU | |
>>> # the content of prod_1 and c only live on CPU | |
>>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward | |
>>> # all intermediary tensors are released (deleted) after the call to backward | |
""" | |
def __init__(self, pin_memory=False, device_type="cuda"): | |
device_module = getattr(torch, device_type, torch.cuda) | |
def pack_to_cpu(tensor): | |
if not pin_memory: | |
return (tensor.device, tensor.cpu()) | |
packed = torch.empty( | |
tensor.size(), | |
dtype=tensor.dtype, | |
layout=tensor.layout, | |
pin_memory=(device_module.is_available() and not tensor.is_sparse), | |
) | |
packed.copy_(tensor) | |
return (tensor.device, packed) | |
def unpack_from_cpu(packed): | |
device, tensor = packed | |
return tensor.to(device, non_blocking=pin_memory) | |
super().__init__(pack_to_cpu, unpack_from_cpu) | |
def disable_saved_tensors_hooks(error_message): | |
"""Context-manager that disables the saved tensors default hooks feature. | |
Useful for if you are creating a feature that does not work with saved | |
tensors default hooks. | |
Args: | |
error_message (str): When saved tensors default hooks are used when they | |
have been are disabled, a RuntimeError with this | |
error message gets raised. | |
Example:: | |
>>> # xdoctest: +SKIP(failing) | |
>>> message = "saved tensors default hooks are disabled" | |
>>> with torch.autograd.graph.disable_saved_tensors_hooks(message): | |
... # Raises RuntimeError: saved tensors default hooks are disabled | |
... with torch.autograd.graph.save_on_cpu(): | |
... pass | |
""" | |
try: | |
maybe_prev_message = ( | |
torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() | |
) | |
torch._C._autograd._saved_tensors_hooks_disable(error_message) | |
yield | |
finally: | |
# See NOTE: [disabled_error_message invariant] | |
if maybe_prev_message is None: | |
torch._C._autograd._saved_tensors_hooks_enable() | |
else: | |
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message) | |
def register_multi_grad_hook( | |
tensors: Sequence[torch.Tensor], | |
fn: Union[ | |
Callable[[Sequence[Optional[torch.Tensor]]], None], | |
Callable[[torch.Tensor], None], | |
], | |
*, | |
mode: str = "all", | |
): | |
r"""Register a multi-grad backward hook. | |
There are two supported modes: ``"all"`` and ``"any"``. | |
Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in | |
:attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but | |
is not part of the graph, or if a tensor is not needed to compute the gradients | |
for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call, | |
this tensor will be ignored and the hook will not wait for its gradient to be | |
computed. | |
After every non-ignored tensor's gradient has been computed, :attr:`fn` will be | |
called with those gradients. ``None`` will be passed for tensors that did not | |
have their gradients computed. | |
Under the ``"any"`` mode, the hook will be called after the first gradient | |
with respect to a tensor in :attr:`tensors` has been computed. The hook | |
will be called with that gradient as its argument. | |
The hook should not modify its arguments. | |
This function returns a handle with a method ``handle.remove()`` that removes the hook. | |
.. note:: | |
See :ref:`backward-hooks-execution` for more information on how when this hook | |
is executed, and how its execution is ordered relative to other hooks. | |
Example:: | |
>>> import torch | |
>>> | |
>>> a = torch.rand(2, 3, requires_grad=True) | |
>>> b = torch.rand(2, 3, requires_grad=True) | |
>>> c = a * b | |
>>> d = a * b | |
>>> | |
>>> def fn(grads): | |
... print([g is not None for g in grads]) | |
... | |
>>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn) | |
>>> | |
>>> c.sum().backward(retain_graph=True) | |
[True, True, True, False] | |
>>> c.sum().backward(inputs=(a,), retain_graph=True) | |
[True, False, True, False] | |
>>> | |
""" | |
supported_modes = ("all", "any") | |
if mode not in supported_modes: | |
raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}") | |
class Handle(RemovableHandle): | |
handles: Tuple[RemovableHandle, ...] | |
def __init__(self, handles: Tuple[RemovableHandle, ...]): | |
self.handles = handles | |
def remove(self): | |
for handle in self.handles: | |
handle.remove() | |
def __getstate__(self): | |
return self.handles | |
def __setstate__(self, state): | |
self.handles = state | |
if mode == "all": | |
count: Dict[int, int] = dict() | |
nb_calls = None | |
buffer: Dict[int, List[Optional[torch.Tensor]]] = dict() | |
grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors)) | |
len_tensors = len(tensors) | |
def get_inner_hook(idx): | |
def inner_hook(grad: torch.Tensor): | |
nonlocal count, nb_calls, buffer, fn | |
id = torch._C._current_graph_task_id() | |
assert ( | |
id != -1 | |
), "expected this hook to be called inside a backward call" | |
count[id] = count.get(id, 0) | |
buffer[id] = buffer.get(id, [None] * len_tensors) | |
if count[id] == 0: | |
# On the first call, compute the actual nb_calls and buffer | |
nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns) # type: ignore[attr-defined] | |
buffer[id][idx] = grad | |
count[id] += 1 | |
if count[id] == nb_calls: | |
fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn) | |
fn(buffer[id]) | |
del count[id] | |
del buffer[id] | |
return inner_hook | |
handles: Tuple[RemovableHandle] = tuple( | |
t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors) | |
) | |
elif mode == "any": | |
fn = cast(Callable[[torch.Tensor], None], fn) | |
lock = threading.Lock() | |
ran_hook: Dict[int, bool] = defaultdict(bool) | |
def wrapped_fn(grad: torch.Tensor): | |
nonlocal ran_hook | |
id = torch._C._current_graph_task_id() | |
assert id != -1, "expected this hook to be called inside a backward call" | |
with lock: | |
prev, ran_hook[id] = ran_hook[id], True | |
if prev: | |
return | |
fn(grad) | |
handles = tuple( | |
tensor.register_hook(wrapped_fn) | |
for tensor in tensors | |
if tensor.requires_grad | |
) | |
return Handle(handles) # type: ignore[possibly-undefined] | |
# NOTE [Allow mutation on tensors saved for backward] | |
# | |
# 1. Tensor gets saved for backward | |
# - remember the python object id and the version of the tensor | |
# - remember aliasing information (data_ptr of base + version) | |
# - save the original so we control its lifetime | |
# 2. Any time a tensor gets in-placed | |
# - for each tensor aliased to it: | |
# - check using its object id and version to see if it has been saved | |
# - if it has been saved, clone it | |
# - delete the reference to the original | |
# 3. during backward | |
# - if the clone exists, the tensor must've been modified in-place | |
_allow_mutation_on_saved_tensors_enabled = False | |
def _get_tid(t) -> Tuple[int, int, int]: | |
return (id(t), t.data_ptr(), t._version) | |
def _get_sid(t) -> Tuple[int, int]: | |
return (t.data_ptr(), t._version) | |
class _Handle: | |
pass | |
class _swap_with_cloned(saved_tensors_hooks): | |
def __init__(self, ctx): | |
def pack_hook(t): | |
tid = _get_tid(t) | |
sid = _get_sid(t) | |
# Tensors saved for backward have an entry in _tid_to_weakhandle | |
handle: Optional[_Handle] = None | |
# Save aliasing information | |
ctx.sid_to_tid[sid].add(tid) | |
# NB: The same tensor (of the same version) can be saved multiple times | |
if tid not in ctx.tid_to_weakhandle: | |
handle = _Handle() | |
ctx.tid_to_weakhandle[tid] = handle | |
ctx.original[handle] = t | |
else: | |
# Store an additional strong reference to the handle | |
handle = ctx.tid_to_weakhandle[tid] | |
return handle | |
def unpack_hook(tup): | |
handle = tup | |
error_msg = ( | |
"Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" | |
"in which the graph was originally recorded." | |
) | |
assert _allow_mutation_on_saved_tensors_enabled, error_msg | |
if handle in ctx.cloned: | |
res = ctx.cloned[handle] | |
else: | |
assert handle in ctx.original, error_msg | |
res = ctx.original[handle] | |
return res | |
super().__init__(pack_hook, unpack_hook) | |
class _CloneArgBeforeMutateMode(TorchDispatchMode): | |
def __init__(self, ctx): | |
self.ctx = ctx | |
def __torch_dispatch__(self, func, types, args=(), kwargs=None): | |
kwargs = kwargs or {} | |
for idx, arg in enumerate(func._schema.arguments): | |
if arg.alias_info is not None and arg.alias_info.is_write: | |
t = kwargs["out"] if arg.is_out else args[idx] | |
tid = _get_tid(t) | |
sid = _get_sid(t) | |
ctx = self.ctx | |
if sid in ctx.sid_to_tid: | |
for tid in ctx.sid_to_tid[sid]: | |
if tid not in ctx.tid_to_weakhandle: | |
# We know that if tid is in sid_to_tid, then it must also be in | |
# tid_to_weakhandle. However, it is possible for the tensor to be | |
# saved at one point, but cleared by backward before it is modified | |
# in-place. Consider the following example: | |
# | |
# >>> a = torch.randn(2, 3, requires_grad=True).clone() | |
# >>> out = (a**2).sum() | |
# >>> out.backward() | |
# >>> a.sin_() | |
continue | |
handle = ctx.tid_to_weakhandle[tid] | |
if handle in ctx.cloned: | |
# The same exact tensor has been cloned already | |
continue | |
ctx.cloned[handle] = ctx.original[handle].clone() | |
del ctx.original[handle] | |
rs = func(*args, **kwargs) | |
return rs | |
class _AllowMutationOnSavedContext: | |
def __init__(self): | |
self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() | |
self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() | |
self.tid_to_weakhandle: weakref.WeakValueDictionary = ( | |
weakref.WeakValueDictionary() | |
) | |
self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict( | |
set | |
) | |
def clear(self): | |
self.cloned.clear() | |
self.original.clear() | |
self.tid_to_weakhandle.clear() | |
self.sid_to_tid.clear() | |
def allow_mutation_on_saved_tensors(): | |
"""Context manager under which mutating tensors saved for backward is allowed. | |
Under this context manager, tensors saved for backward are cloned on mutation, | |
so the original version can still be used during backward. Normally, mutating a tensor | |
saved for backward will result in an error raised when it's used during backward. | |
To ensure the correct behavior, both the forward and backward should be run under | |
the same context manager. | |
returns: | |
An _AllowMutationOnSavedContext object storing the state managed by this | |
context manager. This object can be useful for debugging purposes. The state | |
managed by the context manager is automatically cleared upon exiting. | |
Example:: | |
>>> import torch | |
>>> with torch.autograd.graph.allow_mutation_on_saved_tensors(): | |
... # forward | |
... a = torch.ones(2, 3, requires_grad=True) | |
... b = a.clone() | |
... out = (b**2).sum() | |
... b.sin_() | |
... # backward | |
... out.sum().backward() | |
... | |
tensor([[0.8415, 0.8415, 0.8415], | |
[0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>) | |
""" | |
global _allow_mutation_on_saved_tensors_enabled | |
ctx = _AllowMutationOnSavedContext() | |
with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx): | |
try: | |
if _allow_mutation_on_saved_tensors_enabled: | |
raise RuntimeError( | |
"allow_mutation_on_saved_tensors contexts cannot be nested" | |
) | |
_allow_mutation_on_saved_tensors_enabled = True | |
yield ctx | |
finally: | |
ctx.clear() | |
_allow_mutation_on_saved_tensors_enabled = False | |
def _register_logging_hooks_on_whole_graph(t_outputs: List[torch.Tensor]): | |
grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs)) | |
def iter_graph(roots): | |
if not roots: | |
return | |
seen = set() | |
q: Deque = collections.deque() | |
for node in roots: | |
if node is not None: | |
seen.add(node) | |
q.append(node) | |
while q: | |
node = q.popleft() | |
for fn, _idx in node.next_functions: | |
if fn in seen or fn is None: | |
continue | |
seen.add(fn) | |
q.append(fn) | |
yield node | |
def fmt(t): | |
# Avoid circular import | |
from torch.testing._internal.common_utils import dtype_abbrs | |
if t is None: | |
return "None" | |
return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]" | |
def prehook(grad_outputs): | |
node = torch._C._current_autograd_node() | |
grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]" | |
log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}" | |
log.debug(log_str) | |
handles = [] | |
for node in iter_graph(grad_fns): | |
handles.append(node.register_prehook(prehook)) | |
def unregister_hooks(): | |
for handle in handles: | |
handle.remove() | |
return unregister_hooks | |
def _engine_run_backward(t_outputs, *args, **kwargs): | |
attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG | |
if attach_logging_hooks: | |
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) | |
try: | |
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass | |
t_outputs, *args, **kwargs | |
) # Calls into the C++ engine to run the backward pass | |
finally: | |
if attach_logging_hooks: | |
unregister_hooks() # type: ignore[possibly-undefined] | |