Spaces:
Sleeping
Sleeping
File size: 5,535 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
from torch import inf
from torch.distributions import Categorical, constraints
from torch.distributions.binomial import Binomial
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
__all__ = ["Multinomial"]
class Multinomial(Distribution):
r"""
Creates a Multinomial distribution parameterized by :attr:`total_count` and
either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
:attr:`probs` indexes over categories. All other dimensions index over batches.
Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
called (see example below)
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
will return this normalized value.
The `logits` argument will be interpreted as unnormalized log probabilities
and can therefore be any real number. It will likewise be normalized so that
the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
will return this normalized value.
- :meth:`sample` requires a single shared `total_count` for all
parameters and samples.
- :meth:`log_prob` allows different `total_count` for each parameter and
sample.
Example::
>>> # xdoctest: +SKIP("FIXME: found invalid values")
>>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
>>> x = m.sample() # equal probability of 0, 1, 2, 3
tensor([ 21., 24., 30., 25.])
>>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
tensor([-4.1338])
Args:
total_count (int): number of trials
probs (Tensor): event probabilities
logits (Tensor): event log probabilities (unnormalized)
"""
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
total_count: int
@property
def mean(self):
return self.probs * self.total_count
@property
def variance(self):
return self.total_count * self.probs * (1 - self.probs)
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
if not isinstance(total_count, int):
raise NotImplementedError("inhomogeneous total_count is not supported")
self.total_count = total_count
self._categorical = Categorical(probs=probs, logits=logits)
self._binomial = Binomial(total_count=total_count, probs=self.probs)
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Multinomial, _instance)
batch_shape = torch.Size(batch_shape)
new.total_count = self.total_count
new._categorical = self._categorical.expand(batch_shape)
super(Multinomial, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=1)
def support(self):
return constraints.multinomial(self.total_count)
@property
def logits(self):
return self._categorical.logits
@property
def probs(self):
return self._categorical.probs
@property
def param_shape(self):
return self._categorical.param_shape
def sample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
samples = self._categorical.sample(
torch.Size((self.total_count,)) + sample_shape
)
# samples.shape is (total_count, sample_shape, batch_shape), need to change it to
# (sample_shape, batch_shape, total_count)
shifted_idx = list(range(samples.dim()))
shifted_idx.append(shifted_idx.pop(0))
samples = samples.permute(*shifted_idx)
counts = samples.new(self._extended_shape(sample_shape)).zero_()
counts.scatter_add_(-1, samples, torch.ones_like(samples))
return counts.type_as(self.probs)
def entropy(self):
n = torch.tensor(self.total_count)
cat_entropy = self._categorical.entropy()
term1 = n * cat_entropy - torch.lgamma(n + 1)
support = self._binomial.enumerate_support(expand=False)[1:]
binomial_probs = torch.exp(self._binomial.log_prob(support))
weights = torch.lgamma(support + 1)
term2 = (binomial_probs * weights).sum([0, -1])
return term1 + term2
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
logits = logits.clone(memory_format=torch.contiguous_format)
log_factorial_n = torch.lgamma(value.sum(-1) + 1)
log_factorial_xs = torch.lgamma(value + 1).sum(-1)
logits[(value == 0) & (logits == -inf)] = 0
log_powers = (logits * value).sum(-1)
return log_factorial_n - log_factorial_xs + log_powers
|