Spaces:
Sleeping
Sleeping
File size: 10,593 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 |
r"""
PyTorch provides two global :class:`ConstraintRegistry` objects that link
:class:`~torch.distributions.constraints.Constraint` objects to
:class:`~torch.distributions.transforms.Transform` objects. These objects both
input constraints and return transforms, but they have different guarantees on
bijectivity.
1. ``biject_to(constraint)`` looks up a bijective
:class:`~torch.distributions.transforms.Transform` from ``constraints.real``
to the given ``constraint``. The returned transform is guaranteed to have
``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
2. ``transform_to(constraint)`` looks up a not-necessarily bijective
:class:`~torch.distributions.transforms.Transform` from ``constraints.real``
to the given ``constraint``. The returned transform is not guaranteed to
implement ``.log_abs_det_jacobian()``.
The ``transform_to()`` registry is useful for performing unconstrained
optimization on constrained parameters of probability distributions, which are
indicated by each distribution's ``.arg_constraints`` dict. These transforms often
overparameterize a space in order to avoid rotation; they are thus more
suitable for coordinate-wise optimization algorithms like Adam::
loc = torch.zeros(100, requires_grad=True)
unconstrained = torch.zeros(100, requires_grad=True)
scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
loss = -Normal(loc, scale).log_prob(data).sum()
The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
samples from a probability distribution with constrained ``.support`` are
propagated in an unconstrained space, and algorithms are typically rotation
invariant.::
dist = Exponential(rate)
unconstrained = torch.zeros(100, requires_grad=True)
sample = biject_to(dist.support)(unconstrained)
potential_energy = -dist.log_prob(sample).sum()
.. note::
An example where ``transform_to`` and ``biject_to`` differ is
``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
:class:`~torch.distributions.transforms.SoftmaxTransform` that simply
exponentiates and normalizes its inputs; this is a cheap and mostly
coordinate-wise operation appropriate for algorithms like SVI. In
contrast, ``biject_to(constraints.simplex)`` returns a
:class:`~torch.distributions.transforms.StickBreakingTransform` that
bijects its input down to a one-fewer-dimensional space; this a more
expensive less numerically stable transform but is needed for algorithms
like HMC.
The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
constraints and transforms using their ``.register()`` method either as a
function on singleton constraints::
transform_to.register(my_constraint, my_transform)
or as a decorator on parameterized constraints::
@transform_to.register(MyConstraintClass)
def my_factory(constraint):
assert isinstance(constraint, MyConstraintClass)
return MyTransform(constraint.param1, constraint.param2)
You can create your own registry by creating a new :class:`ConstraintRegistry`
object.
"""
import numbers
from torch.distributions import constraints, transforms
__all__ = [
"ConstraintRegistry",
"biject_to",
"transform_to",
]
class ConstraintRegistry:
"""
Registry to link constraints to transforms.
"""
def __init__(self):
self._registry = {}
super().__init__()
def register(self, constraint, factory=None):
"""
Registers a :class:`~torch.distributions.constraints.Constraint`
subclass in this registry. Usage::
@my_registry.register(MyConstraintClass)
def construct_transform(constraint):
assert isinstance(constraint, MyConstraint)
return MyTransform(constraint.arg_constraints)
Args:
constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
A subclass of :class:`~torch.distributions.constraints.Constraint`, or
a singleton object of the desired class.
factory (Callable): A callable that inputs a constraint object and returns
a :class:`~torch.distributions.transforms.Transform` object.
"""
# Support use as decorator.
if factory is None:
return lambda factory: self.register(constraint, factory)
# Support calling on singleton instances.
if isinstance(constraint, constraints.Constraint):
constraint = type(constraint)
if not isinstance(constraint, type) or not issubclass(
constraint, constraints.Constraint
):
raise TypeError(
f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}"
)
self._registry[constraint] = factory
return factory
def __call__(self, constraint):
"""
Looks up a transform to constrained space, given a constraint object.
Usage::
constraint = Normal.arg_constraints['scale']
scale = transform_to(constraint)(torch.zeros(1)) # constrained
u = transform_to(constraint).inv(scale) # unconstrained
Args:
constraint (:class:`~torch.distributions.constraints.Constraint`):
A constraint object.
Returns:
A :class:`~torch.distributions.transforms.Transform` object.
Raises:
`NotImplementedError` if no transform has been registered.
"""
# Look up by Constraint subclass.
try:
factory = self._registry[type(constraint)]
except KeyError:
raise NotImplementedError(
f"Cannot transform {type(constraint).__name__} constraints"
) from None
return factory(constraint)
biject_to = ConstraintRegistry()
transform_to = ConstraintRegistry()
################################################################################
# Registration Table
################################################################################
@biject_to.register(constraints.real)
@transform_to.register(constraints.real)
def _transform_to_real(constraint):
return transforms.identity_transform
@biject_to.register(constraints.independent)
def _biject_to_independent(constraint):
base_transform = biject_to(constraint.base_constraint)
return transforms.IndependentTransform(
base_transform, constraint.reinterpreted_batch_ndims
)
@transform_to.register(constraints.independent)
def _transform_to_independent(constraint):
base_transform = transform_to(constraint.base_constraint)
return transforms.IndependentTransform(
base_transform, constraint.reinterpreted_batch_ndims
)
@biject_to.register(constraints.positive)
@biject_to.register(constraints.nonnegative)
@transform_to.register(constraints.positive)
@transform_to.register(constraints.nonnegative)
def _transform_to_positive(constraint):
return transforms.ExpTransform()
@biject_to.register(constraints.greater_than)
@biject_to.register(constraints.greater_than_eq)
@transform_to.register(constraints.greater_than)
@transform_to.register(constraints.greater_than_eq)
def _transform_to_greater_than(constraint):
return transforms.ComposeTransform(
[
transforms.ExpTransform(),
transforms.AffineTransform(constraint.lower_bound, 1),
]
)
@biject_to.register(constraints.less_than)
@transform_to.register(constraints.less_than)
def _transform_to_less_than(constraint):
return transforms.ComposeTransform(
[
transforms.ExpTransform(),
transforms.AffineTransform(constraint.upper_bound, -1),
]
)
@biject_to.register(constraints.interval)
@biject_to.register(constraints.half_open_interval)
@transform_to.register(constraints.interval)
@transform_to.register(constraints.half_open_interval)
def _transform_to_interval(constraint):
# Handle the special case of the unit interval.
lower_is_0 = (
isinstance(constraint.lower_bound, numbers.Number)
and constraint.lower_bound == 0
)
upper_is_1 = (
isinstance(constraint.upper_bound, numbers.Number)
and constraint.upper_bound == 1
)
if lower_is_0 and upper_is_1:
return transforms.SigmoidTransform()
loc = constraint.lower_bound
scale = constraint.upper_bound - constraint.lower_bound
return transforms.ComposeTransform(
[transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)]
)
@biject_to.register(constraints.simplex)
def _biject_to_simplex(constraint):
return transforms.StickBreakingTransform()
@transform_to.register(constraints.simplex)
def _transform_to_simplex(constraint):
return transforms.SoftmaxTransform()
# TODO define a bijection for LowerCholeskyTransform
@transform_to.register(constraints.lower_cholesky)
def _transform_to_lower_cholesky(constraint):
return transforms.LowerCholeskyTransform()
@transform_to.register(constraints.positive_definite)
@transform_to.register(constraints.positive_semidefinite)
def _transform_to_positive_definite(constraint):
return transforms.PositiveDefiniteTransform()
@biject_to.register(constraints.corr_cholesky)
@transform_to.register(constraints.corr_cholesky)
def _transform_to_corr_cholesky(constraint):
return transforms.CorrCholeskyTransform()
@biject_to.register(constraints.cat)
def _biject_to_cat(constraint):
return transforms.CatTransform(
[biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
)
@transform_to.register(constraints.cat)
def _transform_to_cat(constraint):
return transforms.CatTransform(
[transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
)
@biject_to.register(constraints.stack)
def _biject_to_stack(constraint):
return transforms.StackTransform(
[biject_to(c) for c in constraint.cseq], constraint.dim
)
@transform_to.register(constraints.stack)
def _transform_to_stack(constraint):
return transforms.StackTransform(
[transform_to(c) for c in constraint.cseq], constraint.dim
)
|