Spaces:
Sleeping
Sleeping
File size: 6,270 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import math
import torch
import torch.jit
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all, lazy_property
__all__ = ["VonMises"]
def _eval_poly(y, coef):
coef = list(coef)
result = coef.pop()
while coef:
result = coef.pop() + y * result
return result
_I0_COEF_SMALL = [
1.0,
3.5156229,
3.0899424,
1.2067492,
0.2659732,
0.360768e-1,
0.45813e-2,
]
_I0_COEF_LARGE = [
0.39894228,
0.1328592e-1,
0.225319e-2,
-0.157565e-2,
0.916281e-2,
-0.2057706e-1,
0.2635537e-1,
-0.1647633e-1,
0.392377e-2,
]
_I1_COEF_SMALL = [
0.5,
0.87890594,
0.51498869,
0.15084934,
0.2658733e-1,
0.301532e-2,
0.32411e-3,
]
_I1_COEF_LARGE = [
0.39894228,
-0.3988024e-1,
-0.362018e-2,
0.163801e-2,
-0.1031555e-1,
0.2282967e-1,
-0.2895312e-1,
0.1787654e-1,
-0.420059e-2,
]
_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
def _log_modified_bessel_fn(x, order=0):
"""
Returns ``log(I_order(x))`` for ``x > 0``,
where `order` is either 0 or 1.
"""
assert order == 0 or order == 1
# compute small solution
y = x / 3.75
y = y * y
small = _eval_poly(y, _COEF_SMALL[order])
if order == 1:
small = x.abs() * small
small = small.log()
# compute large solution
y = 3.75 / x
large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
result = torch.where(x < 3.75, small, large)
return result
@torch.jit.script_if_tracing
def _rejection_sample(loc, concentration, proposal_r, x):
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
while not done.all():
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
u1, u2, u3 = u.unbind()
z = torch.cos(math.pi * u1)
f = (1 + proposal_r * z) / (proposal_r + z)
c = concentration * (proposal_r - f)
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
if accept.any():
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
done = done | accept
return (x + math.pi + loc) % (2 * math.pi) - math.pi
class VonMises(Distribution):
"""
A circular von Mises distribution.
This implementation uses polar coordinates. The ``loc`` and ``value`` args
can be any real number (to facilitate unconstrained optimization), but are
interpreted as angles modulo 2 pi.
Example::
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample() # von Mises distributed with loc=1 and concentration=1
tensor([1.9777])
:param torch.Tensor loc: an angle in radians.
:param torch.Tensor concentration: concentration parameter
"""
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
support = constraints.real
has_rsample = False
def __init__(self, loc, concentration, validate_args=None):
self.loc, self.concentration = broadcast_all(loc, concentration)
batch_shape = self.loc.shape
event_shape = torch.Size()
super().__init__(batch_shape, event_shape, validate_args)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
log_prob = self.concentration * torch.cos(value - self.loc)
log_prob = (
log_prob
- math.log(2 * math.pi)
- _log_modified_bessel_fn(self.concentration, order=0)
)
return log_prob
@lazy_property
def _loc(self):
return self.loc.to(torch.double)
@lazy_property
def _concentration(self):
return self.concentration.to(torch.double)
@lazy_property
def _proposal_r(self):
kappa = self._concentration
tau = 1 + (1 + 4 * kappa**2).sqrt()
rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
_proposal_r = (1 + rho**2) / (2 * rho)
# second order Taylor expansion around 0 for small kappa
_proposal_r_taylor = 1 / kappa + kappa
return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
@torch.no_grad()
def sample(self, sample_shape=torch.Size()):
"""
The sampling algorithm for the von Mises distribution is based on the
following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
von Mises distribution." Applied Statistics (1979): 152-157.
Sampling is always done in double precision internally to avoid a hang
in _rejection_sample() for small values of the concentration, which
starts to happen for single precision around 1e-4 (see issue #88443).
"""
shape = self._extended_shape(sample_shape)
x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
return _rejection_sample(
self._loc, self._concentration, self._proposal_r, x
).to(self.loc.dtype)
def expand(self, batch_shape):
try:
return super().expand(batch_shape)
except NotImplementedError:
validate_args = self.__dict__.get("_validate_args")
loc = self.loc.expand(batch_shape)
concentration = self.concentration.expand(batch_shape)
return type(self)(loc, concentration, validate_args=validate_args)
@property
def mean(self):
"""
The provided mean is the circular one.
"""
return self.loc
@property
def mode(self):
return self.loc
@lazy_property
def variance(self):
"""
The provided variance is the circular one.
"""
return (
1
- (
_log_modified_bessel_fn(self.concentration, order=1)
- _log_modified_bessel_fn(self.concentration, order=0)
).exp()
)
|