Spaces:
Sleeping
Sleeping
File size: 9,780 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
import torch
from collections import OrderedDict
import weakref
import warnings
from typing import Any, Tuple
__all__ = ["RemovableHandle", "unserializable_hook", "warn_if_has_hooks", "BackwardHook"]
class RemovableHandle:
r"""
A handle which provides the capability to remove a hook.
Args:
hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``.
extra_dict (Union[dict, List[dict]]): An additional dictionary or list of
dictionaries whose keys will be deleted when the same keys are
removed from ``hooks_dict``.
"""
id: int
next_id: int = 0
def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None:
self.hooks_dict_ref = weakref.ref(hooks_dict)
self.id = RemovableHandle.next_id
RemovableHandle.next_id += 1
self.extra_dict_ref: Tuple = ()
if isinstance(extra_dict, dict):
self.extra_dict_ref = (weakref.ref(extra_dict),)
elif isinstance(extra_dict, list):
self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict)
def remove(self) -> None:
hooks_dict = self.hooks_dict_ref()
if hooks_dict is not None and self.id in hooks_dict:
del hooks_dict[self.id]
for ref in self.extra_dict_ref:
extra_dict = ref()
if extra_dict is not None and self.id in extra_dict:
del extra_dict[self.id]
def __getstate__(self):
if self.extra_dict_ref is None:
return (self.hooks_dict_ref(), self.id)
else:
return (self.hooks_dict_ref(), self.id, tuple(ref() for ref in self.extra_dict_ref))
def __setstate__(self, state) -> None:
if state[0] is None:
# create a dead reference
self.hooks_dict_ref = weakref.ref(OrderedDict())
else:
self.hooks_dict_ref = weakref.ref(state[0])
self.id = state[1]
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
if len(state) < 3 or state[2] is None:
self.extra_dict_ref = ()
else:
self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2])
def __enter__(self) -> "RemovableHandle":
return self
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
self.remove()
def unserializable_hook(f):
"""
Mark a function as an unserializable hook with this decorator.
This suppresses warnings that would otherwise arise if you attempt
to serialize a tensor that has a hook.
"""
f.__torch_unserializable__ = True
return f
def warn_if_has_hooks(tensor):
if tensor._backward_hooks:
for k in tensor._backward_hooks:
hook = tensor._backward_hooks[k]
if not hasattr(k, "__torch_unserializable__"):
warnings.warn(f"backward hook {repr(hook)} on tensor will not be "
"serialized. If this is expected, you can "
"decorate the function with @torch.utils.hooks.unserializable_hook "
"to suppress this warning")
class BackwardHook:
"""
A wrapper class to implement nn.Module backward hooks.
It handles:
- Ignoring non-Tensor inputs and replacing them by None before calling the user hook
- Generating the proper Node to capture a set of Tensor's gradients
- Linking the gradients captures for the outputs with the gradients captured for the input
- Calling the user hook once both output and input gradients are available
"""
def __init__(self, module, user_hooks, user_pre_hooks):
self.user_hooks = user_hooks
self.user_pre_hooks = user_pre_hooks
self.module = module
self.grad_outputs = None
self.n_outputs = -1
self.output_tensors_index = None
self.n_inputs = -1
self.input_tensors_index = None
def _pack_with_none(self, indices, values, size):
res = [None] * size
for idx, val in zip(indices, values):
res[idx] = val
return tuple(res)
def _unpack_none(self, indices, values):
res = []
for idx in indices:
res.append(values[idx])
return tuple(res)
def _set_user_hook(self, grad_fn):
def hook(grad_input, _):
if self.grad_outputs is None:
# This happens because the gradient in your nn.Module flows to
# the Module's input without " passing through the Module's
# output, e.g. when you're doing double backward.
return
res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs)
for hook in self.user_hooks:
out = hook(self.module, res, self.grad_outputs)
if out is None:
continue
if len(out) != len(res):
raise RuntimeError("Backward hook returned an invalid number of grad_input, "
f"got {len(out)}, but expected {len(res)}")
res = out
self.grad_outputs = None
return self._unpack_none(self.input_tensors_index, res)
grad_fn.register_hook(hook)
def _apply_on_tensors(self, fn, args):
# Can be used to apply the given function to the tensors contained in the
# args. Will return updated args and the tensors indices
tensors_idx = []
tensors = []
requires_grad = False
for i, arg in enumerate(args):
if isinstance(arg, torch.Tensor):
tensors_idx.append(i)
tensors.append(arg)
requires_grad |= arg.requires_grad
if not (requires_grad and torch.is_grad_enabled()):
return args, None
new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors)
if len(new_tensors) == 0:
raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.")
grad_fns = [t.grad_fn for t in new_tensors if t.grad_fn is not None and t.grad_fn.name() == "BackwardHookFunctionBackward"]
if len(grad_fns) == 0:
raise RuntimeError("Error while setting up backward hooks. Please open "
"an issue with a code sample to reproduce this.")
fn(grad_fns[0])
arg_list = list(args)
for idx, val in zip(tensors_idx, new_tensors):
arg_list[idx] = val
if type(args) is tuple:
out = tuple(arg_list)
else:
out = type(args)(*arg_list)
return out, tensors_idx
def setup_input_hook(self, args):
def fn(grad_fn):
self._set_user_hook(grad_fn)
res, input_idx = self._apply_on_tensors(fn, args)
self.n_inputs = len(args)
self.input_tensors_index = input_idx
return res
def setup_output_hook(self, args):
def fn(grad_fn):
def hook(_, grad_output):
self.grad_outputs = self._pack_with_none(self.output_tensors_index,
grad_output,
self.n_outputs)
if self.user_pre_hooks:
expected_len = len(self.grad_outputs)
for user_pre_hook in self.user_pre_hooks:
hook_grad_outputs = user_pre_hook(self.module, self.grad_outputs)
if hook_grad_outputs is None:
continue
actual_len = len(hook_grad_outputs)
if actual_len != expected_len:
raise RuntimeError("Backward pre hook returned an invalid number of grad_output, "
f"got {actual_len}, but expected {expected_len}")
self.grad_outputs = hook_grad_outputs
# We need to be able to clear self.grad_outputs but also return it
local_grad_outputs = self.grad_outputs
# Special case if no input required gradients, this hook should call the user
# hook directly
if self.input_tensors_index is None:
grad_inputs = self._pack_with_none([], [], self.n_inputs)
for user_hook in self.user_hooks:
res = user_hook(self.module, grad_inputs, self.grad_outputs)
if res is not None and not (isinstance(res, tuple) and all(el is None for el in res)):
raise RuntimeError("Backward hook for Modules where no input requires "
"gradient should always return None or None for all gradients.")
self.grad_outputs = None
if local_grad_outputs is not None:
assert self.output_tensors_index is not None # mypy
return tuple(local_grad_outputs[i] for i in self.output_tensors_index)
grad_fn.register_hook(hook)
is_tuple = True
if not isinstance(args, tuple):
args = (args,)
is_tuple = False
res, output_idx = self._apply_on_tensors(fn, args)
self.n_outputs = len(args)
self.output_tensors_index = output_idx
if not is_tuple:
res = res[0]
return res
|