Spaces:
Running
Running
File size: 8,875 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
from typing import Dict
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.independent import Independent
from torch.distributions.transforms import ComposeTransform, Transform
from torch.distributions.utils import _sum_rightmost
__all__ = ["TransformedDistribution"]
class TransformedDistribution(Distribution):
r"""
Extension of the Distribution class, which applies a sequence of Transforms
to a base distribution. Let f be the composition of transforms applied::
X ~ BaseDistribution
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
log p(Y) = log p(X) + log |det (dX/dY)|
Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
maximum shape of its base distribution and its transforms, since transforms
can introduce correlations among events.
An example for the usage of :class:`TransformedDistribution` would be::
# Building a Logistic Distribution
# X ~ Uniform(0, 1)
# f = a + b * logit(X)
# Y ~ f(X) ~ Logistic(a, b)
base_distribution = Uniform(0, 1)
transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
logistic = TransformedDistribution(base_distribution, transforms)
For more examples, please look at the implementations of
:class:`~torch.distributions.gumbel.Gumbel`,
:class:`~torch.distributions.half_cauchy.HalfCauchy`,
:class:`~torch.distributions.half_normal.HalfNormal`,
:class:`~torch.distributions.log_normal.LogNormal`,
:class:`~torch.distributions.pareto.Pareto`,
:class:`~torch.distributions.weibull.Weibull`,
:class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
:class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
"""
arg_constraints: Dict[str, constraints.Constraint] = {}
def __init__(self, base_distribution, transforms, validate_args=None):
if isinstance(transforms, Transform):
self.transforms = [
transforms,
]
elif isinstance(transforms, list):
if not all(isinstance(t, Transform) for t in transforms):
raise ValueError(
"transforms must be a Transform or a list of Transforms"
)
self.transforms = transforms
else:
raise ValueError(
f"transforms must be a Transform or list, but was {transforms}"
)
# Reshape base_distribution according to transforms.
base_shape = base_distribution.batch_shape + base_distribution.event_shape
base_event_dim = len(base_distribution.event_shape)
transform = ComposeTransform(self.transforms)
if len(base_shape) < transform.domain.event_dim:
raise ValueError(
"base_distribution needs to have shape with size at least {}, but got {}.".format(
transform.domain.event_dim, base_shape
)
)
forward_shape = transform.forward_shape(base_shape)
expanded_base_shape = transform.inverse_shape(forward_shape)
if base_shape != expanded_base_shape:
base_batch_shape = expanded_base_shape[
: len(expanded_base_shape) - base_event_dim
]
base_distribution = base_distribution.expand(base_batch_shape)
reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
if reinterpreted_batch_ndims > 0:
base_distribution = Independent(
base_distribution, reinterpreted_batch_ndims
)
self.base_dist = base_distribution
# Compute shapes.
transform_change_in_event_dim = (
transform.codomain.event_dim - transform.domain.event_dim
)
event_dim = max(
transform.codomain.event_dim, # the transform is coupled
base_event_dim + transform_change_in_event_dim, # the base dist is coupled
)
assert len(forward_shape) >= event_dim
cut = len(forward_shape) - event_dim
batch_shape = forward_shape[:cut]
event_shape = forward_shape[cut:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(TransformedDistribution, _instance)
batch_shape = torch.Size(batch_shape)
shape = batch_shape + self.event_shape
for t in reversed(self.transforms):
shape = t.inverse_shape(shape)
base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)]
new.base_dist = self.base_dist.expand(base_batch_shape)
new.transforms = self.transforms
super(TransformedDistribution, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
@constraints.dependent_property(is_discrete=False)
def support(self):
if not self.transforms:
return self.base_dist.support
support = self.transforms[-1].codomain
if len(self.event_shape) > support.event_dim:
support = constraints.independent(
support, len(self.event_shape) - support.event_dim
)
return support
@property
def has_rsample(self):
return self.base_dist.has_rsample
def sample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped sample or sample_shape shaped batch of
samples if the distribution parameters are batched. Samples first from
base distribution and applies `transform()` for every transform in the
list.
"""
with torch.no_grad():
x = self.base_dist.sample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x
def rsample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped reparameterized sample or sample_shape
shaped batch of reparameterized samples if the distribution parameters
are batched. Samples first from base distribution and applies
`transform()` for every transform in the list.
"""
x = self.base_dist.rsample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x
def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
event_dim += transform.domain.event_dim - transform.codomain.event_dim
log_prob = log_prob - _sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
y = x
log_prob = log_prob + _sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
)
return log_prob
def _monotonize_cdf(self, value):
"""
This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
monotone increasing.
"""
sign = 1
for transform in self.transforms:
sign = sign * transform.sign
if isinstance(sign, int) and sign == 1:
return value
return sign * (value - 0.5) + 0.5
def cdf(self, value):
"""
Computes the cumulative distribution function by inverting the
transform(s) and computing the score of the base distribution.
"""
for transform in self.transforms[::-1]:
value = transform.inv(value)
if self._validate_args:
self.base_dist._validate_sample(value)
value = self.base_dist.cdf(value)
value = self._monotonize_cdf(value)
return value
def icdf(self, value):
"""
Computes the inverse cumulative distribution function using
transform(s) and computing the score of the base distribution.
"""
value = self._monotonize_cdf(value)
value = self.base_dist.icdf(value)
for transform in self.transforms:
value = transform(value)
return value
|