File size: 6,515 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""

This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro).



Original copyright notice:



# Copyright: Contributors to the Pyro project.

# SPDX-License-Identifier: Apache-2.0

"""

import math

import torch
from torch.distributions import Beta, constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all

__all__ = ["LKJCholesky"]


class LKJCholesky(Distribution):
    r"""

    LKJ distribution for lower Cholesky factor of correlation matrices.

    The distribution is controlled by ``concentration`` parameter :math:`\eta`

    to make the probability of the correlation matrix :math:`M` generated from

    a Cholesky factor proportional to :math:`\det(M)^{\eta - 1}`. Because of that,

    when ``concentration == 1``, we have a uniform distribution over Cholesky

    factors of correlation matrices::



        L ~ LKJCholesky(dim, concentration)

        X = L @ L' ~ LKJCorr(dim, concentration)



    Note that this distribution samples the

    Cholesky factor of correlation matrices and not the correlation matrices

    themselves and thereby differs slightly from the derivations in [1] for

    the `LKJCorr` distribution. For sampling, this uses the Onion method from

    [1] Section 3.



    Example::



        >>> # xdoctest: +IGNORE_WANT("non-deterministic")

        >>> l = LKJCholesky(3, 0.5)

        >>> l.sample()  # l @ l.T is a sample of a correlation 3x3 matrix

        tensor([[ 1.0000,  0.0000,  0.0000],

                [ 0.3516,  0.9361,  0.0000],

                [-0.1899,  0.4748,  0.8593]])



    Args:

        dimension (dim): dimension of the matrices

        concentration (float or Tensor): concentration/shape parameter of the

            distribution (often referred to as eta)



    **References**



    [1] `Generating random correlation matrices based on vines and extended onion method` (2009),

    Daniel Lewandowski, Dorota Kurowicka, Harry Joe.

    Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008

    """
    arg_constraints = {"concentration": constraints.positive}
    support = constraints.corr_cholesky

    def __init__(self, dim, concentration=1.0, validate_args=None):
        if dim < 2:
            raise ValueError(
                f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}."
            )
        self.dim = dim
        (self.concentration,) = broadcast_all(concentration)
        batch_shape = self.concentration.size()
        event_shape = torch.Size((dim, dim))
        # This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1].
        marginal_conc = self.concentration + 0.5 * (self.dim - 2)
        offset = torch.arange(
            self.dim - 1,
            dtype=self.concentration.dtype,
            device=self.concentration.device,
        )
        offset = torch.cat([offset.new_zeros((1,)), offset])
        beta_conc1 = offset + 0.5
        beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset
        self._beta = Beta(beta_conc1, beta_conc0)
        super().__init__(batch_shape, event_shape, validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(LKJCholesky, _instance)
        batch_shape = torch.Size(batch_shape)
        new.dim = self.dim
        new.concentration = self.concentration.expand(batch_shape)
        new._beta = self._beta.expand(batch_shape + (self.dim,))
        super(LKJCholesky, new).__init__(
            batch_shape, self.event_shape, validate_args=False
        )
        new._validate_args = self._validate_args
        return new

    def sample(self, sample_shape=torch.Size()):
        # This uses the Onion method, but there are a few differences from [1] Sec. 3.2:
        # - This vectorizes the for loop and also works for heterogeneous eta.
        # - Same algorithm generalizes to n=1.
        # - The procedure is simplified since we are sampling the cholesky factor of
        #   the correlation matrix instead of the correlation matrix itself. As such,
        #   we only need to generate `w`.
        y = self._beta.sample(sample_shape).unsqueeze(-1)
        u_normal = torch.randn(
            self._extended_shape(sample_shape), dtype=y.dtype, device=y.device
        ).tril(-1)
        u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True)
        # Replace NaNs in first row
        u_hypersphere[..., 0, :].fill_(0.0)
        w = torch.sqrt(y) * u_hypersphere
        # Fill diagonal elements; clamp for numerical stability
        eps = torch.finfo(w.dtype).tiny
        diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt()
        w += torch.diag_embed(diag_elems)
        return w

    def log_prob(self, value):
        # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html
        # The probability of a correlation matrix is proportional to
        #   determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))
        # Additionally, the Jacobian of the transformation from Cholesky factor to
        # correlation matrix is:
        #   prod(L_ii ^ (D - i))
        # So the probability of a Cholesky factor is propotional to
        #   prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i)
        # with order_i = 2 * concentration - 2 + D - i
        if self._validate_args:
            self._validate_sample(value)
        diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:]
        order = torch.arange(2, self.dim + 1, device=self.concentration.device)
        order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
        unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1)
        # Compute normalization constant (page 1999 of [1])
        dm1 = self.dim - 1
        alpha = self.concentration + 0.5 * dm1
        denominator = torch.lgamma(alpha) * dm1
        numerator = torch.mvlgamma(alpha - 0.5, dm1)
        # pi_constant in [1] is D * (D - 1) / 4 * log(pi)
        # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi)
        # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2
        pi_constant = 0.5 * dm1 * math.log(math.pi)
        normalize_term = pi_constant + numerator - denominator
        return unnormalized_log_pdf - normalize_term