Spaces:
Sleeping
Sleeping
import inspect | |
import torch | |
def skip_init(module_cls, *args, **kwargs): | |
r""" | |
Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers. | |
This can be useful if initialization is slow or if custom initialization will | |
be performed, making the default initialization unnecessary. There are some caveats to this, due to | |
the way this function is implemented: | |
1. The module must accept a `device` arg in its constructor that is passed to any parameters | |
or buffers created during construction. | |
2. The module must not perform any computation on parameters in its constructor except | |
initialization (i.e. functions from :mod:`torch.nn.init`). | |
If these conditions are satisfied, the module can be instantiated with parameter / buffer values | |
uninitialized, as if having been created using :func:`torch.empty`. | |
Args: | |
module_cls: Class object; should be a subclass of :class:`torch.nn.Module` | |
args: args to pass to the module's constructor | |
kwargs: kwargs to pass to the module's constructor | |
Returns: | |
Instantiated module with uninitialized parameters / buffers | |
Example:: | |
>>> # xdoctest: +IGNORE_WANT("non-deterministic") | |
>>> import torch | |
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) | |
>>> m.weight | |
Parameter containing: | |
tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]], | |
requires_grad=True) | |
>>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1) | |
>>> m2.weight | |
Parameter containing: | |
tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, | |
4.5915e-41]], requires_grad=True) | |
""" | |
if not issubclass(module_cls, torch.nn.Module): | |
raise RuntimeError(f'Expected a Module; got {module_cls}') | |
if 'device' not in inspect.signature(module_cls).parameters: | |
raise RuntimeError('Module must support a \'device\' arg to skip initialization') | |
final_device = kwargs.pop('device', 'cpu') | |
kwargs['device'] = 'meta' | |
return module_cls(*args, **kwargs).to_empty(device=final_device) | |