Spaces:
Sleeping
Sleeping
from collections import OrderedDict, namedtuple | |
import itertools | |
import warnings | |
import functools | |
import weakref | |
import torch | |
from torch._prims_common import DeviceLikeType | |
from ..parameter import Parameter | |
import torch.utils.hooks as hooks | |
from torch import Tensor, device, dtype | |
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List | |
from typing_extensions import Self | |
from ...utils.hooks import RemovableHandle | |
from torch.utils._python_dispatch import is_traceable_wrapper_subclass | |
__all__ = ['register_module_forward_pre_hook', 'register_module_forward_hook', | |
'register_module_full_backward_pre_hook', 'register_module_backward_hook', | |
'register_module_full_backward_hook', 'register_module_buffer_registration_hook', | |
'register_module_module_registration_hook', 'register_module_parameter_registration_hook', 'Module'] | |
_grad_t = Union[Tuple[Tensor, ...], Tensor] | |
# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use | |
# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be | |
# the type of the subclass, not the looser type of `Module`. | |
T = TypeVar('T', bound='Module') | |
class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): | |
def __repr__(self): | |
if not self.missing_keys and not self.unexpected_keys: | |
return '<All keys matched successfully>' | |
return super().__repr__() | |
__str__ = __repr__ | |
def _addindent(s_, numSpaces): | |
s = s_.split('\n') | |
# don't do anything for single-line stuff | |
if len(s) == 1: | |
return s_ | |
first = s.pop(0) | |
s = [(numSpaces * ' ') + line for line in s] | |
s = '\n'.join(s) | |
s = first + '\n' + s | |
return s | |
r"""This tracks hooks common to all modules that are executed immediately before | |
.registering the buffer/module/parameter""" | |
_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict() | |
_global_module_registration_hooks: Dict[int, Callable] = OrderedDict() | |
_global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict() | |
class _WrappedHook: | |
def __init__(self, hook: Callable, module: Optional["Module"] = None): | |
self.hook: Callable = hook | |
functools.update_wrapper(self, hook) | |
self.with_module: bool = False | |
if module is not None: | |
self.module: weakref.ReferenceType[Module] = weakref.ref(module) | |
self.with_module = True | |
def __call__(self, *args: Any, **kwargs: Any) -> Any: | |
if self.with_module: | |
module = self.module() | |
if module is None: | |
raise RuntimeError("You are trying to call the hook of a dead Module!") | |
return self.hook(module, *args, **kwargs) | |
return self.hook(*args, **kwargs) | |
def __getstate__(self) -> Dict: | |
result = {"hook": self.hook, "with_module": self.with_module} | |
if self.with_module: | |
result["module"] = self.module() | |
return result | |
def __setstate__(self, state: Dict): | |
self.hook = state["hook"] | |
self.with_module = state["with_module"] | |
if self.with_module: | |
if state["module"] is None: | |
raise RuntimeError("You are trying to revive the hook of a dead Module!") | |
self.module = weakref.ref(state["module"]) | |
r"""This tracks hooks common to all modules that are executed before/after | |
calling forward and backward. This is global state used for debugging/profiling | |
purposes""" | |
_global_backward_pre_hooks: Dict[int, Callable] = OrderedDict() | |
_global_backward_hooks: Dict[int, Callable] = OrderedDict() | |
_global_is_full_backward_hook: Optional[bool] = None | |
_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict() | |
_global_forward_hooks: Dict[int, Callable] = OrderedDict() | |
_global_forward_hooks_always_called: Dict[int, bool] = OrderedDict() | |
_EXTRA_STATE_KEY_SUFFIX = '_extra_state' | |
def register_module_buffer_registration_hook(hook: Callable[..., None]) -> RemovableHandle: | |
r"""Register a buffer registration hook common to all modules. | |
.. warning :: | |
This adds global state to the `nn.Module` module | |
The hook will be called every time :func:`register_buffer` is invoked. | |
It should have the following signature:: | |
hook(module, name, buffer) -> None or new buffer | |
The hook can modify the input or return a single modified value in the hook. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(_global_buffer_registration_hooks) | |
_global_buffer_registration_hooks[handle.id] = hook | |
return handle | |
def register_module_module_registration_hook(hook: Callable[..., None]) -> RemovableHandle: | |
r"""Register a module registration hook common to all modules. | |
.. warning :: | |
This adds global state to the `nn.Module` module | |
The hook will be called every time :func:`register_module` is invoked. | |
It should have the following signature:: | |
hook(module, name, submodule) -> None or new submodule | |
The hook can modify the input or return a single modified value in the hook. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(_global_module_registration_hooks) | |
_global_module_registration_hooks[handle.id] = hook | |
return handle | |
def register_module_parameter_registration_hook(hook: Callable[..., None]) -> RemovableHandle: | |
r"""Register a parameter registration hook common to all modules. | |
.. warning :: | |
This adds global state to the `nn.Module` module | |
The hook will be called every time :func:`register_parameter` is invoked. | |
It should have the following signature:: | |
hook(module, name, param) -> None or new parameter | |
The hook can modify the input or return a single modified value in the hook. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(_global_parameter_registration_hooks) | |
_global_parameter_registration_hooks[handle.id] = hook | |
return handle | |
def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle: | |
r"""Register a forward pre-hook common to all modules. | |
.. warning :: | |
This adds global state to the `nn.module` module | |
and it is only intended for debugging/profiling purposes. | |
The hook will be called every time before :func:`forward` is invoked. | |
It should have the following signature:: | |
hook(module, input) -> None or modified input | |
The input contains only the positional arguments given to the module. | |
Keyword arguments won't be passed to the hooks and only to the ``forward``. | |
The hook can modify the input. User can either return a tuple or a | |
single modified value in the hook. We will wrap the value into a tuple | |
if a single value is returned(unless that value is already a tuple). | |
This hook has precedence over the specific module hooks registered with | |
``register_forward_pre_hook``. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(_global_forward_pre_hooks) | |
_global_forward_pre_hooks[handle.id] = hook | |
return handle | |
def register_module_forward_hook(hook: Callable[..., None], *, always_call: bool = False) -> RemovableHandle: | |
r"""Register a global forward hook for all the modules. | |
.. warning :: | |
This adds global state to the `nn.module` module | |
and it is only intended for debugging/profiling purposes. | |
The hook will be called every time after :func:`forward` has computed an output. | |
It should have the following signature:: | |
hook(module, input, output) -> None or modified output | |
The input contains only the positional arguments given to the module. | |
Keyword arguments won't be passed to the hooks and only to the ``forward``. | |
The hook can modify the output. It can modify the input inplace but | |
it will not have effect on forward since this is called after | |
:func:`forward` is called. | |
Parameters: | |
hook (Callable): The user defined hook to be registered. | |
always_call (bool): If ``True`` the ``hook`` will be run regardless of | |
whether an exception is raised while calling the Module. | |
Default: ``False`` | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
This hook will be executed before specific module hooks registered with | |
``register_forward_hook``. | |
""" | |
handle = hooks.RemovableHandle(_global_forward_hooks, | |
extra_dict=_global_forward_hooks_always_called) | |
_global_forward_hooks[handle.id] = hook | |
if always_call: | |
_global_forward_hooks_always_called[handle.id] = True | |
return handle | |
def register_module_backward_hook( | |
hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]] | |
) -> RemovableHandle: | |
r"""Register a backward hook common to all the modules. | |
This function is deprecated in favor of | |
:func:`torch.nn.modules.module.register_module_full_backward_hook` | |
and the behavior of this function will change in future versions. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
global _global_is_full_backward_hook | |
if _global_is_full_backward_hook is True: | |
raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a " | |
"global Module hook. Please use only one of them.") | |
_global_is_full_backward_hook = False | |
handle = hooks.RemovableHandle(_global_backward_hooks) | |
_global_backward_hooks[handle.id] = hook | |
return handle | |
def register_module_full_backward_pre_hook( | |
hook: Callable[['Module', _grad_t], Union[None, _grad_t]] | |
) -> RemovableHandle: | |
r"""Register a backward pre-hook common to all the modules. | |
.. warning :: | |
This adds global state to the `nn.module` module | |
and it is only intended for debugging/profiling purposes. | |
Hooks registered using this function behave in the same way as those | |
registered by :meth:`torch.nn.Module.register_full_backward_pre_hook`. | |
Refer to its documentation for more details. | |
Hooks registered using this function will be called before hooks registered | |
using :meth:`torch.nn.Module.register_full_backward_pre_hook`. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(_global_backward_pre_hooks) | |
_global_backward_pre_hooks[handle.id] = hook | |
return handle | |
def register_module_full_backward_hook( | |
hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]] | |
) -> RemovableHandle: | |
r"""Register a backward hook common to all the modules. | |
.. warning :: | |
This adds global state to the `nn.module` module | |
and it is only intended for debugging/profiling purposes. | |
Hooks registered using this function behave in the same way as those | |
registered by :meth:`torch.nn.Module.register_full_backward_hook`. | |
Refer to its documentation for more details. | |
Hooks registered using this function will be called before hooks registered | |
using :meth:`torch.nn.Module.register_full_backward_hook`. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
global _global_is_full_backward_hook | |
if _global_is_full_backward_hook is False: | |
raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a " | |
"global Module hook. Please use only one of them.") | |
_global_is_full_backward_hook = True | |
handle = hooks.RemovableHandle(_global_backward_hooks) | |
_global_backward_hooks[handle.id] = hook | |
return handle | |
# Trick mypy into not applying contravariance rules to inputs by defining | |
# forward as a value, rather than a function. See also | |
# https://github.com/python/mypy/issues/8795 | |
def _forward_unimplemented(self, *input: Any) -> None: | |
r"""Define the computation performed at every call. | |
Should be overridden by all subclasses. | |
.. note:: | |
Although the recipe for forward pass needs to be defined within | |
this function, one should call the :class:`Module` instance afterwards | |
instead of this since the former takes care of running the | |
registered hooks while the latter silently ignores them. | |
""" | |
raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function") | |
class Module: | |
r"""Base class for all neural network modules. | |
Your models should also subclass this class. | |
Modules can also contain other Modules, allowing to nest them in | |
a tree structure. You can assign the submodules as regular attributes:: | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv2d(1, 20, 5) | |
self.conv2 = nn.Conv2d(20, 20, 5) | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
return F.relu(self.conv2(x)) | |
Submodules assigned in this way will be registered, and will have their | |
parameters converted too when you call :meth:`to`, etc. | |
.. note:: | |
As per the example above, an ``__init__()`` call to the parent class | |
must be made before assignment on the child. | |
:ivar training: Boolean represents whether this module is in training or | |
evaluation mode. | |
:vartype training: bool | |
""" | |
dump_patches: bool = False | |
_version: int = 1 | |
r"""This allows better BC support for :meth:`load_state_dict`. In | |
:meth:`state_dict`, the version number will be saved as in the attribute | |
`_metadata` of the returned state dict, and thus pickled. `_metadata` is a | |
dictionary with keys that follow the naming convention of state dict. See | |
``_load_from_state_dict`` on how to use this information in loading. | |
If new parameters/buffers are added/removed from a module, this number shall | |
be bumped, and the module's `_load_from_state_dict` method can compare the | |
version number and do appropriate changes if the state dict is from before | |
the change.""" | |
training: bool | |
_parameters: Dict[str, Optional[Parameter]] | |
_buffers: Dict[str, Optional[Tensor]] | |
_non_persistent_buffers_set: Set[str] | |
_backward_pre_hooks: Dict[int, Callable] | |
_backward_hooks: Dict[int, Callable] | |
_is_full_backward_hook: Optional[bool] | |
_forward_hooks: Dict[int, Callable] | |
# Marks whether the corresponding _forward_hooks accept kwargs or not. | |
# As JIT does not support Set[int], this dict is used as a set, where all | |
# hooks represented in this dict accept kwargs. | |
_forward_hooks_with_kwargs: Dict[int, bool] | |
# forward hooks that should always be called even if an exception is raised | |
_forward_hooks_always_called: Dict[int, bool] | |
_forward_pre_hooks: Dict[int, Callable] | |
# Marks whether the corresponding _forward_hooks accept kwargs or not. | |
# As JIT does not support Set[int], this dict is used as a set, where all | |
# hooks represented in this dict accept kwargs. | |
_forward_pre_hooks_with_kwargs: Dict[int, bool] | |
_state_dict_hooks: Dict[int, Callable] | |
_load_state_dict_pre_hooks: Dict[int, Callable] | |
_state_dict_pre_hooks: Dict[int, Callable] | |
_load_state_dict_post_hooks: Dict[int, Callable] | |
_modules: Dict[str, Optional['Module']] | |
call_super_init: bool = False | |
_compiled_call_impl : Optional[Callable] = None | |
def __init__(self, *args, **kwargs) -> None: | |
"""Initialize internal Module state, shared by both nn.Module and ScriptModule.""" | |
torch._C._log_api_usage_once("python.nn_module") | |
# Backward compatibility: no args used to be allowed when call_super_init=False | |
if self.call_super_init is False and bool(kwargs): | |
raise TypeError("{}.__init__() got an unexpected keyword argument '{}'" | |
"".format(type(self).__name__, next(iter(kwargs)))) | |
if self.call_super_init is False and bool(args): | |
raise TypeError(f"{type(self).__name__}.__init__() takes 1 positional argument but {len(args) + 1} were" | |
" given") | |
""" | |
Calls super().__setattr__('a', a) instead of the typical self.a = a | |
to avoid Module.__setattr__ overhead. Module's __setattr__ has special | |
handling for parameters, submodules, and buffers but simply calls into | |
super().__setattr__ for all other attributes. | |
""" | |
super().__setattr__('training', True) | |
super().__setattr__('_parameters', OrderedDict()) | |
super().__setattr__('_buffers', OrderedDict()) | |
super().__setattr__('_non_persistent_buffers_set', set()) | |
super().__setattr__('_backward_pre_hooks', OrderedDict()) | |
super().__setattr__('_backward_hooks', OrderedDict()) | |
super().__setattr__('_is_full_backward_hook', None) | |
super().__setattr__('_forward_hooks', OrderedDict()) | |
super().__setattr__('_forward_hooks_with_kwargs', OrderedDict()) | |
super().__setattr__('_forward_hooks_always_called', OrderedDict()) | |
super().__setattr__('_forward_pre_hooks', OrderedDict()) | |
super().__setattr__('_forward_pre_hooks_with_kwargs', OrderedDict()) | |
super().__setattr__('_state_dict_hooks', OrderedDict()) | |
super().__setattr__('_state_dict_pre_hooks', OrderedDict()) | |
super().__setattr__('_load_state_dict_pre_hooks', OrderedDict()) | |
super().__setattr__('_load_state_dict_post_hooks', OrderedDict()) | |
super().__setattr__('_modules', OrderedDict()) | |
if self.call_super_init: | |
super().__init__(*args, **kwargs) | |
forward: Callable[..., Any] = _forward_unimplemented | |
def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None: | |
r"""Add a buffer to the module. | |
This is typically used to register a buffer that should not to be | |
considered a model parameter. For example, BatchNorm's ``running_mean`` | |
is not a parameter, but is part of the module's state. Buffers, by | |
default, are persistent and will be saved alongside parameters. This | |
behavior can be changed by setting :attr:`persistent` to ``False``. The | |
only difference between a persistent buffer and a non-persistent buffer | |
is that the latter will not be a part of this module's | |
:attr:`state_dict`. | |
Buffers can be accessed as attributes using given names. | |
Args: | |
name (str): name of the buffer. The buffer can be accessed | |
from this module using the given name | |
tensor (Tensor or None): buffer to be registered. If ``None``, then operations | |
that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, | |
the buffer is **not** included in the module's :attr:`state_dict`. | |
persistent (bool): whether the buffer is part of this module's | |
:attr:`state_dict`. | |
Example:: | |
>>> # xdoctest: +SKIP("undefined vars") | |
>>> self.register_buffer('running_mean', torch.zeros(num_features)) | |
""" | |
if persistent is False and isinstance(self, torch.jit.ScriptModule): | |
raise RuntimeError("ScriptModule does not support non-persistent buffers") | |
if '_buffers' not in self.__dict__: | |
raise AttributeError( | |
"cannot assign buffer before Module.__init__() call") | |
elif not isinstance(name, str): | |
raise TypeError(f"buffer name should be a string. Got {torch.typename(name)}") | |
elif '.' in name: | |
raise KeyError("buffer name can't contain \".\"") | |
elif name == '': | |
raise KeyError("buffer name can't be empty string \"\"") | |
elif hasattr(self, name) and name not in self._buffers: | |
raise KeyError(f"attribute '{name}' already exists") | |
elif tensor is not None and not isinstance(tensor, torch.Tensor): | |
raise TypeError(f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' " | |
"(torch Tensor or None required)" | |
) | |
else: | |
for hook in _global_buffer_registration_hooks.values(): | |
output = hook(self, name, tensor) | |
if output is not None: | |
tensor = output | |
self._buffers[name] = tensor | |
if persistent: | |
self._non_persistent_buffers_set.discard(name) | |
else: | |
self._non_persistent_buffers_set.add(name) | |
def register_parameter(self, name: str, param: Optional[Parameter]) -> None: | |
r"""Add a parameter to the module. | |
The parameter can be accessed as an attribute using given name. | |
Args: | |
name (str): name of the parameter. The parameter can be accessed | |
from this module using the given name | |
param (Parameter or None): parameter to be added to the module. If | |
``None``, then operations that run on parameters, such as :attr:`cuda`, | |
are ignored. If ``None``, the parameter is **not** included in the | |
module's :attr:`state_dict`. | |
""" | |
if '_parameters' not in self.__dict__: | |
raise AttributeError( | |
"cannot assign parameter before Module.__init__() call") | |
elif not isinstance(name, str): | |
raise TypeError(f"parameter name should be a string. Got {torch.typename(name)}") | |
elif '.' in name: | |
raise KeyError("parameter name can't contain \".\"") | |
elif name == '': | |
raise KeyError("parameter name can't be empty string \"\"") | |
elif hasattr(self, name) and name not in self._parameters: | |
raise KeyError(f"attribute '{name}' already exists") | |
if param is None: | |
self._parameters[name] = None | |
elif not isinstance(param, Parameter): | |
raise TypeError(f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " | |
"(torch.nn.Parameter or None required)" | |
) | |
elif param.grad_fn: | |
raise ValueError( | |
f"Cannot assign non-leaf Tensor to parameter '{name}'. Model " | |
f"parameters must be created explicitly. To express '{name}' " | |
"as a function of another Tensor, compute the value in " | |
"the forward() method.") | |
else: | |
for hook in _global_parameter_registration_hooks.values(): | |
output = hook(self, name, param) | |
if output is not None: | |
param = output | |
self._parameters[name] = param | |
def add_module(self, name: str, module: Optional['Module']) -> None: | |
r"""Add a child module to the current module. | |
The module can be accessed as an attribute using the given name. | |
Args: | |
name (str): name of the child module. The child module can be | |
accessed from this module using the given name | |
module (Module): child module to be added to the module. | |
""" | |
if not isinstance(module, Module) and module is not None: | |
raise TypeError(f"{torch.typename(module)} is not a Module subclass") | |
elif not isinstance(name, str): | |
raise TypeError(f"module name should be a string. Got {torch.typename(name)}") | |
elif hasattr(self, name) and name not in self._modules: | |
raise KeyError(f"attribute '{name}' already exists") | |
elif '.' in name: | |
raise KeyError(f"module name can't contain \".\", got: {name}") | |
elif name == '': | |
raise KeyError("module name can't be empty string \"\"") | |
for hook in _global_module_registration_hooks.values(): | |
output = hook(self, name, module) | |
if output is not None: | |
module = output | |
self._modules[name] = module | |
def register_module(self, name: str, module: Optional['Module']) -> None: | |
r"""Alias for :func:`add_module`.""" | |
self.add_module(name, module) | |
def get_submodule(self, target: str) -> "Module": | |
"""Return the submodule given by ``target`` if it exists, otherwise throw an error. | |
For example, let's say you have an ``nn.Module`` ``A`` that | |
looks like this: | |
.. code-block:: text | |
A( | |
(net_b): Module( | |
(net_c): Module( | |
(conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) | |
) | |
(linear): Linear(in_features=100, out_features=200, bias=True) | |
) | |
) | |
(The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested | |
submodule ``net_b``, which itself has two submodules ``net_c`` | |
and ``linear``. ``net_c`` then has a submodule ``conv``.) | |
To check whether or not we have the ``linear`` submodule, we | |
would call ``get_submodule("net_b.linear")``. To check whether | |
we have the ``conv`` submodule, we would call | |
``get_submodule("net_b.net_c.conv")``. | |
The runtime of ``get_submodule`` is bounded by the degree | |
of module nesting in ``target``. A query against | |
``named_modules`` achieves the same result, but it is O(N) in | |
the number of transitive modules. So, for a simple check to see | |
if some submodule exists, ``get_submodule`` should always be | |
used. | |
Args: | |
target: The fully-qualified string name of the submodule | |
to look for. (See above example for how to specify a | |
fully-qualified string.) | |
Returns: | |
torch.nn.Module: The submodule referenced by ``target`` | |
Raises: | |
AttributeError: If the target string references an invalid | |
path or resolves to something that is not an | |
``nn.Module`` | |
""" | |
if target == "": | |
return self | |
atoms: List[str] = target.split(".") | |
mod: torch.nn.Module = self | |
for item in atoms: | |
if not hasattr(mod, item): | |
raise AttributeError(mod._get_name() + " has no " | |
"attribute `" + item + "`") | |
mod = getattr(mod, item) | |
if not isinstance(mod, torch.nn.Module): | |
raise AttributeError("`" + item + "` is not " | |
"an nn.Module") | |
return mod | |
def get_parameter(self, target: str) -> "Parameter": | |
"""Return the parameter given by ``target`` if it exists, otherwise throw an error. | |
See the docstring for ``get_submodule`` for a more detailed | |
explanation of this method's functionality as well as how to | |
correctly specify ``target``. | |
Args: | |
target: The fully-qualified string name of the Parameter | |
to look for. (See ``get_submodule`` for how to specify a | |
fully-qualified string.) | |
Returns: | |
torch.nn.Parameter: The Parameter referenced by ``target`` | |
Raises: | |
AttributeError: If the target string references an invalid | |
path or resolves to something that is not an | |
``nn.Parameter`` | |
""" | |
module_path, _, param_name = target.rpartition(".") | |
mod: torch.nn.Module = self.get_submodule(module_path) | |
if not hasattr(mod, param_name): | |
raise AttributeError(mod._get_name() + " has no attribute `" | |
+ param_name + "`") | |
param: torch.nn.Parameter = getattr(mod, param_name) | |
if not isinstance(param, torch.nn.Parameter): | |
raise AttributeError("`" + param_name + "` is not an " | |
"nn.Parameter") | |
return param | |
def get_buffer(self, target: str) -> "Tensor": | |
"""Return the buffer given by ``target`` if it exists, otherwise throw an error. | |
See the docstring for ``get_submodule`` for a more detailed | |
explanation of this method's functionality as well as how to | |
correctly specify ``target``. | |
Args: | |
target: The fully-qualified string name of the buffer | |
to look for. (See ``get_submodule`` for how to specify a | |
fully-qualified string.) | |
Returns: | |
torch.Tensor: The buffer referenced by ``target`` | |
Raises: | |
AttributeError: If the target string references an invalid | |
path or resolves to something that is not a | |
buffer | |
""" | |
module_path, _, buffer_name = target.rpartition(".") | |
mod: torch.nn.Module = self.get_submodule(module_path) | |
if not hasattr(mod, buffer_name): | |
raise AttributeError(mod._get_name() + " has no attribute `" | |
+ buffer_name + "`") | |
buffer: torch.Tensor = getattr(mod, buffer_name) | |
if buffer_name not in mod._buffers: | |
raise AttributeError("`" + buffer_name + "` is not a buffer") | |
return buffer | |
def get_extra_state(self) -> Any: | |
"""Return any extra state to include in the module's state_dict. | |
Implement this and a corresponding :func:`set_extra_state` for your module | |
if you need to store extra state. This function is called when building the | |
module's `state_dict()`. | |
Note that extra state should be picklable to ensure working serialization | |
of the state_dict. We only provide provide backwards compatibility guarantees | |
for serializing Tensors; other objects may break backwards compatibility if | |
their serialized pickled form changes. | |
Returns: | |
object: Any extra state to store in the module's state_dict | |
""" | |
raise RuntimeError( | |
"Reached a code path in Module.get_extra_state() that should never be called. " | |
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " | |
"to report this bug.") | |
def set_extra_state(self, state: Any) -> None: | |
"""Set extra state contained in the loaded `state_dict`. | |
This function is called from :func:`load_state_dict` to handle any extra state | |
found within the `state_dict`. Implement this function and a corresponding | |
:func:`get_extra_state` for your module if you need to store extra state within its | |
`state_dict`. | |
Args: | |
state (dict): Extra state from the `state_dict` | |
""" | |
raise RuntimeError( | |
"Reached a code path in Module.set_extra_state() that should never be called. " | |
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " | |
"to report this bug.") | |
def _apply(self, fn, recurse=True): | |
if recurse: | |
for module in self.children(): | |
module._apply(fn) | |
def compute_should_use_set_data(tensor, tensor_applied): | |
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): | |
# If the new tensor has compatible tensor type as the existing tensor, | |
# the current behavior is to change the tensor in-place using `.data =`, | |
# and the future behavior is to overwrite the existing tensor. However, | |
# changing the current behavior is a BC-breaking change, and we want it | |
# to happen in future releases. So for now we introduce the | |
# `torch.__future__.get_overwrite_module_params_on_conversion()` | |
# global flag to let the user control whether they want the future | |
# behavior of overwriting the existing tensor or not. | |
return not torch.__future__.get_overwrite_module_params_on_conversion() | |
else: | |
return False | |
should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() | |
for key, param in self._parameters.items(): | |
if param is None: | |
continue | |
# Tensors stored in modules are graph leaves, and we don't want to | |
# track autograd history of `param_applied`, so we have to use | |
# `with torch.no_grad():` | |
with torch.no_grad(): | |
param_applied = fn(param) | |
p_should_use_set_data = compute_should_use_set_data(param, param_applied) | |
# subclasses may have multiple child tensors so we need to use swap_tensors | |
p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) | |
param_grad = param.grad | |
if p_should_use_swap_tensors: | |
try: | |
if param_grad is not None: | |
# Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. | |
# Decrement use count of the gradient by setting to None | |
param.grad = None | |
param_applied = torch.nn.Parameter(param_applied, requires_grad=param.requires_grad) | |
torch.utils.swap_tensors(param, param_applied) | |
except Exception as e: | |
if param_grad is not None: | |
param.grad = param_grad | |
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}") from e | |
out_param = param | |
elif p_should_use_set_data: | |
param.data = param_applied | |
out_param = param | |
else: | |
assert isinstance(param, Parameter) | |
assert param.is_leaf | |
out_param = Parameter(param_applied, param.requires_grad) | |
self._parameters[key] = out_param | |
if param_grad is not None: | |
with torch.no_grad(): | |
grad_applied = fn(param_grad) | |
g_should_use_set_data = compute_should_use_set_data(param_grad, grad_applied) | |
if p_should_use_swap_tensors: | |
grad_applied.requires_grad_(param_grad.requires_grad) | |
try: | |
torch.utils.swap_tensors(param_grad, grad_applied) | |
except Exception as e: | |
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}.grad") from e | |
out_param.grad = param_grad | |
elif g_should_use_set_data: | |
assert out_param.grad is not None | |
out_param.grad.data = grad_applied | |
else: | |
assert param_grad.is_leaf | |
out_param.grad = grad_applied.requires_grad_(param_grad.requires_grad) | |
for key, buf in self._buffers.items(): | |
if buf is not None: | |
self._buffers[key] = fn(buf) | |
return self | |
def apply(self: T, fn: Callable[['Module'], None]) -> T: | |
r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. | |
Typical use includes initializing the parameters of a model | |
(see also :ref:`nn-init-doc`). | |
Args: | |
fn (:class:`Module` -> None): function to be applied to each submodule | |
Returns: | |
Module: self | |
Example:: | |
>>> @torch.no_grad() | |
>>> def init_weights(m): | |
>>> print(m) | |
>>> if type(m) == nn.Linear: | |
>>> m.weight.fill_(1.0) | |
>>> print(m.weight) | |
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) | |
>>> net.apply(init_weights) | |
Linear(in_features=2, out_features=2, bias=True) | |
Parameter containing: | |
tensor([[1., 1.], | |
[1., 1.]], requires_grad=True) | |
Linear(in_features=2, out_features=2, bias=True) | |
Parameter containing: | |
tensor([[1., 1.], | |
[1., 1.]], requires_grad=True) | |
Sequential( | |
(0): Linear(in_features=2, out_features=2, bias=True) | |
(1): Linear(in_features=2, out_features=2, bias=True) | |
) | |
""" | |
for module in self.children(): | |
module.apply(fn) | |
fn(self) | |
return self | |
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: | |
r"""Move all model parameters and buffers to the GPU. | |
This also makes associated parameters and buffers different objects. So | |
it should be called before constructing optimizer if the module will | |
live on GPU while being optimized. | |
.. note:: | |
This method modifies the module in-place. | |
Args: | |
device (int, optional): if specified, all parameters will be | |
copied to that device | |
Returns: | |
Module: self | |
""" | |
return self._apply(lambda t: t.cuda(device)) | |
def ipu(self: T, device: Optional[Union[int, device]] = None) -> T: | |
r"""Move all model parameters and buffers to the IPU. | |
This also makes associated parameters and buffers different objects. So | |
it should be called before constructing optimizer if the module will | |
live on IPU while being optimized. | |
.. note:: | |
This method modifies the module in-place. | |
Arguments: | |
device (int, optional): if specified, all parameters will be | |
copied to that device | |
Returns: | |
Module: self | |
""" | |
return self._apply(lambda t: t.ipu(device)) | |
def xpu(self: T, device: Optional[Union[int, device]] = None) -> T: | |
r"""Move all model parameters and buffers to the XPU. | |
This also makes associated parameters and buffers different objects. So | |
it should be called before constructing optimizer if the module will | |
live on XPU while being optimized. | |
.. note:: | |
This method modifies the module in-place. | |
Arguments: | |
device (int, optional): if specified, all parameters will be | |
copied to that device | |
Returns: | |
Module: self | |
""" | |
return self._apply(lambda t: t.xpu(device)) | |
def cpu(self: T) -> T: | |
r"""Move all model parameters and buffers to the CPU. | |
.. note:: | |
This method modifies the module in-place. | |
Returns: | |
Module: self | |
""" | |
return self._apply(lambda t: t.cpu()) | |
def type(self: T, dst_type: Union[dtype, str]) -> T: | |
r"""Casts all parameters and buffers to :attr:`dst_type`. | |
.. note:: | |
This method modifies the module in-place. | |
Args: | |
dst_type (type or string): the desired type | |
Returns: | |
Module: self | |
""" | |
return self._apply(lambda t: t.type(dst_type)) | |
def float(self: T) -> T: | |
r"""Casts all floating point parameters and buffers to ``float`` datatype. | |
.. note:: | |
This method modifies the module in-place. | |
Returns: | |
Module: self | |
""" | |
return self._apply(lambda t: t.float() if t.is_floating_point() else t) | |
def double(self: T) -> T: | |
r"""Casts all floating point parameters and buffers to ``double`` datatype. | |
.. note:: | |
This method modifies the module in-place. | |
Returns: | |
Module: self | |
""" | |
return self._apply(lambda t: t.double() if t.is_floating_point() else t) | |
def half(self: T) -> T: | |
r"""Casts all floating point parameters and buffers to ``half`` datatype. | |
.. note:: | |
This method modifies the module in-place. | |
Returns: | |
Module: self | |
""" | |
return self._apply(lambda t: t.half() if t.is_floating_point() else t) | |
def bfloat16(self: T) -> T: | |
r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype. | |
.. note:: | |
This method modifies the module in-place. | |
Returns: | |
Module: self | |
""" | |
return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) | |
def to_empty(self: T, *, device: Optional[DeviceLikeType], recurse: bool = True) -> T: | |
r"""Move the parameters and buffers to the specified device without copying storage. | |
Args: | |
device (:class:`torch.device`): The desired device of the parameters | |
and buffers in this module. | |
recurse (bool): Whether parameters and buffers of submodules should | |
be recursively moved to the specified device. | |
Returns: | |
Module: self | |
""" | |
return self._apply(lambda t: torch.empty_like(t, device=device), recurse=recurse) | |
def to(self, device: Optional[DeviceLikeType] = ..., dtype: Optional[dtype] = ..., | |
non_blocking: bool = ...) -> Self: | |
... | |
def to(self, dtype: dtype, non_blocking: bool = ...) -> Self: | |
... | |
def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self: | |
... | |
def to(self, *args, **kwargs): | |
r"""Move and/or cast the parameters and buffers. | |
This can be called as | |
.. function:: to(device=None, dtype=None, non_blocking=False) | |
:noindex: | |
.. function:: to(dtype, non_blocking=False) | |
:noindex: | |
.. function:: to(tensor, non_blocking=False) | |
:noindex: | |
.. function:: to(memory_format=torch.channels_last) | |
:noindex: | |
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts | |
floating point or complex :attr:`dtype`\ s. In addition, this method will | |
only cast the floating point or complex parameters and buffers to :attr:`dtype` | |
(if given). The integral parameters and buffers will be moved | |
:attr:`device`, if that is given, but with dtypes unchanged. When | |
:attr:`non_blocking` is set, it tries to convert/move asynchronously | |
with respect to the host if possible, e.g., moving CPU Tensors with | |
pinned memory to CUDA devices. | |
See below for examples. | |
.. note:: | |
This method modifies the module in-place. | |
Args: | |
device (:class:`torch.device`): the desired device of the parameters | |
and buffers in this module | |
dtype (:class:`torch.dtype`): the desired floating point or complex dtype of | |
the parameters and buffers in this module | |
tensor (torch.Tensor): Tensor whose dtype and device are the desired | |
dtype and device for all parameters and buffers in this module | |
memory_format (:class:`torch.memory_format`): the desired memory | |
format for 4D parameters and buffers in this module (keyword | |
only argument) | |
Returns: | |
Module: self | |
Examples:: | |
>>> # xdoctest: +IGNORE_WANT("non-deterministic") | |
>>> linear = nn.Linear(2, 2) | |
>>> linear.weight | |
Parameter containing: | |
tensor([[ 0.1913, -0.3420], | |
[-0.5113, -0.2325]]) | |
>>> linear.to(torch.double) | |
Linear(in_features=2, out_features=2, bias=True) | |
>>> linear.weight | |
Parameter containing: | |
tensor([[ 0.1913, -0.3420], | |
[-0.5113, -0.2325]], dtype=torch.float64) | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) | |
>>> gpu1 = torch.device("cuda:1") | |
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True) | |
Linear(in_features=2, out_features=2, bias=True) | |
>>> linear.weight | |
Parameter containing: | |
tensor([[ 0.1914, -0.3420], | |
[-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') | |
>>> cpu = torch.device("cpu") | |
>>> linear.to(cpu) | |
Linear(in_features=2, out_features=2, bias=True) | |
>>> linear.weight | |
Parameter containing: | |
tensor([[ 0.1914, -0.3420], | |
[-0.5112, -0.2324]], dtype=torch.float16) | |
>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) | |
>>> linear.weight | |
Parameter containing: | |
tensor([[ 0.3741+0.j, 0.2382+0.j], | |
[ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) | |
>>> linear(torch.ones(3, 2, dtype=torch.cdouble)) | |
tensor([[0.6122+0.j, 0.1150+0.j], | |
[0.6122+0.j, 0.1150+0.j], | |
[0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) | |
""" | |
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) | |
if dtype is not None: | |
if not (dtype.is_floating_point or dtype.is_complex): | |
raise TypeError('nn.Module.to only accepts floating point or complex ' | |
f'dtypes, but got desired dtype={dtype}') | |
if dtype.is_complex: | |
warnings.warn( | |
"Complex modules are a new feature under active development whose design may change, " | |
"and some modules might not work as expected when using complex tensors as parameters or buffers. " | |
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " | |
"if a complex module does not work as expected.") | |
def convert(t): | |
try: | |
if convert_to_format is not None and t.dim() in (4, 5): | |
return t.to( | |
device, | |
dtype if t.is_floating_point() or t.is_complex() else None, | |
non_blocking, | |
memory_format=convert_to_format, | |
) | |
return t.to( | |
device, | |
dtype if t.is_floating_point() or t.is_complex() else None, | |
non_blocking, | |
) | |
except NotImplementedError as e: | |
if str(e) == "Cannot copy out of meta tensor; no data!": | |
raise NotImplementedError( | |
f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() " | |
f"when moving module from meta to a different device." | |
) from None | |
else: | |
raise | |
return self._apply(convert) | |
def register_full_backward_pre_hook( | |
self, | |
hook: Callable[["Module", _grad_t], Union[None, _grad_t]], | |
prepend: bool = False, | |
) -> RemovableHandle: | |
r"""Register a backward pre-hook on the module. | |
The hook will be called every time the gradients for the module are computed. | |
The hook should have the following signature:: | |
hook(module, grad_output) -> tuple[Tensor] or None | |
The :attr:`grad_output` is a tuple. The hook should | |
not modify its arguments, but it can optionally return a new gradient with | |
respect to the output that will be used in place of :attr:`grad_output` in | |
subsequent computations. Entries in :attr:`grad_output` will be ``None`` for | |
all non-Tensor arguments. | |
For technical reasons, when this hook is applied to a Module, its forward function will | |
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view | |
of each Tensor returned by the Module's forward function. | |
.. warning :: | |
Modifying inputs inplace is not allowed when using backward hooks and | |
will raise an error. | |
Args: | |
hook (Callable): The user-defined hook to be registered. | |
prepend (bool): If true, the provided ``hook`` will be fired before | |
all existing ``backward_pre`` hooks on this | |
:class:`torch.nn.modules.Module`. Otherwise, the provided | |
``hook`` will be fired after all existing ``backward_pre`` hooks | |
on this :class:`torch.nn.modules.Module`. Note that global | |
``backward_pre`` hooks registered with | |
:func:`register_module_full_backward_pre_hook` will fire before | |
all hooks registered by this method. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(self._backward_pre_hooks) | |
self._backward_pre_hooks[handle.id] = hook | |
if prepend: | |
self._backward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] | |
return handle | |
def register_backward_hook( | |
self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]] | |
) -> RemovableHandle: | |
r"""Register a backward hook on the module. | |
This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and | |
the behavior of this function will change in future versions. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
if self._is_full_backward_hook is True: | |
raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a " | |
"single Module. Please use only one of them.") | |
self._is_full_backward_hook = False | |
handle = hooks.RemovableHandle(self._backward_hooks) | |
self._backward_hooks[handle.id] = hook | |
return handle | |
def register_full_backward_hook( | |
self, | |
hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], | |
prepend: bool = False, | |
) -> RemovableHandle: | |
r"""Register a backward hook on the module. | |
The hook will be called every time the gradients with respect to a module | |
are computed, i.e. the hook will execute if and only if the gradients with | |
respect to module outputs are computed. The hook should have the following | |
signature:: | |
hook(module, grad_input, grad_output) -> tuple(Tensor) or None | |
The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients | |
with respect to the inputs and outputs respectively. The hook should | |
not modify its arguments, but it can optionally return a new gradient with | |
respect to the input that will be used in place of :attr:`grad_input` in | |
subsequent computations. :attr:`grad_input` will only correspond to the inputs given | |
as positional arguments and all kwarg arguments are ignored. Entries | |
in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor | |
arguments. | |
For technical reasons, when this hook is applied to a Module, its forward function will | |
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view | |
of each Tensor returned by the Module's forward function. | |
.. warning :: | |
Modifying inputs or outputs inplace is not allowed when using backward hooks and | |
will raise an error. | |
Args: | |
hook (Callable): The user-defined hook to be registered. | |
prepend (bool): If true, the provided ``hook`` will be fired before | |
all existing ``backward`` hooks on this | |
:class:`torch.nn.modules.Module`. Otherwise, the provided | |
``hook`` will be fired after all existing ``backward`` hooks on | |
this :class:`torch.nn.modules.Module`. Note that global | |
``backward`` hooks registered with | |
:func:`register_module_full_backward_hook` will fire before | |
all hooks registered by this method. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
if self._is_full_backward_hook is False: | |
raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a " | |
"single Module. Please use only one of them.") | |
self._is_full_backward_hook = True | |
handle = hooks.RemovableHandle(self._backward_hooks) | |
self._backward_hooks[handle.id] = hook | |
if prepend: | |
self._backward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] | |
return handle | |
def _get_backward_hooks(self): | |
r"""Return the backward hooks for use in the call function. | |
It returns two lists, one with the full backward hooks and one with the non-full | |
backward hooks. | |
""" | |
full_backward_hooks: List[Callable] = [] | |
if (_global_is_full_backward_hook is True): | |
full_backward_hooks += _global_backward_hooks.values() | |
if (self._is_full_backward_hook is True): | |
full_backward_hooks += self._backward_hooks.values() | |
non_full_backward_hooks: List[Callable] = [] | |
if (_global_is_full_backward_hook is False): | |
non_full_backward_hooks += _global_backward_hooks.values() | |
if (self._is_full_backward_hook is False): | |
non_full_backward_hooks += self._backward_hooks.values() | |
return full_backward_hooks, non_full_backward_hooks | |
def _get_backward_pre_hooks(self): | |
backward_pre_hooks: List[Callable] = [] | |
backward_pre_hooks += _global_backward_pre_hooks.values() | |
backward_pre_hooks += self._backward_pre_hooks.values() | |
return backward_pre_hooks | |
def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): | |
if not isinstance(result, torch.Tensor): | |
if not (isinstance(result, tuple) and all(isinstance(r, torch.Tensor) for r in result)): | |
warnings.warn("Using non-full backward hooks on a Module that does not return a " | |
"single Tensor or a tuple of Tensors is deprecated and will be removed " | |
"in future versions. This hook will be missing some of the grad_output. " | |
"Please use register_full_backward_hook to get the documented behavior.") | |
return | |
else: | |
result = (result,) | |
if not isinstance(inputs, torch.Tensor): | |
if not (isinstance(inputs, tuple) and all(isinstance(i, torch.Tensor) for i in inputs)): | |
warnings.warn("Using non-full backward hooks on a Module that does not take as input a " | |
"single Tensor or a tuple of Tensors is deprecated and will be removed " | |
"in future versions. This hook will be missing some of the grad_input. " | |
"Please use register_full_backward_hook to get the documented behavior.") | |
return | |
else: | |
inputs = (inputs,) | |
# At this point we are sure that inputs and result are tuple of Tensors | |
out_grad_fn = {r.grad_fn for r in result if r.grad_fn is not None} | |
if len(out_grad_fn) == 0 or (len(out_grad_fn) == 1 and grad_fn not in out_grad_fn): | |
warnings.warn("Using a non-full backward hook when outputs are nested in python data structure " | |
"is deprecated and will be removed in future versions. This hook will be missing " | |
"some grad_output.") | |
elif len(out_grad_fn) > 1: | |
warnings.warn("Using a non-full backward hook when outputs are generated by different autograd Nodes " | |
"is deprecated and will be removed in future versions. This hook will be missing " | |
"some grad_output. Please use register_full_backward_hook to get the documented behavior.") | |
else: | |
# At this point the grad_output part of the hook will most likely be correct | |
inputs_grad_fn = {i.grad_fn for i in inputs if i.grad_fn is not None} | |
next_functions = {n[0] for n in grad_fn.next_functions} | |
if inputs_grad_fn != next_functions: | |
warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes " | |
"is deprecated and will be removed in future versions. This hook will be missing " | |
"some grad_input. Please use register_full_backward_hook to get the documented " | |
"behavior.") | |
def register_forward_pre_hook( | |
self, | |
hook: Union[ | |
Callable[[T, Tuple[Any, ...]], Optional[Any]], | |
Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]], | |
], | |
*, | |
prepend: bool = False, | |
with_kwargs: bool = False, | |
) -> RemovableHandle: | |
r"""Register a forward pre-hook on the module. | |
The hook will be called every time before :func:`forward` is invoked. | |
If ``with_kwargs`` is false or not specified, the input contains only | |
the positional arguments given to the module. Keyword arguments won't be | |
passed to the hooks and only to the ``forward``. The hook can modify the | |
input. User can either return a tuple or a single modified value in the | |
hook. We will wrap the value into a tuple if a single value is returned | |
(unless that value is already a tuple). The hook should have the | |
following signature:: | |
hook(module, args) -> None or modified input | |
If ``with_kwargs`` is true, the forward pre-hook will be passed the | |
kwargs given to the forward function. And if the hook modifies the | |
input, both the args and kwargs should be returned. The hook should have | |
the following signature:: | |
hook(module, args, kwargs) -> None or a tuple of modified input and kwargs | |
Args: | |
hook (Callable): The user defined hook to be registered. | |
prepend (bool): If true, the provided ``hook`` will be fired before | |
all existing ``forward_pre`` hooks on this | |
:class:`torch.nn.modules.Module`. Otherwise, the provided | |
``hook`` will be fired after all existing ``forward_pre`` hooks | |
on this :class:`torch.nn.modules.Module`. Note that global | |
``forward_pre`` hooks registered with | |
:func:`register_module_forward_pre_hook` will fire before all | |
hooks registered by this method. | |
Default: ``False`` | |
with_kwargs (bool): If true, the ``hook`` will be passed the kwargs | |
given to the forward function. | |
Default: ``False`` | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle( | |
self._forward_pre_hooks, | |
extra_dict=self._forward_pre_hooks_with_kwargs | |
) | |
self._forward_pre_hooks[handle.id] = hook | |
if with_kwargs: | |
self._forward_pre_hooks_with_kwargs[handle.id] = True | |
if prepend: | |
self._forward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] | |
return handle | |
def register_forward_hook( | |
self, | |
hook: Union[ | |
Callable[[T, Tuple[Any, ...], Any], Optional[Any]], | |
Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], | |
], | |
*, | |
prepend: bool = False, | |
with_kwargs: bool = False, | |
always_call: bool = False, | |
) -> RemovableHandle: | |
r"""Register a forward hook on the module. | |
The hook will be called every time after :func:`forward` has computed an output. | |
If ``with_kwargs`` is ``False`` or not specified, the input contains only | |
the positional arguments given to the module. Keyword arguments won't be | |
passed to the hooks and only to the ``forward``. The hook can modify the | |
output. It can modify the input inplace but it will not have effect on | |
forward since this is called after :func:`forward` is called. The hook | |
should have the following signature:: | |
hook(module, args, output) -> None or modified output | |
If ``with_kwargs`` is ``True``, the forward hook will be passed the | |
``kwargs`` given to the forward function and be expected to return the | |
output possibly modified. The hook should have the following signature:: | |
hook(module, args, kwargs, output) -> None or modified output | |
Args: | |
hook (Callable): The user defined hook to be registered. | |
prepend (bool): If ``True``, the provided ``hook`` will be fired | |
before all existing ``forward`` hooks on this | |
:class:`torch.nn.modules.Module`. Otherwise, the provided | |
``hook`` will be fired after all existing ``forward`` hooks on | |
this :class:`torch.nn.modules.Module`. Note that global | |
``forward`` hooks registered with | |
:func:`register_module_forward_hook` will fire before all hooks | |
registered by this method. | |
Default: ``False`` | |
with_kwargs (bool): If ``True``, the ``hook`` will be passed the | |
kwargs given to the forward function. | |
Default: ``False`` | |
always_call (bool): If ``True`` the ``hook`` will be run regardless of | |
whether an exception is raised while calling the Module. | |
Default: ``False`` | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle( | |
self._forward_hooks, | |
extra_dict=[self._forward_hooks_with_kwargs, self._forward_hooks_always_called], | |
) | |
self._forward_hooks[handle.id] = hook | |
if with_kwargs: | |
self._forward_hooks_with_kwargs[handle.id] = True | |
if always_call: | |
self._forward_hooks_always_called[handle.id] = True | |
if prepend: | |
self._forward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] | |
return handle | |
def _slow_forward(self, *input, **kwargs): | |
tracing_state = torch._C._get_tracing_state() | |
if not tracing_state or isinstance(self.forward, torch._C.ScriptMethod): | |
return self.forward(*input, **kwargs) | |
recording_scopes = torch.jit._trace._trace_module_map is not None | |
if recording_scopes: | |
# type ignore was added because at this point one knows that | |
# torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] | |
name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950 | |
if name: | |
tracing_state.push_scope(name) | |
else: | |
recording_scopes = False | |
try: | |
result = self.forward(*input, **kwargs) | |
finally: | |
if recording_scopes: | |
tracing_state.pop_scope() | |
return result | |
def _wrapped_call_impl(self, *args, **kwargs): | |
if self._compiled_call_impl is not None: | |
return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] | |
else: | |
return self._call_impl(*args, **kwargs) | |
def _call_impl(self, *args, **kwargs): | |
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward) | |
# If we don't have any hooks, we want to skip the rest of the logic in | |
# this function, and just call forward. | |
if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks | |
or _global_backward_pre_hooks or _global_backward_hooks | |
or _global_forward_hooks or _global_forward_pre_hooks): | |
return forward_call(*args, **kwargs) | |
try: | |
result = None | |
called_always_called_hooks = set() | |
full_backward_hooks, non_full_backward_hooks = [], [] | |
backward_pre_hooks = [] | |
if self._backward_pre_hooks or _global_backward_pre_hooks: | |
backward_pre_hooks = self._get_backward_pre_hooks() | |
if self._backward_hooks or _global_backward_hooks: | |
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() | |
if _global_forward_pre_hooks or self._forward_pre_hooks: | |
for hook_id, hook in ( | |
*_global_forward_pre_hooks.items(), | |
*self._forward_pre_hooks.items(), | |
): | |
if hook_id in self._forward_pre_hooks_with_kwargs: | |
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] | |
if args_kwargs_result is not None: | |
if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2: | |
args, kwargs = args_kwargs_result | |
else: | |
raise RuntimeError( | |
"forward pre-hook must return None or a tuple " | |
f"of (new_args, new_kwargs), but got {args_kwargs_result}." | |
) | |
else: | |
args_result = hook(self, args) | |
if args_result is not None: | |
if not isinstance(args_result, tuple): | |
args_result = (args_result,) | |
args = args_result | |
bw_hook = None | |
if full_backward_hooks or backward_pre_hooks: | |
bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) | |
args = bw_hook.setup_input_hook(args) | |
result = forward_call(*args, **kwargs) | |
if _global_forward_hooks or self._forward_hooks: | |
for hook_id, hook in ( | |
*_global_forward_hooks.items(), | |
*self._forward_hooks.items(), | |
): | |
# mark that always called hook is run | |
if hook_id in self._forward_hooks_always_called or hook_id in _global_forward_hooks_always_called: | |
called_always_called_hooks.add(hook_id) | |
if hook_id in self._forward_hooks_with_kwargs: | |
hook_result = hook(self, args, kwargs, result) | |
else: | |
hook_result = hook(self, args, result) | |
if hook_result is not None: | |
result = hook_result | |
if bw_hook: | |
if not isinstance(result, (torch.Tensor, tuple)): | |
warnings.warn("For backward hooks to be called," | |
" module output should be a Tensor or a tuple of Tensors" | |
f" but received {type(result)}") | |
result = bw_hook.setup_output_hook(result) | |
# Handle the non-full backward hooks | |
if non_full_backward_hooks: | |
var = result | |
while not isinstance(var, torch.Tensor): | |
if isinstance(var, dict): | |
var = next(v for v in var.values() if isinstance(v, torch.Tensor)) | |
else: | |
var = var[0] | |
grad_fn = var.grad_fn | |
if grad_fn is not None: | |
for hook in non_full_backward_hooks: | |
grad_fn.register_hook(_WrappedHook(hook, self)) | |
self._maybe_warn_non_full_backward_hook(args, result, grad_fn) | |
return result | |
except Exception: | |
# run always called hooks if they have not already been run | |
# For now only forward hooks have the always_call option but perhaps | |
# this functionality should be added to full backward hooks as well. | |
for hook_id, hook in _global_forward_hooks.items(): | |
if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] | |
try: | |
hook_result = hook(self, args, result) # type: ignore[possibly-undefined] | |
if hook_result is not None: | |
result = hook_result | |
except Exception as e: | |
warnings.warn("global module forward hook with ``always_call=True`` raised an exception " | |
f"that was silenced as another error was raised in forward: {str(e)}") | |
continue | |
for hook_id, hook in self._forward_hooks.items(): | |
if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] | |
try: | |
if hook_id in self._forward_hooks_with_kwargs: | |
hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined] | |
else: | |
hook_result = hook(self, args, result) # type: ignore[possibly-undefined] | |
if hook_result is not None: | |
result = hook_result | |
except Exception as e: | |
warnings.warn("module forward hook with ``always_call=True`` raised an exception " | |
f"that was silenced as another error was raised in forward: {str(e)}") | |
continue | |
# raise exception raised in try block | |
raise | |
__call__ : Callable[..., Any] = _wrapped_call_impl | |
def __getstate__(self): | |
state = self.__dict__.copy() | |
state.pop("_compiled_call_impl", None) | |
return state | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |
# Support loading old checkpoints that don't have the following attrs: | |
if '_forward_pre_hooks' not in self.__dict__: | |
self._forward_pre_hooks = OrderedDict() | |
if '_forward_pre_hooks_with_kwargs' not in self.__dict__: | |
self._forward_pre_hooks_with_kwargs = OrderedDict() | |
if '_forward_hooks_with_kwargs' not in self.__dict__: | |
self._forward_hooks_with_kwargs = OrderedDict() | |
if '_forward_hooks_always_called' not in self.__dict__: | |
self._forward_hooks_always_called = OrderedDict() | |
if '_state_dict_hooks' not in self.__dict__: | |
self._state_dict_hooks = OrderedDict() | |
if '_state_dict_pre_hooks' not in self.__dict__: | |
self._state_dict_pre_hooks = OrderedDict() | |
if '_load_state_dict_pre_hooks' not in self.__dict__: | |
self._load_state_dict_pre_hooks = OrderedDict() | |
if '_load_state_dict_post_hooks' not in self.__dict__: | |
self._load_state_dict_post_hooks = OrderedDict() | |
if '_non_persistent_buffers_set' not in self.__dict__: | |
self._non_persistent_buffers_set = set() | |
if '_is_full_backward_hook' not in self.__dict__: | |
self._is_full_backward_hook = None | |
if '_backward_pre_hooks' not in self.__dict__: | |
self._backward_pre_hooks = OrderedDict() | |
# On the return type: | |
# We choose to return `Any` in the `__getattr__` type signature instead of a more strict `Union[Tensor, Module]`. | |
# This is done for better interop with various type checkers for the end users. | |
# Having a stricter return type doesn't play nicely with `register_buffer()` and forces | |
# people to excessively use type-ignores, asserts, casts, etc. | |
# See full discussion on the problems with returning `Union` here | |
# https://github.com/microsoft/pyright/issues/4213 | |
def __getattr__(self, name: str) -> Any: | |
if '_parameters' in self.__dict__: | |
_parameters = self.__dict__['_parameters'] | |
if name in _parameters: | |
return _parameters[name] | |
if '_buffers' in self.__dict__: | |
_buffers = self.__dict__['_buffers'] | |
if name in _buffers: | |
return _buffers[name] | |
if '_modules' in self.__dict__: | |
modules = self.__dict__['_modules'] | |
if name in modules: | |
return modules[name] | |
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") | |
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: | |
def remove_from(*dicts_or_sets): | |
for d in dicts_or_sets: | |
if name in d: | |
if isinstance(d, dict): | |
del d[name] | |
else: | |
d.discard(name) | |
params = self.__dict__.get('_parameters') | |
if isinstance(value, Parameter): | |
if params is None: | |
raise AttributeError( | |
"cannot assign parameters before Module.__init__() call") | |
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set) | |
self.register_parameter(name, value) | |
elif params is not None and name in params: | |
if value is not None: | |
raise TypeError(f"cannot assign '{torch.typename(value)}' as parameter '{name}' " | |
"(torch.nn.Parameter or None expected)" | |
) | |
self.register_parameter(name, value) | |
else: | |
modules = self.__dict__.get('_modules') | |
if isinstance(value, Module): | |
if modules is None: | |
raise AttributeError( | |
"cannot assign module before Module.__init__() call") | |
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set) | |
for hook in _global_module_registration_hooks.values(): | |
output = hook(self, name, value) | |
if output is not None: | |
value = output | |
modules[name] = value | |
elif modules is not None and name in modules: | |
if value is not None: | |
raise TypeError(f"cannot assign '{torch.typename(value)}' as child module '{name}' " | |
"(torch.nn.Module or None expected)" | |
) | |
for hook in _global_module_registration_hooks.values(): | |
output = hook(self, name, value) | |
if output is not None: | |
value = output | |
modules[name] = value | |
else: | |
buffers = self.__dict__.get('_buffers') | |
if buffers is not None and name in buffers: | |
if value is not None and not isinstance(value, torch.Tensor): | |
raise TypeError(f"cannot assign '{torch.typename(value)}' as buffer '{name}' " | |
"(torch.Tensor or None expected)" | |
) | |
for hook in _global_buffer_registration_hooks.values(): | |
output = hook(self, name, value) | |
if output is not None: | |
value = output | |
buffers[name] = value | |
else: | |
super().__setattr__(name, value) | |
def __delattr__(self, name): | |
if name in self._parameters: | |
del self._parameters[name] | |
elif name in self._buffers: | |
del self._buffers[name] | |
self._non_persistent_buffers_set.discard(name) | |
elif name in self._modules: | |
del self._modules[name] | |
else: | |
super().__delattr__(name) | |
def _register_state_dict_hook(self, hook): | |
r"""Register a state-dict hook. | |
These hooks will be called with arguments: `self`, `state_dict`, | |
`prefix`, `local_metadata`, after the `state_dict` of `self` is set. | |
Note that only parameters and buffers of `self` or its children are | |
guaranteed to exist in `state_dict`. The hooks may modify `state_dict` | |
inplace or return a new one. | |
""" | |
handle = hooks.RemovableHandle(self._state_dict_hooks) | |
self._state_dict_hooks[handle.id] = hook | |
return handle | |
def register_state_dict_pre_hook(self, hook): | |
r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. | |
These hooks will be called with arguments: ``self``, ``prefix``, | |
and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered | |
hooks can be used to perform pre-processing before the ``state_dict`` | |
call is made. | |
""" | |
handle = hooks.RemovableHandle(self._state_dict_pre_hooks) | |
self._state_dict_pre_hooks[handle.id] = hook | |
return handle | |
def _save_to_state_dict(self, destination, prefix, keep_vars): | |
r"""Save module state to the `destination` dictionary. | |
The `destination` dictionary will contain the state | |
of the module, but not its descendants. This is called on every | |
submodule in :meth:`~torch.nn.Module.state_dict`. | |
In rare cases, subclasses can achieve class-specific behavior by | |
overriding this method with custom logic. | |
Args: | |
destination (dict): a dict where state will be stored | |
prefix (str): the prefix for parameters and buffers used in this | |
module | |
""" | |
for name, param in self._parameters.items(): | |
if param is not None: | |
destination[prefix + name] = param if keep_vars else param.detach() | |
for name, buf in self._buffers.items(): | |
if buf is not None and name not in self._non_persistent_buffers_set: | |
destination[prefix + name] = buf if keep_vars else buf.detach() | |
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX | |
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: | |
destination[extra_state_key] = self.get_extra_state() | |
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns | |
# back that same object. But if they pass nothing, an `OrderedDict` is created and returned. | |
T_destination = TypeVar('T_destination', bound=Dict[str, Any]) | |
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: | |
... | |
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: | |
... | |
# TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows. | |
# Also remove the logic for arg parsing together. | |
def state_dict(self, *args, destination=None, prefix='', keep_vars=False): | |
r"""Return a dictionary containing references to the whole state of the module. | |
Both parameters and persistent buffers (e.g. running averages) are | |
included. Keys are corresponding parameter and buffer names. | |
Parameters and buffers set to ``None`` are not included. | |
.. note:: | |
The returned object is a shallow copy. It contains references | |
to the module's parameters and buffers. | |
.. warning:: | |
Currently ``state_dict()`` also accepts positional arguments for | |
``destination``, ``prefix`` and ``keep_vars`` in order. However, | |
this is being deprecated and keyword arguments will be enforced in | |
future releases. | |
.. warning:: | |
Please avoid the use of argument ``destination`` as it is not | |
designed for end-users. | |
Args: | |
destination (dict, optional): If provided, the state of module will | |
be updated into the dict and the same object is returned. | |
Otherwise, an ``OrderedDict`` will be created and returned. | |
Default: ``None``. | |
prefix (str, optional): a prefix added to parameter and buffer | |
names to compose the keys in state_dict. Default: ``''``. | |
keep_vars (bool, optional): by default the :class:`~torch.Tensor` s | |
returned in the state dict are detached from autograd. If it's | |
set to ``True``, detaching will not be performed. | |
Default: ``False``. | |
Returns: | |
dict: | |
a dictionary containing a whole state of the module | |
Example:: | |
>>> # xdoctest: +SKIP("undefined vars") | |
>>> module.state_dict().keys() | |
['bias', 'weight'] | |
""" | |
# TODO: Remove `args` and the parsing logic when BC allows. | |
if len(args) > 0: | |
if destination is None: | |
destination = args[0] | |
if len(args) > 1 and prefix == '': | |
prefix = args[1] | |
if len(args) > 2 and keep_vars is False: | |
keep_vars = args[2] | |
# DeprecationWarning is ignored by default | |
warnings.warn( | |
"Positional args are being deprecated, use kwargs instead. Refer to " | |
"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict" | |
" for details.") | |
if destination is None: | |
destination = OrderedDict() | |
destination._metadata = OrderedDict() | |
local_metadata = dict(version=self._version) | |
if hasattr(destination, "_metadata"): | |
destination._metadata[prefix[:-1]] = local_metadata | |
for hook in self._state_dict_pre_hooks.values(): | |
hook(self, prefix, keep_vars) | |
self._save_to_state_dict(destination, prefix, keep_vars) | |
for name, module in self._modules.items(): | |
if module is not None: | |
module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) | |
for hook in self._state_dict_hooks.values(): | |
hook_result = hook(self, destination, prefix, local_metadata) | |
if hook_result is not None: | |
destination = hook_result | |
return destination | |
def _register_load_state_dict_pre_hook(self, hook, with_module=False): | |
r"""Register a pre-hook for the :meth:`~torch.nn.Module.load_state_dict` method. | |
These hooks will be called with arguments: `state_dict`, `prefix`, | |
`local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, | |
`error_msgs`, before loading `state_dict` into `self`. These arguments | |
are exactly the same as those of `_load_from_state_dict`. | |
If ``with_module`` is ``True``, then the first argument to the hook is | |
an instance of the module. | |
Arguments: | |
hook (Callable): Callable hook that will be invoked before | |
loading the state dict. | |
with_module (bool, optional): Whether or not to pass the module | |
instance to the hook as the first parameter. | |
""" | |
handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks) | |
self._load_state_dict_pre_hooks[handle.id] = _WrappedHook(hook, self if with_module else None) | |
return handle | |
def register_load_state_dict_post_hook(self, hook): | |
r"""Register a post hook to be run after module's ``load_state_dict`` is called. | |
It should have the following signature:: | |
hook(module, incompatible_keys) -> None | |
The ``module`` argument is the current module that this hook is registered | |
on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting | |
of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` | |
is a ``list`` of ``str`` containing the missing keys and | |
``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. | |
The given incompatible_keys can be modified inplace if needed. | |
Note that the checks performed when calling :func:`load_state_dict` with | |
``strict=True`` are affected by modifications the hook makes to | |
``missing_keys`` or ``unexpected_keys``, as expected. Additions to either | |
set of keys will result in an error being thrown when ``strict=True``, and | |
clearing out both missing and unexpected keys will avoid an error. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(self._load_state_dict_post_hooks) | |
self._load_state_dict_post_hooks[handle.id] = hook | |
return handle | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |
missing_keys, unexpected_keys, error_msgs): | |
r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. | |
This is called on every submodule | |
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this | |
module in input :attr:`state_dict` is provided as :attr:`local_metadata`. | |
For state dicts without metadata, :attr:`local_metadata` is empty. | |
Subclasses can achieve class-specific backward compatible loading using | |
the version number at `local_metadata.get("version", None)`. | |
Additionally, :attr:`local_metadata` can also contain the key | |
`assign_to_params_buffers` that indicates whether keys should be | |
assigned their corresponding tensor in the state_dict. | |
.. note:: | |
:attr:`state_dict` is not the same object as the input | |
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So | |
it can be modified. | |
Args: | |
state_dict (dict): a dict containing parameters and | |
persistent buffers. | |
prefix (str): the prefix for parameters and buffers used in this | |
module | |
local_metadata (dict): a dict containing the metadata for this module. | |
See | |
strict (bool): whether to strictly enforce that the keys in | |
:attr:`state_dict` with :attr:`prefix` match the names of | |
parameters and buffers in this module | |
missing_keys (list of str): if ``strict=True``, add missing keys to | |
this list | |
unexpected_keys (list of str): if ``strict=True``, add unexpected | |
keys to this list | |
error_msgs (list of str): error messages should be added to this | |
list, and will be reported together in | |
:meth:`~torch.nn.Module.load_state_dict` | |
""" | |
for hook in self._load_state_dict_pre_hooks.values(): | |
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | |
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} | |
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) | |
local_state = {k: v for k, v in local_name_params if v is not None} | |
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) | |
use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() | |
for name, param in local_state.items(): | |
key = prefix + name | |
if key in state_dict: | |
input_param = state_dict[key] | |
if not torch.overrides.is_tensor_like(input_param): | |
error_msgs.append(f'While copying the parameter named "{key}", ' | |
'expected torch.Tensor or Tensor-like object from checkpoint but ' | |
f'received {type(input_param)}' | |
) | |
continue | |
# This is used to avoid copying uninitialized parameters into | |
# non-lazy modules, since they dont have the hook to do the checks | |
# in such case, it will error when accessing the .shape attribute. | |
is_param_lazy = torch.nn.parameter.is_lazy(param) | |
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ | |
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: | |
input_param = input_param[0] | |
if not is_param_lazy and input_param.shape != param.shape: | |
# local shape should match the one in checkpoint | |
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' | |
'the shape in current model is {}.' | |
.format(key, input_param.shape, param.shape)) | |
continue | |
if param.is_meta and not input_param.is_meta and not assign_to_params_buffers: | |
warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' | |
'parameter in the current model, which is a no-op. (Did you mean to ' | |
'pass `assign=True` to assign items in the state dictionary to their ' | |
'corresponding key in the module instead of copying them in place?)') | |
try: | |
with torch.no_grad(): | |
if use_swap_tensors: | |
new_input_param = param.module_load(input_param, assign=assign_to_params_buffers) | |
if id(new_input_param) == id(input_param) or id(new_input_param) == id(param): | |
raise RuntimeError("module_load returned one of self or other, please .detach() " | |
"the result if returning one of the inputs in module_load") | |
if (isinstance(param, torch.nn.Parameter)): | |
if not isinstance(new_input_param, torch.nn.Parameter): | |
new_input_param = torch.nn.Parameter(new_input_param, requires_grad=param.requires_grad) | |
else: | |
new_input_param.requires_grad_(param.requires_grad) | |
torch.utils.swap_tensors(param, new_input_param) | |
del new_input_param | |
elif assign_to_params_buffers: | |
# Shape checks are already done above | |
if (isinstance(param, torch.nn.Parameter)): | |
if not isinstance(input_param, torch.nn.Parameter): | |
input_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) | |
else: | |
input_param.requires_grad_(param.requires_grad) | |
setattr(self, name, input_param) | |
else: | |
param.copy_(input_param) | |
except Exception as ex: | |
action = "swapping" if use_swap_tensors else "copying" | |
error_msgs.append(f'While {action} the parameter named "{key}", ' | |
f'whose dimensions in the model are {param.size()} and ' | |
f'whose dimensions in the checkpoint are {input_param.size()}, ' | |
f'an exception occurred : {ex.args}.' | |
) | |
elif strict: | |
missing_keys.append(key) | |
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX | |
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: | |
if extra_state_key in state_dict: | |
self.set_extra_state(state_dict[extra_state_key]) | |
elif strict: | |
missing_keys.append(extra_state_key) | |
elif strict and (extra_state_key in state_dict): | |
unexpected_keys.append(extra_state_key) | |
if strict: | |
for key in state_dict.keys(): | |
if key.startswith(prefix) and key != extra_state_key: | |
input_name = key[len(prefix):] | |
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child | |
if input_name not in self._modules and input_name not in local_state: | |
unexpected_keys.append(key) | |
def load_state_dict(self, state_dict: Mapping[str, Any], | |
strict: bool = True, assign: bool = False): | |
r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. | |
If :attr:`strict` is ``True``, then | |
the keys of :attr:`state_dict` must exactly match the keys returned | |
by this module's :meth:`~torch.nn.Module.state_dict` function. | |
.. warning:: | |
If :attr:`assign` is ``True`` the optimizer must be created after | |
the call to :attr:`load_state_dict` unless | |
:func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. | |
Args: | |
state_dict (dict): a dict containing parameters and | |
persistent buffers. | |
strict (bool, optional): whether to strictly enforce that the keys | |
in :attr:`state_dict` match the keys returned by this module's | |
:meth:`~torch.nn.Module.state_dict` function. Default: ``True`` | |
assign (bool, optional): When ``False``, the properties of the tensors | |
in the current module are preserved while when ``True``, the | |
properties of the Tensors in the state dict are preserved. The only | |
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s | |
for which the value from the module is preserved. | |
Default: ``False`` | |
Returns: | |
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: | |
* **missing_keys** is a list of str containing the missing keys | |
* **unexpected_keys** is a list of str containing the unexpected keys | |
Note: | |
If a parameter or buffer is registered as ``None`` and its corresponding key | |
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a | |
``RuntimeError``. | |
""" | |
if not isinstance(state_dict, Mapping): | |
raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.") | |
missing_keys: List[str] = [] | |
unexpected_keys: List[str] = [] | |
error_msgs: List[str] = [] | |
# copy state_dict so _load_from_state_dict can modify it | |
metadata = getattr(state_dict, '_metadata', None) | |
state_dict = OrderedDict(state_dict) | |
if metadata is not None: | |
# mypy isn't aware that "_metadata" exists in state_dict | |
state_dict._metadata = metadata # type: ignore[attr-defined] | |
def load(module, local_state_dict, prefix=''): | |
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |
if assign: | |
local_metadata['assign_to_params_buffers'] = assign | |
module._load_from_state_dict( | |
local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |
for name, child in module._modules.items(): | |
if child is not None: | |
child_prefix = prefix + name + '.' | |
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} | |
load(child, child_state_dict, child_prefix) # noqa: F821 | |
# Note that the hook can modify missing_keys and unexpected_keys. | |
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) | |
for hook in module._load_state_dict_post_hooks.values(): | |
out = hook(module, incompatible_keys) | |
assert out is None, ( | |
"Hooks registered with ``register_load_state_dict_post_hook`` are not" | |
"expected to return new values, if incompatible_keys need to be modified," | |
"it should be done inplace." | |
) | |
load(self, state_dict) | |
del load | |
if strict: | |
if len(unexpected_keys) > 0: | |
error_msgs.insert( | |
0, 'Unexpected key(s) in state_dict: {}. '.format( | |
', '.join(f'"{k}"' for k in unexpected_keys))) | |
if len(missing_keys) > 0: | |
error_msgs.insert( | |
0, 'Missing key(s) in state_dict: {}. '.format( | |
', '.join(f'"{k}"' for k in missing_keys))) | |
if len(error_msgs) > 0: | |
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( | |
self.__class__.__name__, "\n\t".join(error_msgs))) | |
return _IncompatibleKeys(missing_keys, unexpected_keys) | |
def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplicate: bool = True): | |
r"""Help yield various names + members of modules.""" | |
memo = set() | |
modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)] | |
for module_prefix, module in modules: | |
members = get_members_fn(module) | |
for k, v in members: | |
if v is None or v in memo: | |
continue | |
if remove_duplicate: | |
memo.add(v) | |
name = module_prefix + ('.' if module_prefix else '') + k | |
yield name, v | |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]: | |
r"""Return an iterator over module parameters. | |
This is typically passed to an optimizer. | |
Args: | |
recurse (bool): if True, then yields parameters of this module | |
and all submodules. Otherwise, yields only parameters that | |
are direct members of this module. | |
Yields: | |
Parameter: module parameter | |
Example:: | |
>>> # xdoctest: +SKIP("undefined vars") | |
>>> for param in model.parameters(): | |
>>> print(type(param), param.size()) | |
<class 'torch.Tensor'> (20L,) | |
<class 'torch.Tensor'> (20L, 1L, 5L, 5L) | |
""" | |
for name, param in self.named_parameters(recurse=recurse): | |
yield param | |
def named_parameters( | |
self, | |
prefix: str = '', | |
recurse: bool = True, | |
remove_duplicate: bool = True | |
) -> Iterator[Tuple[str, Parameter]]: | |
r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. | |
Args: | |
prefix (str): prefix to prepend to all parameter names. | |
recurse (bool): if True, then yields parameters of this module | |
and all submodules. Otherwise, yields only parameters that | |
are direct members of this module. | |
remove_duplicate (bool, optional): whether to remove the duplicated | |
parameters in the result. Defaults to True. | |
Yields: | |
(str, Parameter): Tuple containing the name and parameter | |
Example:: | |
>>> # xdoctest: +SKIP("undefined vars") | |
>>> for name, param in self.named_parameters(): | |
>>> if name in ['bias']: | |
>>> print(param.size()) | |
""" | |
gen = self._named_members( | |
lambda module: module._parameters.items(), | |
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) | |
yield from gen | |
def buffers(self, recurse: bool = True) -> Iterator[Tensor]: | |
r"""Return an iterator over module buffers. | |
Args: | |
recurse (bool): if True, then yields buffers of this module | |
and all submodules. Otherwise, yields only buffers that | |
are direct members of this module. | |
Yields: | |
torch.Tensor: module buffer | |
Example:: | |
>>> # xdoctest: +SKIP("undefined vars") | |
>>> for buf in model.buffers(): | |
>>> print(type(buf), buf.size()) | |
<class 'torch.Tensor'> (20L,) | |
<class 'torch.Tensor'> (20L, 1L, 5L, 5L) | |
""" | |
for _, buf in self.named_buffers(recurse=recurse): | |
yield buf | |
def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]: | |
r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. | |
Args: | |
prefix (str): prefix to prepend to all buffer names. | |
recurse (bool, optional): if True, then yields buffers of this module | |
and all submodules. Otherwise, yields only buffers that | |
are direct members of this module. Defaults to True. | |
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. | |
Yields: | |
(str, torch.Tensor): Tuple containing the name and buffer | |
Example:: | |
>>> # xdoctest: +SKIP("undefined vars") | |
>>> for name, buf in self.named_buffers(): | |
>>> if name in ['running_var']: | |
>>> print(buf.size()) | |
""" | |
gen = self._named_members( | |
lambda module: module._buffers.items(), | |
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) | |
yield from gen | |
def children(self) -> Iterator['Module']: | |
r"""Return an iterator over immediate children modules. | |
Yields: | |
Module: a child module | |
""" | |
for name, module in self.named_children(): | |
yield module | |
def named_children(self) -> Iterator[Tuple[str, 'Module']]: | |
r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. | |
Yields: | |
(str, Module): Tuple containing a name and child module | |
Example:: | |
>>> # xdoctest: +SKIP("undefined vars") | |
>>> for name, module in model.named_children(): | |
>>> if name in ['conv4', 'conv5']: | |
>>> print(module) | |
""" | |
memo = set() | |
for name, module in self._modules.items(): | |
if module is not None and module not in memo: | |
memo.add(module) | |
yield name, module | |
def modules(self) -> Iterator['Module']: | |
r"""Return an iterator over all modules in the network. | |
Yields: | |
Module: a module in the network | |
Note: | |
Duplicate modules are returned only once. In the following | |
example, ``l`` will be returned only once. | |
Example:: | |
>>> l = nn.Linear(2, 2) | |
>>> net = nn.Sequential(l, l) | |
>>> for idx, m in enumerate(net.modules()): | |
... print(idx, '->', m) | |
0 -> Sequential( | |
(0): Linear(in_features=2, out_features=2, bias=True) | |
(1): Linear(in_features=2, out_features=2, bias=True) | |
) | |
1 -> Linear(in_features=2, out_features=2, bias=True) | |
""" | |
for _, module in self.named_modules(): | |
yield module | |
def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): | |
r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. | |
Args: | |
memo: a memo to store the set of modules already added to the result | |
prefix: a prefix that will be added to the name of the module | |
remove_duplicate: whether to remove the duplicated module instances in the result | |
or not | |
Yields: | |
(str, Module): Tuple of name and module | |
Note: | |
Duplicate modules are returned only once. In the following | |
example, ``l`` will be returned only once. | |
Example:: | |
>>> l = nn.Linear(2, 2) | |
>>> net = nn.Sequential(l, l) | |
>>> for idx, m in enumerate(net.named_modules()): | |
... print(idx, '->', m) | |
0 -> ('', Sequential( | |
(0): Linear(in_features=2, out_features=2, bias=True) | |
(1): Linear(in_features=2, out_features=2, bias=True) | |
)) | |
1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) | |
""" | |
if memo is None: | |
memo = set() | |
if self not in memo: | |
if remove_duplicate: | |
memo.add(self) | |
yield prefix, self | |
for name, module in self._modules.items(): | |
if module is None: | |
continue | |
submodule_prefix = prefix + ('.' if prefix else '') + name | |
yield from module.named_modules(memo, submodule_prefix, remove_duplicate) | |
def train(self: T, mode: bool = True) -> T: | |
r"""Set the module in training mode. | |
This has any effect only on certain modules. See documentations of | |
particular modules for details of their behaviors in training/evaluation | |
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, | |
etc. | |
Args: | |
mode (bool): whether to set training mode (``True``) or evaluation | |
mode (``False``). Default: ``True``. | |
Returns: | |
Module: self | |
""" | |
if not isinstance(mode, bool): | |
raise ValueError("training mode is expected to be boolean") | |
self.training = mode | |
for module in self.children(): | |
module.train(mode) | |
return self | |
def eval(self: T) -> T: | |
r"""Set the module in evaluation mode. | |
This has any effect only on certain modules. See documentations of | |
particular modules for details of their behaviors in training/evaluation | |
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, | |
etc. | |
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`. | |
See :ref:`locally-disable-grad-doc` for a comparison between | |
`.eval()` and several similar mechanisms that may be confused with it. | |
Returns: | |
Module: self | |
""" | |
return self.train(False) | |
def requires_grad_(self: T, requires_grad: bool = True) -> T: | |
r"""Change if autograd should record operations on parameters in this module. | |
This method sets the parameters' :attr:`requires_grad` attributes | |
in-place. | |
This method is helpful for freezing part of the module for finetuning | |
or training parts of a model individually (e.g., GAN training). | |
See :ref:`locally-disable-grad-doc` for a comparison between | |
`.requires_grad_()` and several similar mechanisms that may be confused with it. | |
Args: | |
requires_grad (bool): whether autograd should record operations on | |
parameters in this module. Default: ``True``. | |
Returns: | |
Module: self | |
""" | |
for p in self.parameters(): | |
p.requires_grad_(requires_grad) | |
return self | |
def zero_grad(self, set_to_none: bool = True) -> None: | |
r"""Reset gradients of all model parameters. | |
See similar function under :class:`torch.optim.Optimizer` for more context. | |
Args: | |
set_to_none (bool): instead of setting to zero, set the grads to None. | |
See :meth:`torch.optim.Optimizer.zero_grad` for details. | |
""" | |
if getattr(self, '_is_replica', False): | |
warnings.warn( | |
"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " | |
"The parameters are copied (in a differentiable manner) from the original module. " | |
"This means they are not leaf nodes in autograd and so don't accumulate gradients. " | |
"If you need gradients in your forward method, consider using autograd.grad instead.") | |
for p in self.parameters(): | |
if p.grad is not None: | |
if set_to_none: | |
p.grad = None | |
else: | |
if p.grad.grad_fn is not None: | |
p.grad.detach_() | |
else: | |
p.grad.requires_grad_(False) | |
p.grad.zero_() | |
def share_memory(self: T) -> T: | |
r"""See :meth:`torch.Tensor.share_memory_`.""" | |
return self._apply(lambda t: t.share_memory_()) | |
def _get_name(self): | |
return self.__class__.__name__ | |
def extra_repr(self) -> str: | |
r"""Set the extra representation of the module. | |
To print customized extra information, you should re-implement | |
this method in your own modules. Both single-line and multi-line | |
strings are acceptable. | |
""" | |
return '' | |
def __repr__(self): | |
# We treat the extra repr like the sub-module, one item per line | |
extra_lines = [] | |
extra_repr = self.extra_repr() | |
# empty string will be split into list [''] | |
if extra_repr: | |
extra_lines = extra_repr.split('\n') | |
child_lines = [] | |
for key, module in self._modules.items(): | |
mod_str = repr(module) | |
mod_str = _addindent(mod_str, 2) | |
child_lines.append('(' + key + '): ' + mod_str) | |
lines = extra_lines + child_lines | |
main_str = self._get_name() + '(' | |
if lines: | |
# simple one-liner info, which most builtin Modules will use | |
if len(extra_lines) == 1 and not child_lines: | |
main_str += extra_lines[0] | |
else: | |
main_str += '\n ' + '\n '.join(lines) + '\n' | |
main_str += ')' | |
return main_str | |
def __dir__(self): | |
module_attrs = dir(self.__class__) | |
attrs = list(self.__dict__.keys()) | |
parameters = list(self._parameters.keys()) | |
modules = list(self._modules.keys()) | |
buffers = list(self._buffers.keys()) | |
keys = module_attrs + attrs + parameters + modules + buffers | |
# Eliminate attrs that are not legal Python variable names | |
keys = [key for key in keys if not key[0].isdigit()] | |
return sorted(keys) | |
def _replicate_for_data_parallel(self): | |
replica = self.__new__(type(self)) | |
replica.__dict__ = self.__dict__.copy() | |
# replicas do not have parameters themselves, the replicas reference the original | |
# module. | |
replica._parameters = OrderedDict() | |
replica._buffers = replica._buffers.copy() | |
replica._modules = replica._modules.copy() | |
replica._is_replica = True # type: ignore[assignment] | |
return replica | |
def compile(self, *args, **kwargs): | |
""" | |
Compile this Module's forward using :func:`torch.compile`. | |
This Module's `__call__` method is compiled and all arguments are passed as-is | |
to :func:`torch.compile`. | |
See :func:`torch.compile` for details on the arguments for this function. | |
""" | |
self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs) | |