Spaces:
Sleeping
Sleeping
| r""" | |
| PyTorch provides two global :class:`ConstraintRegistry` objects that link | |
| :class:`~torch.distributions.constraints.Constraint` objects to | |
| :class:`~torch.distributions.transforms.Transform` objects. These objects both | |
| input constraints and return transforms, but they have different guarantees on | |
| bijectivity. | |
| 1. ``biject_to(constraint)`` looks up a bijective | |
| :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` | |
| to the given ``constraint``. The returned transform is guaranteed to have | |
| ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``. | |
| 2. ``transform_to(constraint)`` looks up a not-necessarily bijective | |
| :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` | |
| to the given ``constraint``. The returned transform is not guaranteed to | |
| implement ``.log_abs_det_jacobian()``. | |
| The ``transform_to()`` registry is useful for performing unconstrained | |
| optimization on constrained parameters of probability distributions, which are | |
| indicated by each distribution's ``.arg_constraints`` dict. These transforms often | |
| overparameterize a space in order to avoid rotation; they are thus more | |
| suitable for coordinate-wise optimization algorithms like Adam:: | |
| loc = torch.zeros(100, requires_grad=True) | |
| unconstrained = torch.zeros(100, requires_grad=True) | |
| scale = transform_to(Normal.arg_constraints['scale'])(unconstrained) | |
| loss = -Normal(loc, scale).log_prob(data).sum() | |
| The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where | |
| samples from a probability distribution with constrained ``.support`` are | |
| propagated in an unconstrained space, and algorithms are typically rotation | |
| invariant.:: | |
| dist = Exponential(rate) | |
| unconstrained = torch.zeros(100, requires_grad=True) | |
| sample = biject_to(dist.support)(unconstrained) | |
| potential_energy = -dist.log_prob(sample).sum() | |
| .. note:: | |
| An example where ``transform_to`` and ``biject_to`` differ is | |
| ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a | |
| :class:`~torch.distributions.transforms.SoftmaxTransform` that simply | |
| exponentiates and normalizes its inputs; this is a cheap and mostly | |
| coordinate-wise operation appropriate for algorithms like SVI. In | |
| contrast, ``biject_to(constraints.simplex)`` returns a | |
| :class:`~torch.distributions.transforms.StickBreakingTransform` that | |
| bijects its input down to a one-fewer-dimensional space; this a more | |
| expensive less numerically stable transform but is needed for algorithms | |
| like HMC. | |
| The ``biject_to`` and ``transform_to`` objects can be extended by user-defined | |
| constraints and transforms using their ``.register()`` method either as a | |
| function on singleton constraints:: | |
| transform_to.register(my_constraint, my_transform) | |
| or as a decorator on parameterized constraints:: | |
| @transform_to.register(MyConstraintClass) | |
| def my_factory(constraint): | |
| assert isinstance(constraint, MyConstraintClass) | |
| return MyTransform(constraint.param1, constraint.param2) | |
| You can create your own registry by creating a new :class:`ConstraintRegistry` | |
| object. | |
| """ | |
| import numbers | |
| from torch.distributions import constraints, transforms | |
| __all__ = [ | |
| "ConstraintRegistry", | |
| "biject_to", | |
| "transform_to", | |
| ] | |
| class ConstraintRegistry: | |
| """ | |
| Registry to link constraints to transforms. | |
| """ | |
| def __init__(self): | |
| self._registry = {} | |
| super().__init__() | |
| def register(self, constraint, factory=None): | |
| """ | |
| Registers a :class:`~torch.distributions.constraints.Constraint` | |
| subclass in this registry. Usage:: | |
| @my_registry.register(MyConstraintClass) | |
| def construct_transform(constraint): | |
| assert isinstance(constraint, MyConstraint) | |
| return MyTransform(constraint.arg_constraints) | |
| Args: | |
| constraint (subclass of :class:`~torch.distributions.constraints.Constraint`): | |
| A subclass of :class:`~torch.distributions.constraints.Constraint`, or | |
| a singleton object of the desired class. | |
| factory (Callable): A callable that inputs a constraint object and returns | |
| a :class:`~torch.distributions.transforms.Transform` object. | |
| """ | |
| # Support use as decorator. | |
| if factory is None: | |
| return lambda factory: self.register(constraint, factory) | |
| # Support calling on singleton instances. | |
| if isinstance(constraint, constraints.Constraint): | |
| constraint = type(constraint) | |
| if not isinstance(constraint, type) or not issubclass( | |
| constraint, constraints.Constraint | |
| ): | |
| raise TypeError( | |
| f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}" | |
| ) | |
| self._registry[constraint] = factory | |
| return factory | |
| def __call__(self, constraint): | |
| """ | |
| Looks up a transform to constrained space, given a constraint object. | |
| Usage:: | |
| constraint = Normal.arg_constraints['scale'] | |
| scale = transform_to(constraint)(torch.zeros(1)) # constrained | |
| u = transform_to(constraint).inv(scale) # unconstrained | |
| Args: | |
| constraint (:class:`~torch.distributions.constraints.Constraint`): | |
| A constraint object. | |
| Returns: | |
| A :class:`~torch.distributions.transforms.Transform` object. | |
| Raises: | |
| `NotImplementedError` if no transform has been registered. | |
| """ | |
| # Look up by Constraint subclass. | |
| try: | |
| factory = self._registry[type(constraint)] | |
| except KeyError: | |
| raise NotImplementedError( | |
| f"Cannot transform {type(constraint).__name__} constraints" | |
| ) from None | |
| return factory(constraint) | |
| biject_to = ConstraintRegistry() | |
| transform_to = ConstraintRegistry() | |
| ################################################################################ | |
| # Registration Table | |
| ################################################################################ | |
| def _transform_to_real(constraint): | |
| return transforms.identity_transform | |
| def _biject_to_independent(constraint): | |
| base_transform = biject_to(constraint.base_constraint) | |
| return transforms.IndependentTransform( | |
| base_transform, constraint.reinterpreted_batch_ndims | |
| ) | |
| def _transform_to_independent(constraint): | |
| base_transform = transform_to(constraint.base_constraint) | |
| return transforms.IndependentTransform( | |
| base_transform, constraint.reinterpreted_batch_ndims | |
| ) | |
| def _transform_to_positive(constraint): | |
| return transforms.ExpTransform() | |
| def _transform_to_greater_than(constraint): | |
| return transforms.ComposeTransform( | |
| [ | |
| transforms.ExpTransform(), | |
| transforms.AffineTransform(constraint.lower_bound, 1), | |
| ] | |
| ) | |
| def _transform_to_less_than(constraint): | |
| return transforms.ComposeTransform( | |
| [ | |
| transforms.ExpTransform(), | |
| transforms.AffineTransform(constraint.upper_bound, -1), | |
| ] | |
| ) | |
| def _transform_to_interval(constraint): | |
| # Handle the special case of the unit interval. | |
| lower_is_0 = ( | |
| isinstance(constraint.lower_bound, numbers.Number) | |
| and constraint.lower_bound == 0 | |
| ) | |
| upper_is_1 = ( | |
| isinstance(constraint.upper_bound, numbers.Number) | |
| and constraint.upper_bound == 1 | |
| ) | |
| if lower_is_0 and upper_is_1: | |
| return transforms.SigmoidTransform() | |
| loc = constraint.lower_bound | |
| scale = constraint.upper_bound - constraint.lower_bound | |
| return transforms.ComposeTransform( | |
| [transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)] | |
| ) | |
| def _biject_to_simplex(constraint): | |
| return transforms.StickBreakingTransform() | |
| def _transform_to_simplex(constraint): | |
| return transforms.SoftmaxTransform() | |
| # TODO define a bijection for LowerCholeskyTransform | |
| def _transform_to_lower_cholesky(constraint): | |
| return transforms.LowerCholeskyTransform() | |
| def _transform_to_positive_definite(constraint): | |
| return transforms.PositiveDefiniteTransform() | |
| def _transform_to_corr_cholesky(constraint): | |
| return transforms.CorrCholeskyTransform() | |
| def _biject_to_cat(constraint): | |
| return transforms.CatTransform( | |
| [biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths | |
| ) | |
| def _transform_to_cat(constraint): | |
| return transforms.CatTransform( | |
| [transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths | |
| ) | |
| def _biject_to_stack(constraint): | |
| return transforms.StackTransform( | |
| [biject_to(c) for c in constraint.cseq], constraint.dim | |
| ) | |
| def _transform_to_stack(constraint): | |
| return transforms.StackTransform( | |
| [transform_to(c) for c in constraint.cseq], constraint.dim | |
| ) | |