Spaces:
Sleeping
Sleeping
File size: 35,652 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 |
import torch
from torch.nn.modules.container import ModuleList, ModuleDict, Module
from torch.nn.parameter import Parameter
from torch import Tensor
import collections
import copyreg
from copy import deepcopy
from contextlib import contextmanager
from typing import Union, Optional, Dict, Tuple, Sequence
__all__ = ['cached', 'ParametrizationList', 'register_parametrization', 'is_parametrized', 'remove_parametrizations',
'type_before_parametrizations', 'transfer_parametrizations_and_params']
_cache_enabled = 0
_cache: Dict[Tuple[int, str], Optional[Tensor]] = {}
@contextmanager
def cached():
r"""Context manager that enables the caching system within parametrizations registered with :func:`register_parametrization`.
The value of the parametrized objects is computed and cached the first time
they are required when this context manager is active. The cached values are
discarded when leaving the context manager.
This is useful when using a parametrized parameter more than once in the forward pass.
An example of this is when parametrizing the recurrent kernel of an RNN or when
sharing weights.
The simplest way to activate the cache is by wrapping the forward pass of the neural network
.. code-block:: python
import torch.nn.utils.parametrize as P
...
with P.cached():
output = model(inputs)
in training and evaluation. One may also wrap the parts of the modules that use
several times the parametrized tensors. For example, the loop of an RNN with a
parametrized recurrent kernel:
.. code-block:: python
with P.cached():
for x in xs:
out_rnn = self.rnn_cell(x, out_rnn)
"""
global _cache
global _cache_enabled
_cache_enabled += 1
try:
yield
finally:
_cache_enabled -= 1
if not _cache_enabled:
_cache = {}
def _register_parameter_or_buffer(module, name, X):
if isinstance(X, Parameter):
module.register_parameter(name, X)
else:
module.register_buffer(name, X)
class ParametrizationList(ModuleList):
r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`torch.nn.Module`.
It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]``
has been parametrized with :func:`register_parametrization`.
If the first registered parametrization has a ``right_inverse`` that returns one tensor or
does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity),
it will hold the tensor under the name ``original``.
If it has a ``right_inverse`` that returns more than one tensor, these will be registered as
``original0``, ``original1``, ...
.. warning::
This class is used internally by :func:`register_parametrization`. It is documented
here for completeness. It shall not be instantiated by the user.
Args:
modules (sequence): sequence of modules representing the parametrizations
original (Parameter or Tensor): parameter or buffer that is parametrized
unsafe (bool): a boolean flag that denotes whether the parametrization
may change the dtype and shape of the tensor. Default: `False`
Warning: the parametrization is not checked for consistency upon registration.
Enable this flag at your own risk.
"""
original: Tensor
unsafe: bool
def __init__(
self, modules: Sequence[Module], original: Union[Tensor, Parameter], unsafe: bool = False
) -> None:
# We require this because we need to treat differently the first parametrization
# This should never throw, unless this class is used from the outside
if len(modules) == 0:
raise ValueError("ParametrizationList requires one or more modules.")
super().__init__(modules)
self.unsafe = unsafe
# In plain words:
# module.weight must keep its dtype and shape.
# Furthermore, if there is no right_inverse or the right_inverse returns a tensor,
# this should be of the same dtype as the original tensor
#
# We check that the following invariants hold:
# X = module.weight
# Y = param.right_inverse(X)
# assert isinstance(Y, Tensor) or
# (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y))
# Z = param(Y) if isinstance(Y, Tensor) else param(*Y)
# # Consistency checks
# assert X.dtype == Z.dtype and X.shape == Z.shape
# # If it has one input, this allows to be able to use set_ to be able to
# # move data to/from the original tensor without changing its id (which is what the
# # optimizer uses to track parameters)
# if isinstance(Y, Tensor)
# assert X.dtype == Y.dtype
# Below we use original = X, new = Y
original_shape = original.shape
original_dtype = original.dtype
# Compute new
with torch.no_grad():
new = original
for module in reversed(self): # type: ignore[call-overload]
if hasattr(module, "right_inverse"):
try:
new = module.right_inverse(new)
except NotImplementedError:
pass
# else, or if it throws, we assume that right_inverse is the identity
if not isinstance(new, Tensor) and not isinstance(new, collections.abc.Sequence):
raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). "
f"Got {type(new).__name__}")
# Set the number of original tensors
self.is_tensor = isinstance(new, Tensor)
self.ntensors = 1 if self.is_tensor else len(new)
# Register the tensor(s)
if self.is_tensor:
if original.dtype != new.dtype:
raise ValueError(
"When `right_inverse` outputs one tensor, it may not change the dtype.\n"
f"original.dtype: {original.dtype}\n"
f"right_inverse(original).dtype: {new.dtype}"
)
# Set the original to original so that the user does not need to re-register the parameter
# manually in the optimiser
with torch.no_grad():
original.set_(new) # type: ignore[call-overload]
_register_parameter_or_buffer(self, "original", original)
else:
for i, originali in enumerate(new):
if not isinstance(originali, Tensor):
raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors "
"(list, tuple...). "
f"Got element {i} of the sequence with type {type(originali).__name__}.")
# If the original tensor was a Parameter that required grad, we expect the user to
# add the new parameters to the optimizer after registering the parametrization
# (this is documented)
if isinstance(original, Parameter):
originali = Parameter(originali)
originali.requires_grad_(original.requires_grad)
_register_parameter_or_buffer(self, f"original{i}", originali)
if not self.unsafe:
# Consistency checks:
# Since f : A -> B, right_inverse : B -> A, Z and original should live in B
# Z = forward(right_inverse(original))
Z = self()
if not isinstance(Z, Tensor):
raise ValueError(
f"A parametrization must return a tensor. Got {type(Z).__name__}."
)
if Z.dtype != original_dtype:
raise ValueError(
"Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n"
f"unparametrized dtype: {original_dtype}\n"
f"parametrized dtype: {Z.dtype}"
)
if Z.shape != original_shape:
raise ValueError(
"Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n"
f"unparametrized shape: {original_shape}\n"
f"parametrized shape: {Z.shape}"
)
def right_inverse(self, value: Tensor) -> None:
r"""Call the ``right_inverse`` methods of the parametrizations in the inverse registration order.
Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor
or in ``self.original0``, ``self.original1``, ... if it outputs several.
Args:
value (Tensor): Value to which initialize the module
"""
# All the exceptions in this function should almost never throw.
# They could throw if, for example, right_inverse function returns a different
# dtype when given a different input, which should most likely be caused by a
# bug in the user's code
with torch.no_grad():
# See https://github.com/pytorch/pytorch/issues/53103
for module in reversed(self): # type: ignore[call-overload]
if hasattr(module, "right_inverse"):
value = module.right_inverse(value)
else:
raise RuntimeError(f"parametrization {type(module).__name__} does not implement "
"right_inverse.")
if self.is_tensor:
# These exceptions should only throw when a right_inverse function does not
# return the same dtype for every input, which should most likely be caused by a bug
if not isinstance(value, Tensor):
raise ValueError(
f"`right_inverse` should return a tensor. Got {type(value).__name__}"
)
if value.dtype != self.original.dtype:
raise ValueError(
f"The tensor returned by `right_inverse` has dtype {value.dtype} "
f"while `original` has dtype {self.original.dtype}"
)
# We know that the result is going to have the same dtype
self.original.set_(value) # type: ignore[call-overload]
else:
if not isinstance(value, collections.abc.Sequence):
raise ValueError(
"'right_inverse' must return a sequence of tensors. "
f"Got {type(value).__name__}."
)
if len(value) != self.ntensors:
raise ValueError(
"'right_inverse' must return a sequence of tensors of length "
f"{self.ntensors}. Got a sequence of length {len(value)}."
)
for i, tensor in enumerate(value):
original_i = getattr(self, f"original{i}")
if not isinstance(tensor, Tensor):
raise ValueError(
f"`right_inverse` must return a sequence of tensors. "
f"Got element {i} of type {type(tensor).__name__}"
)
if original_i.dtype != tensor.dtype:
raise ValueError(
f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} "
f"while `original{i}` has dtype {original_i.dtype}"
)
original_i.set_(tensor)
def forward(self) -> Tensor:
if torch.jit.is_scripting():
raise RuntimeError('Parametrization is not working with scripting.')
# Unpack the originals for the first parametrization
if self.is_tensor:
x = self[0](self.original)
else:
originals = (getattr(self, f"original{i}") for i in range(self.ntensors))
x = self[0](*originals)
# It's not possible to call self[1:] here, so we have to be a bit more cryptic
# Also we want to skip all non-integer keys
curr_idx = 1
while hasattr(self, str(curr_idx)):
x = self[curr_idx](x)
curr_idx += 1
return x
def _inject_new_class(module: Module) -> None:
r"""Set up a module to be parametrized.
This works by substituting the class of the module by a class
that extends it to be able to inject a property
Args:
module (nn.Module): module into which to inject the property
"""
cls = module.__class__
def default_deepcopy(self, memo):
# Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class.
obj = memo.get(id(self), None)
if obj is not None:
return obj
replica = self.__new__(self.__class__)
memo[id(self)] = replica
replica.__dict__ = deepcopy(self.__dict__, memo)
# Also save all slots if they exist.
slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined]
for slot in slots_to_save:
if hasattr(self, slot):
setattr(replica, slot, deepcopy(getattr(self, slot), memo))
return replica
def getstate(self):
raise RuntimeError(
"Serialization of parametrized modules is only "
"supported through state_dict(). See:\n"
"https://pytorch.org/tutorials/beginner/saving_loading_models.html"
"#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
)
dct = {"__getstate__": getstate}
# We don't allow serialization of parametrized modules but should still allow deepcopying.
# Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists.
if not hasattr(cls, "__deepcopy__"):
dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment]
param_cls = type(
f"Parametrized{cls.__name__}",
(cls,),
dct,
)
module.__class__ = param_cls
def _inject_property(module: Module, tensor_name: str) -> None:
r"""Injects a property into module[tensor_name].
It assumes that the class in the module has already been modified from its
original one using _inject_new_class and that the tensor under :attr:`tensor_name`
has already been moved out
Args:
module (nn.Module): module into which to inject the property
tensor_name (str): name of the name of the property to create
"""
# We check the precondition.
# This should never fire if register_parametrization is correctly implemented
assert not hasattr(module, tensor_name)
@torch.jit.unused
def get_cached_parametrization(parametrization) -> Tensor:
global _cache
key = (id(module), tensor_name)
tensor = _cache.get(key)
if tensor is None:
tensor = parametrization()
_cache[key] = tensor
return tensor
def get_parametrized(self) -> Tensor:
if torch.jit.is_scripting():
raise RuntimeError('Parametrization is not working with scripting.')
parametrization = self.parametrizations[tensor_name]
if _cache_enabled:
if torch.jit.is_scripting():
# Scripting
raise RuntimeError('Caching is not implemented for scripting. '
'Either disable caching or avoid scripting.')
elif torch._C._get_tracing_state() is not None:
# Tracing
raise RuntimeError('Cannot trace a model while caching parametrizations.')
else:
return get_cached_parametrization(parametrization)
else:
# If caching is not active, this function just evaluates the parametrization
return parametrization()
def set_original(self, value: Tensor) -> None:
if torch.jit.is_scripting():
raise RuntimeError('Parametrization is not working with scripting.')
self.parametrizations[tensor_name].right_inverse(value)
setattr(module.__class__, tensor_name, property(get_parametrized, set_original))
def register_parametrization(
module: Module, tensor_name: str, parametrization: Module, *, unsafe: bool = False,
) -> Module:
r"""Register a parametrization to a tensor in a module.
Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``,
the module will return the parametrized version ``parametrization(module.weight)``.
If the original tensor requires a gradient, the backward pass will differentiate
through :attr:`parametrization`, and the optimizer will update the tensor accordingly.
The first time that a module registers a parametrization, this function will add an attribute
``parametrizations`` to the module of type :class:`~ParametrizationList`.
The list of parametrizations on the tensor ``weight`` will be accessible under
``module.parametrizations.weight``.
The original tensor will be accessible under
``module.parametrizations.weight.original``.
Parametrizations may be concatenated by registering several parametrizations
on the same attribute.
The training mode of a registered parametrization is updated on registration
to match the training mode of the host module
Parametrized parameters and buffers have an inbuilt caching system that can be activated
using the context manager :func:`cached`.
A :attr:`parametrization` may optionally implement a method with signature
.. code-block:: python
def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]
This method is called on the unparametrized tensor when the first parametrization
is registered to compute the initial value of the original tensor.
If this method is not implemented, the original tensor will be just the unparametrized tensor.
If all the parametrizations registered on a tensor implement `right_inverse` it is possible
to initialize a parametrized tensor by assigning to it, as shown in the example below.
It is possible for the first parametrization to depend on several inputs.
This may be implemented returning a tuple of tensors from ``right_inverse``
(see the example implementation of a ``RankOne`` parametrization below).
In this case, the unconstrained tensors are also located under ``module.parametrizations.weight``
with names ``original0``, ``original1``,...
.. note::
If unsafe=False (default) both the forward and right_inverse methods will be called
once to perform a number of consistency checks.
If unsafe=True, then right_inverse will be called if the tensor is not parametrized,
and nothing will be called otherwise.
.. note::
In most situations, ``right_inverse`` will be a function such that
``forward(right_inverse(X)) == X`` (see
`right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_).
Sometimes, when the parametrization is not surjective, it may be reasonable
to relax this.
.. warning::
If a parametrization depends on several inputs, :func:`~register_parametrization`
will register a number of new parameters. If such parametrization is registered
after the optimizer is created, these new parameters will need to be added manually
to the optimizer. See :meth:`torch.Optimizer.add_param_group`.
Args:
module (nn.Module): module on which to register the parametrization
tensor_name (str): name of the parameter or buffer on which to register
the parametrization
parametrization (nn.Module): the parametrization to register
Keyword args:
unsafe (bool): a boolean flag that denotes whether the parametrization
may change the dtype and shape of the tensor. Default: `False`
Warning: the parametrization is not checked for consistency upon registration.
Enable this flag at your own risk.
Raises:
ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name`
Examples:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.utils.parametrize as P
>>>
>>> class Symmetric(nn.Module):
>>> def forward(self, X):
>>> return X.triu() + X.triu(1).T # Return a symmetric matrix
>>>
>>> def right_inverse(self, A):
>>> return A.triu()
>>>
>>> m = nn.Linear(5, 5)
>>> P.register_parametrization(m, "weight", Symmetric())
>>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric
True
>>> A = torch.rand(5, 5)
>>> A = A + A.T # A is now symmetric
>>> m.weight = A # Initialize the weight to be the symmetric matrix A
>>> print(torch.allclose(m.weight, A))
True
>>> class RankOne(nn.Module):
>>> def forward(self, x, y):
>>> # Form a rank 1 matrix multiplying two vectors
>>> return x.unsqueeze(-1) @ y.unsqueeze(-2)
>>>
>>> def right_inverse(self, Z):
>>> # Project Z onto the rank 1 matrices
>>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
>>> # Return rescaled singular vectors
>>> s0_sqrt = S[0].sqrt().unsqueeze(-1)
>>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
>>>
>>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
>>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
1
"""
parametrization.train(module.training)
if is_parametrized(module, tensor_name):
# Correctness checks.
# If A is the space of tensors with shape and dtype equal to module.weight
# we check that parametrization.forward and parametrization.right_inverse are
# functions from A to A
if not unsafe:
Y = getattr(module, tensor_name)
X = parametrization(Y)
if not isinstance(X, Tensor):
raise ValueError(
f"A parametrization must return a tensor. Got {type(X).__name__}."
)
if X.dtype != Y.dtype:
raise ValueError(
"Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n"
f"module.{tensor_name}.dtype: {Y.dtype}\n"
f"parametrization(module.{tensor_name}).dtype: {X.dtype}"
)
if X.shape != Y.shape:
raise ValueError(
"Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n"
f"module.{tensor_name}.shape: {Y.shape}\n"
f"parametrization(module.{tensor_name}).shape: {X.shape}"
)
if hasattr(parametrization, "right_inverse"):
try:
Z = parametrization.right_inverse(X) # type: ignore[operator]
except NotImplementedError:
pass
else:
if not isinstance(Z, Tensor):
raise ValueError(
f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}"
)
if Z.dtype != Y.dtype:
raise ValueError(
"The tensor returned by parametrization.right_inverse must have the same dtype "
f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
f"module.{tensor_name}.dtype: {Y.dtype}\n"
f"returned dtype: {Z.dtype}"
)
if Z.shape != Y.shape:
raise ValueError(
"The tensor returned by parametrization.right_inverse must have the same shape "
f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
f"module.{tensor_name}.shape: {Y.shape}\n"
f"returned shape: {Z.shape}"
)
# else right_inverse is assumed to be the identity
# add the new parametrization to the parametrization list
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
module.parametrizations[tensor_name].append(parametrization)
# If unsafe was True in previous parametrization, keep it enabled
module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr]
elif tensor_name in module._buffers or tensor_name in module._parameters:
# Set the parametrization mechanism
# Fetch the original buffer or parameter
original = getattr(module, tensor_name)
# We create this early to check for possible errors
parametrizations = ParametrizationList([parametrization], original, unsafe=unsafe)
# Delete the previous parameter or buffer
delattr(module, tensor_name)
# If this is the first parametrization registered on the module,
# we prepare the module to inject the property
if not is_parametrized(module):
# Change the class
_inject_new_class(module)
# Inject a ``ModuleDict`` into the instance under module.parametrizations
module.parametrizations = ModuleDict()
# Add a property into the class
_inject_property(module, tensor_name)
# Add a ParametrizationList
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
module.parametrizations[tensor_name] = parametrizations
else:
raise ValueError(
f"Module '{module}' does not have a parameter, a buffer, or a "
f"parametrized element with name '{tensor_name}'"
)
return module
def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool:
r"""Determine if a module has a parametrization.
Args:
module (nn.Module): module to query
tensor_name (str, optional): name of the parameter in the module
Default: ``None``
Returns:
``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`,
or if it has any parametrization when :attr:`tensor_name` is ``None``;
otherwise ``False``
"""
parametrizations = getattr(module, "parametrizations", None)
if parametrizations is None or not isinstance(parametrizations, ModuleDict):
return False
if tensor_name is None:
# Check that there is at least one parametrized buffer or Parameter
return len(parametrizations) > 0
else:
return tensor_name in parametrizations
def remove_parametrizations(
module: Module, tensor_name: str, leave_parametrized: bool = True
) -> Module:
r"""Remove the parametrizations on a tensor in a module.
- If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to
its current output. In this case, the parametrization shall not change the ``dtype``
of the tensor.
- If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to
the unparametrised tensor in ``module.parametrizations[tensor_name].original``.
This is only possible when the parametrization depends on just one tensor.
Args:
module (nn.Module): module from which remove the parametrization
tensor_name (str): name of the parametrization to be removed
leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized.
Default: ``True``
Returns:
Module: module
Raises:
ValueError: if ``module[tensor_name]`` is not parametrized
ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors
"""
if not is_parametrized(module, tensor_name):
raise ValueError(f"Module {module} does not have a parametrization on {tensor_name}")
# Fetch the original tensor
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
parametrizations = module.parametrizations[tensor_name]
if parametrizations.is_tensor:
original = parametrizations.original
if leave_parametrized:
with torch.no_grad():
t = getattr(module, tensor_name)
# We know they have the same dtype because we have checked this when registering the
# parametrizations. As such, we can use set_
# We do this so that the parameter does not to change the id()
# This way the user does not need to update the optimizer
with torch.no_grad():
if type(original) is torch.Tensor:
original.set_(t)
else:
try:
original.set_(t)
except RuntimeError as e:
# TODO: Fix this for tensor subclasses that are parameters:
# RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach().
raise RuntimeError("Calling remove_parametrizations() with leave_parametrized=True "
"for a parameter that is an instance of a tensor subclass requires "
"set_() to be implemented correctly for the tensor subclass. Either "
"set leave_parametrized=False or provide a working implementation for "
"set_() in the tensor subclass.") from e
else:
if leave_parametrized:
# We cannot use no_grad because we need to know whether one or more
# original tensors required grad
t = getattr(module, tensor_name)
# We'll have to trust the user to add it to the optimizer
original = Parameter(t) if t.requires_grad else t
else:
raise ValueError("Cannot leave unparametrized (`leave_parametrized=False`) a tensor "
"that is parametrized in terms of a sequence of tensors.")
# Delete the property that manages the parametrization
delattr(module.__class__, tensor_name)
# Delete the ParametrizationList
del module.parametrizations[tensor_name]
# Restore the parameter / buffer into the main class
_register_parameter_or_buffer(module, tensor_name, original)
# Roll back the parametrized class if no other buffer or parameter
# is currently parametrized in this class
if not is_parametrized(module):
delattr(module, "parametrizations")
# Restore class
orig_cls = module.__class__.__bases__[0]
module.__class__ = orig_cls
return module
def type_before_parametrizations(module: Module) -> type:
r"""Return the module type before parametrizations were applied and if not, then it returns the module type.
Args:
module (nn.Module): module to get type of
"""
if is_parametrized(module):
return module.__class__.__bases__[0]
else:
return type(module)
def transfer_parametrizations_and_params(
from_module: Module, to_module: Module, tensor_name: Optional[str] = None
) -> Module:
r"""Transfer parametrizations and the parameters they parametrize from :attr:`from_module` to :attr:`to_module`.
If :attr:`tensor_name` is specified, only transfers the specified parameter, otherwise
transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them.
Does nothing if from_module is not parametrized.
Args:
from_module (nn.Module): module to transfer from
to_module (nn.Module): module to transfer to
tensor_name (str, optional): parameter to transfer
Returns:
Module: to_module
"""
if is_parametrized(from_module):
assert isinstance(from_module.parametrizations, ModuleDict) # for mypy
# get list of all params or the single param to transfer
parameters_to_transfer: Union[list, ModuleDict] = (
from_module.parametrizations if tensor_name is None else [tensor_name]
)
assert hasattr(parameters_to_transfer, "__iter__") # for mypy
for parameter_name in parameters_to_transfer:
# initialize the to-be-transferred param in to_module if it doesn't exist already
if not hasattr(to_module, parameter_name):
setattr(
to_module,
parameter_name,
Parameter(getattr(from_module, parameter_name)),
)
# apply the params's parametrizations to to_module
for param_func in from_module.parametrizations[parameter_name]:
register_parametrization(to_module, parameter_name, param_func)
assert isinstance(to_module.parametrizations, ModuleDict) # for mypy
# make values match, original values can be stored in either original or
# original0, original1..., need to check both cases
if hasattr(from_module.parametrizations[parameter_name], "original"):
to_module.parametrizations[parameter_name].original = \
from_module.parametrizations[parameter_name].original
else:
num = 0
orig_num = "original" + str(num)
# loop through each original# until all values have been set
while hasattr(from_module.parametrizations[parameter_name], orig_num):
setattr(
to_module.parametrizations[parameter_name],
orig_num,
getattr(from_module.parametrizations[parameter_name], orig_num),
)
num = num + 1
orig_num = "original" + str(num)
return to_module
|