Spaces:
Sleeping
Sleeping
r""" | |
PyTorch provides two global :class:`ConstraintRegistry` objects that link | |
:class:`~torch.distributions.constraints.Constraint` objects to | |
:class:`~torch.distributions.transforms.Transform` objects. These objects both | |
input constraints and return transforms, but they have different guarantees on | |
bijectivity. | |
1. ``biject_to(constraint)`` looks up a bijective | |
:class:`~torch.distributions.transforms.Transform` from ``constraints.real`` | |
to the given ``constraint``. The returned transform is guaranteed to have | |
``.bijective = True`` and should implement ``.log_abs_det_jacobian()``. | |
2. ``transform_to(constraint)`` looks up a not-necessarily bijective | |
:class:`~torch.distributions.transforms.Transform` from ``constraints.real`` | |
to the given ``constraint``. The returned transform is not guaranteed to | |
implement ``.log_abs_det_jacobian()``. | |
The ``transform_to()`` registry is useful for performing unconstrained | |
optimization on constrained parameters of probability distributions, which are | |
indicated by each distribution's ``.arg_constraints`` dict. These transforms often | |
overparameterize a space in order to avoid rotation; they are thus more | |
suitable for coordinate-wise optimization algorithms like Adam:: | |
loc = torch.zeros(100, requires_grad=True) | |
unconstrained = torch.zeros(100, requires_grad=True) | |
scale = transform_to(Normal.arg_constraints['scale'])(unconstrained) | |
loss = -Normal(loc, scale).log_prob(data).sum() | |
The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where | |
samples from a probability distribution with constrained ``.support`` are | |
propagated in an unconstrained space, and algorithms are typically rotation | |
invariant.:: | |
dist = Exponential(rate) | |
unconstrained = torch.zeros(100, requires_grad=True) | |
sample = biject_to(dist.support)(unconstrained) | |
potential_energy = -dist.log_prob(sample).sum() | |
.. note:: | |
An example where ``transform_to`` and ``biject_to`` differ is | |
``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a | |
:class:`~torch.distributions.transforms.SoftmaxTransform` that simply | |
exponentiates and normalizes its inputs; this is a cheap and mostly | |
coordinate-wise operation appropriate for algorithms like SVI. In | |
contrast, ``biject_to(constraints.simplex)`` returns a | |
:class:`~torch.distributions.transforms.StickBreakingTransform` that | |
bijects its input down to a one-fewer-dimensional space; this a more | |
expensive less numerically stable transform but is needed for algorithms | |
like HMC. | |
The ``biject_to`` and ``transform_to`` objects can be extended by user-defined | |
constraints and transforms using their ``.register()`` method either as a | |
function on singleton constraints:: | |
transform_to.register(my_constraint, my_transform) | |
or as a decorator on parameterized constraints:: | |
@transform_to.register(MyConstraintClass) | |
def my_factory(constraint): | |
assert isinstance(constraint, MyConstraintClass) | |
return MyTransform(constraint.param1, constraint.param2) | |
You can create your own registry by creating a new :class:`ConstraintRegistry` | |
object. | |
""" | |
import numbers | |
from torch.distributions import constraints, transforms | |
__all__ = [ | |
"ConstraintRegistry", | |
"biject_to", | |
"transform_to", | |
] | |
class ConstraintRegistry: | |
""" | |
Registry to link constraints to transforms. | |
""" | |
def __init__(self): | |
self._registry = {} | |
super().__init__() | |
def register(self, constraint, factory=None): | |
""" | |
Registers a :class:`~torch.distributions.constraints.Constraint` | |
subclass in this registry. Usage:: | |
@my_registry.register(MyConstraintClass) | |
def construct_transform(constraint): | |
assert isinstance(constraint, MyConstraint) | |
return MyTransform(constraint.arg_constraints) | |
Args: | |
constraint (subclass of :class:`~torch.distributions.constraints.Constraint`): | |
A subclass of :class:`~torch.distributions.constraints.Constraint`, or | |
a singleton object of the desired class. | |
factory (Callable): A callable that inputs a constraint object and returns | |
a :class:`~torch.distributions.transforms.Transform` object. | |
""" | |
# Support use as decorator. | |
if factory is None: | |
return lambda factory: self.register(constraint, factory) | |
# Support calling on singleton instances. | |
if isinstance(constraint, constraints.Constraint): | |
constraint = type(constraint) | |
if not isinstance(constraint, type) or not issubclass( | |
constraint, constraints.Constraint | |
): | |
raise TypeError( | |
f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}" | |
) | |
self._registry[constraint] = factory | |
return factory | |
def __call__(self, constraint): | |
""" | |
Looks up a transform to constrained space, given a constraint object. | |
Usage:: | |
constraint = Normal.arg_constraints['scale'] | |
scale = transform_to(constraint)(torch.zeros(1)) # constrained | |
u = transform_to(constraint).inv(scale) # unconstrained | |
Args: | |
constraint (:class:`~torch.distributions.constraints.Constraint`): | |
A constraint object. | |
Returns: | |
A :class:`~torch.distributions.transforms.Transform` object. | |
Raises: | |
`NotImplementedError` if no transform has been registered. | |
""" | |
# Look up by Constraint subclass. | |
try: | |
factory = self._registry[type(constraint)] | |
except KeyError: | |
raise NotImplementedError( | |
f"Cannot transform {type(constraint).__name__} constraints" | |
) from None | |
return factory(constraint) | |
biject_to = ConstraintRegistry() | |
transform_to = ConstraintRegistry() | |
################################################################################ | |
# Registration Table | |
################################################################################ | |
def _transform_to_real(constraint): | |
return transforms.identity_transform | |
def _biject_to_independent(constraint): | |
base_transform = biject_to(constraint.base_constraint) | |
return transforms.IndependentTransform( | |
base_transform, constraint.reinterpreted_batch_ndims | |
) | |
def _transform_to_independent(constraint): | |
base_transform = transform_to(constraint.base_constraint) | |
return transforms.IndependentTransform( | |
base_transform, constraint.reinterpreted_batch_ndims | |
) | |
def _transform_to_positive(constraint): | |
return transforms.ExpTransform() | |
def _transform_to_greater_than(constraint): | |
return transforms.ComposeTransform( | |
[ | |
transforms.ExpTransform(), | |
transforms.AffineTransform(constraint.lower_bound, 1), | |
] | |
) | |
def _transform_to_less_than(constraint): | |
return transforms.ComposeTransform( | |
[ | |
transforms.ExpTransform(), | |
transforms.AffineTransform(constraint.upper_bound, -1), | |
] | |
) | |
def _transform_to_interval(constraint): | |
# Handle the special case of the unit interval. | |
lower_is_0 = ( | |
isinstance(constraint.lower_bound, numbers.Number) | |
and constraint.lower_bound == 0 | |
) | |
upper_is_1 = ( | |
isinstance(constraint.upper_bound, numbers.Number) | |
and constraint.upper_bound == 1 | |
) | |
if lower_is_0 and upper_is_1: | |
return transforms.SigmoidTransform() | |
loc = constraint.lower_bound | |
scale = constraint.upper_bound - constraint.lower_bound | |
return transforms.ComposeTransform( | |
[transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)] | |
) | |
def _biject_to_simplex(constraint): | |
return transforms.StickBreakingTransform() | |
def _transform_to_simplex(constraint): | |
return transforms.SoftmaxTransform() | |
# TODO define a bijection for LowerCholeskyTransform | |
def _transform_to_lower_cholesky(constraint): | |
return transforms.LowerCholeskyTransform() | |
def _transform_to_positive_definite(constraint): | |
return transforms.PositiveDefiniteTransform() | |
def _transform_to_corr_cholesky(constraint): | |
return transforms.CorrCholeskyTransform() | |
def _biject_to_cat(constraint): | |
return transforms.CatTransform( | |
[biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths | |
) | |
def _transform_to_cat(constraint): | |
return transforms.CatTransform( | |
[transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths | |
) | |
def _biject_to_stack(constraint): | |
return transforms.StackTransform( | |
[biject_to(c) for c in constraint.cseq], constraint.dim | |
) | |
def _transform_to_stack(constraint): | |
return transforms.StackTransform( | |
[transform_to(c) for c in constraint.cseq], constraint.dim | |
) | |