Spaces:
Running
Running
File size: 5,302 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
from torch.distributions import constraints
from torch.distributions.categorical import Categorical
from torch.distributions.distribution import Distribution
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import ExpTransform
from torch.distributions.utils import broadcast_all, clamp_probs
__all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"]
class ExpRelaxedCategorical(Distribution):
r"""
Creates a ExpRelaxedCategorical parameterized by
:attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both).
Returns the log of a point in the simplex. Based on the interface to
:class:`OneHotCategorical`.
Implementation based on [1].
See also: :func:`torch.distributions.OneHotCategorical`
Args:
temperature (Tensor): relaxation temperature
probs (Tensor): event probabilities
logits (Tensor): unnormalized log probability for each event
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
(Maddison et al, 2017)
[2] Categorical Reparametrization with Gumbel-Softmax
(Jang et al, 2017)
"""
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = (
constraints.real_vector
) # The true support is actually a submanifold of this.
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
self._categorical = Categorical(probs, logits)
self.temperature = temperature
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
batch_shape = torch.Size(batch_shape)
new.temperature = self.temperature
new._categorical = self._categorical.expand(batch_shape)
super(ExpRelaxedCategorical, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)
@property
def param_shape(self):
return self._categorical.param_shape
@property
def logits(self):
return self._categorical.logits
@property
def probs(self):
return self._categorical.probs
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
uniforms = clamp_probs(
torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
)
gumbels = -((-(uniforms.log())).log())
scores = (self.logits + gumbels) / self.temperature
return scores - scores.logsumexp(dim=-1, keepdim=True)
def log_prob(self, value):
K = self._categorical._num_events
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
log_scale = torch.full_like(
self.temperature, float(K)
).lgamma() - self.temperature.log().mul(-(K - 1))
score = logits - value.mul(self.temperature)
score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
return score + log_scale
class RelaxedOneHotCategorical(TransformedDistribution):
r"""
Creates a RelaxedOneHotCategorical distribution parametrized by
:attr:`temperature`, and either :attr:`probs` or :attr:`logits`.
This is a relaxed version of the :class:`OneHotCategorical` distribution, so
its samples are on simplex, and are reparametrizable.
Example::
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
... torch.tensor([0.1, 0.2, 0.3, 0.4]))
>>> m.sample()
tensor([ 0.1294, 0.2324, 0.3859, 0.2523])
Args:
temperature (Tensor): relaxation temperature
probs (Tensor): event probabilities
logits (Tensor): unnormalized log probability for each event
"""
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.simplex
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
base_dist = ExpRelaxedCategorical(
temperature, probs, logits, validate_args=validate_args
)
super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
return super().expand(batch_shape, _instance=new)
@property
def temperature(self):
return self.base_dist.temperature
@property
def logits(self):
return self.base_dist.logits
@property
def probs(self):
return self.base_dist.probs
|