File size: 6,206 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
r"""

The ``distributions`` package contains parameterizable probability distributions

and sampling functions. This allows the construction of stochastic computation

graphs and stochastic gradient estimators for optimization. This package

generally follows the design of the `TensorFlow Distributions`_ package.



.. _`TensorFlow Distributions`:

    https://arxiv.org/abs/1711.10604



It is not possible to directly backpropagate through random samples. However,

there are two main methods for creating surrogate functions that can be

backpropagated through. These are the score function estimator/likelihood ratio

estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly

seen as the basis for policy gradient methods in reinforcement learning, and the

pathwise derivative estimator is commonly seen in the reparameterization trick

in variational autoencoders. Whilst the score function only requires the value

of samples :math:`f(x)`, the pathwise derivative requires the derivative

:math:`f'(x)`. The next sections discuss these two in a reinforcement learning

example. For more details see

`Gradient Estimation Using Stochastic Computation Graphs`_ .



.. _`Gradient Estimation Using Stochastic Computation Graphs`:

     https://arxiv.org/abs/1506.05254



Score function

^^^^^^^^^^^^^^



When the probability density function is differentiable with respect to its

parameters, we only need :meth:`~torch.distributions.Distribution.sample` and

:meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE:



.. math::



    \Delta\theta  = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}



where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate,

:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of

taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`.



In practice we would sample an action from the output of a network, apply this

action in an environment, and then use ``log_prob`` to construct an equivalent

loss function. Note that we use a negative because optimizers use gradient

descent, whilst the rule above assumes gradient ascent. With a categorical

policy, the code for implementing REINFORCE would be as follows::



    probs = policy_network(state)

    # Note that this is equivalent to what used to be called multinomial

    m = Categorical(probs)

    action = m.sample()

    next_state, reward = env.step(action)

    loss = -m.log_prob(action) * reward

    loss.backward()



Pathwise derivative

^^^^^^^^^^^^^^^^^^^



The other way to implement these stochastic/policy gradients would be to use the

reparameterization trick from the

:meth:`~torch.distributions.Distribution.rsample` method, where the

parameterized random variable can be constructed via a parameterized

deterministic function of a parameter-free random variable. The reparameterized

sample therefore becomes differentiable. The code for implementing the pathwise

derivative would be as follows::



    params = policy_network(state)

    m = Normal(*params)

    # Any distribution with .has_rsample == True could work based on the application

    action = m.rsample()

    next_state, reward = env.step(action)  # Assuming that reward is differentiable

    loss = -reward

    loss.backward()

"""

from .bernoulli import Bernoulli
from .beta import Beta
from .binomial import Binomial
from .categorical import Categorical
from .cauchy import Cauchy
from .chi2 import Chi2
from .constraint_registry import biject_to, transform_to
from .continuous_bernoulli import ContinuousBernoulli
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exp_family import ExponentialFamily
from .exponential import Exponential
from .fishersnedecor import FisherSnedecor
from .gamma import Gamma
from .geometric import Geometric
from .gumbel import Gumbel
from .half_cauchy import HalfCauchy
from .half_normal import HalfNormal
from .independent import Independent
from .inverse_gamma import InverseGamma
from .kl import _add_kl_info, kl_divergence, register_kl
from .kumaraswamy import Kumaraswamy
from .laplace import Laplace
from .lkj_cholesky import LKJCholesky
from .log_normal import LogNormal
from .logistic_normal import LogisticNormal
from .lowrank_multivariate_normal import LowRankMultivariateNormal
from .mixture_same_family import MixtureSameFamily
from .multinomial import Multinomial
from .multivariate_normal import MultivariateNormal
from .negative_binomial import NegativeBinomial
from .normal import Normal
from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough
from .pareto import Pareto
from .poisson import Poisson
from .relaxed_bernoulli import RelaxedBernoulli
from .relaxed_categorical import RelaxedOneHotCategorical
from .studentT import StudentT
from .transformed_distribution import TransformedDistribution
from .transforms import *  # noqa: F403
from . import transforms
from .uniform import Uniform
from .von_mises import VonMises
from .weibull import Weibull
from .wishart import Wishart

_add_kl_info()
del _add_kl_info

__all__ = [
    "Bernoulli",
    "Beta",
    "Binomial",
    "Categorical",
    "Cauchy",
    "Chi2",
    "ContinuousBernoulli",
    "Dirichlet",
    "Distribution",
    "Exponential",
    "ExponentialFamily",
    "FisherSnedecor",
    "Gamma",
    "Geometric",
    "Gumbel",
    "HalfCauchy",
    "HalfNormal",
    "Independent",
    "InverseGamma",
    "Kumaraswamy",
    "LKJCholesky",
    "Laplace",
    "LogNormal",
    "LogisticNormal",
    "LowRankMultivariateNormal",
    "MixtureSameFamily",
    "Multinomial",
    "MultivariateNormal",
    "NegativeBinomial",
    "Normal",
    "OneHotCategorical",
    "OneHotCategoricalStraightThrough",
    "Pareto",
    "RelaxedBernoulli",
    "RelaxedOneHotCategorical",
    "StudentT",
    "Poisson",
    "Uniform",
    "VonMises",
    "Weibull",
    "Wishart",
    "TransformedDistribution",
    "biject_to",
    "kl_divergence",
    "register_kl",
    "transform_to",
]
__all__.extend(transforms.__all__)