File size: 12,660 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import warnings
from typing import Any, Dict, Optional, Tuple

import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property
from torch.types import _size

__all__ = ["Distribution"]


class Distribution:
    r"""

    Distribution is the abstract base class for probability distributions.

    """

    has_rsample = False
    has_enumerate_support = False
    _validate_args = __debug__

    @staticmethod
    def set_default_validate_args(value: bool) -> None:
        """

        Sets whether validation is enabled or disabled.



        The default behavior mimics Python's ``assert`` statement: validation

        is on by default, but is disabled if Python is run in optimized mode

        (via ``python -O``). Validation may be expensive, so you may want to

        disable it once a model is working.



        Args:

            value (bool): Whether to enable validation.

        """
        if value not in [True, False]:
            raise ValueError
        Distribution._validate_args = value

    def __init__(

        self,

        batch_shape: torch.Size = torch.Size(),

        event_shape: torch.Size = torch.Size(),

        validate_args: Optional[bool] = None,

    ):
        self._batch_shape = batch_shape
        self._event_shape = event_shape
        if validate_args is not None:
            self._validate_args = validate_args
        if self._validate_args:
            try:
                arg_constraints = self.arg_constraints
            except NotImplementedError:
                arg_constraints = {}
                warnings.warn(
                    f"{self.__class__} does not define `arg_constraints`. "
                    + "Please set `arg_constraints = {}` or initialize the distribution "
                    + "with `validate_args=False` to turn off validation."
                )
            for param, constraint in arg_constraints.items():
                if constraints.is_dependent(constraint):
                    continue  # skip constraints that cannot be checked
                if param not in self.__dict__ and isinstance(
                    getattr(type(self), param), lazy_property
                ):
                    continue  # skip checking lazily-constructed args
                value = getattr(self, param)
                valid = constraint.check(value)
                if not valid.all():
                    raise ValueError(
                        f"Expected parameter {param} "
                        f"({type(value).__name__} of shape {tuple(value.shape)}) "
                        f"of distribution {repr(self)} "
                        f"to satisfy the constraint {repr(constraint)}, "
                        f"but found invalid values:\n{value}"
                    )
        super().__init__()

    def expand(self, batch_shape: torch.Size, _instance=None):
        """

        Returns a new distribution instance (or populates an existing instance

        provided by a derived class) with batch dimensions expanded to

        `batch_shape`. This method calls :class:`~torch.Tensor.expand` on

        the distribution's parameters. As such, this does not allocate new

        memory for the expanded distribution instance. Additionally,

        this does not repeat any args checking or parameter broadcasting in

        `__init__.py`, when an instance is first created.



        Args:

            batch_shape (torch.Size): the desired expanded size.

            _instance: new instance provided by subclasses that

                need to override `.expand`.



        Returns:

            New distribution instance with batch dimensions expanded to

            `batch_size`.

        """
        raise NotImplementedError

    @property
    def batch_shape(self) -> torch.Size:
        """

        Returns the shape over which parameters are batched.

        """
        return self._batch_shape

    @property
    def event_shape(self) -> torch.Size:
        """

        Returns the shape of a single sample (without batching).

        """
        return self._event_shape

    @property
    def arg_constraints(self) -> Dict[str, constraints.Constraint]:
        """

        Returns a dictionary from argument names to

        :class:`~torch.distributions.constraints.Constraint` objects that

        should be satisfied by each argument of this distribution. Args that

        are not tensors need not appear in this dict.

        """
        raise NotImplementedError

    @property
    def support(self) -> Optional[Any]:
        """

        Returns a :class:`~torch.distributions.constraints.Constraint` object

        representing this distribution's support.

        """
        raise NotImplementedError

    @property
    def mean(self) -> torch.Tensor:
        """

        Returns the mean of the distribution.

        """
        raise NotImplementedError

    @property
    def mode(self) -> torch.Tensor:
        """

        Returns the mode of the distribution.

        """
        raise NotImplementedError(f"{self.__class__} does not implement mode")

    @property
    def variance(self) -> torch.Tensor:
        """

        Returns the variance of the distribution.

        """
        raise NotImplementedError

    @property
    def stddev(self) -> torch.Tensor:
        """

        Returns the standard deviation of the distribution.

        """
        return self.variance.sqrt()

    def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
        """

        Generates a sample_shape shaped sample or sample_shape shaped batch of

        samples if the distribution parameters are batched.

        """
        with torch.no_grad():
            return self.rsample(sample_shape)

    def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
        """

        Generates a sample_shape shaped reparameterized sample or sample_shape

        shaped batch of reparameterized samples if the distribution parameters

        are batched.

        """
        raise NotImplementedError

    def sample_n(self, n: int) -> torch.Tensor:
        """

        Generates n samples or n batches of samples if the distribution

        parameters are batched.

        """
        warnings.warn(
            "sample_n will be deprecated. Use .sample((n,)) instead", UserWarning
        )
        return self.sample(torch.Size((n,)))

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        """

        Returns the log of the probability density/mass function evaluated at

        `value`.



        Args:

            value (Tensor):

        """
        raise NotImplementedError

    def cdf(self, value: torch.Tensor) -> torch.Tensor:
        """

        Returns the cumulative density/mass function evaluated at

        `value`.



        Args:

            value (Tensor):

        """
        raise NotImplementedError

    def icdf(self, value: torch.Tensor) -> torch.Tensor:
        """

        Returns the inverse cumulative density/mass function evaluated at

        `value`.



        Args:

            value (Tensor):

        """
        raise NotImplementedError

    def enumerate_support(self, expand: bool = True) -> torch.Tensor:
        """

        Returns tensor containing all values supported by a discrete

        distribution. The result will enumerate over dimension 0, so the shape

        of the result will be `(cardinality,) + batch_shape + event_shape`

        (where `event_shape = ()` for univariate distributions).



        Note that this enumerates over all batched tensors in lock-step

        `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens

        along dim 0, but with the remaining batch dimensions being

        singleton dimensions, `[[0], [1], ..`.



        To iterate over the full Cartesian product use

        `itertools.product(m.enumerate_support())`.



        Args:

            expand (bool): whether to expand the support over the

                batch dims to match the distribution's `batch_shape`.



        Returns:

            Tensor iterating over dimension 0.

        """
        raise NotImplementedError

    def entropy(self) -> torch.Tensor:
        """

        Returns entropy of distribution, batched over batch_shape.



        Returns:

            Tensor of shape batch_shape.

        """
        raise NotImplementedError

    def perplexity(self) -> torch.Tensor:
        """

        Returns perplexity of distribution, batched over batch_shape.



        Returns:

            Tensor of shape batch_shape.

        """
        return torch.exp(self.entropy())

    def _extended_shape(self, sample_shape: _size = torch.Size()) -> Tuple[int, ...]:
        """

        Returns the size of the sample returned by the distribution, given

        a `sample_shape`. Note, that the batch and event shapes of a distribution

        instance are fixed at the time of construction. If this is empty, the

        returned shape is upcast to (1,).



        Args:

            sample_shape (torch.Size): the size of the sample to be drawn.

        """
        if not isinstance(sample_shape, torch.Size):
            sample_shape = torch.Size(sample_shape)
        return torch.Size(sample_shape + self._batch_shape + self._event_shape)

    def _validate_sample(self, value: torch.Tensor) -> None:
        """

        Argument validation for distribution methods such as `log_prob`,

        `cdf` and `icdf`. The rightmost dimensions of a value to be

        scored via these methods must agree with the distribution's batch

        and event shapes.



        Args:

            value (Tensor): the tensor whose log probability is to be

                computed by the `log_prob` method.

        Raises

            ValueError: when the rightmost dimensions of `value` do not match the

                distribution's batch and event shapes.

        """
        if not isinstance(value, torch.Tensor):
            raise ValueError("The value argument to log_prob must be a Tensor")

        event_dim_start = len(value.size()) - len(self._event_shape)
        if value.size()[event_dim_start:] != self._event_shape:
            raise ValueError(
                f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}."
            )

        actual_shape = value.size()
        expected_shape = self._batch_shape + self._event_shape
        for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
            if i != 1 and j != 1 and i != j:
                raise ValueError(
                    f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}."
                )
        try:
            support = self.support
        except NotImplementedError:
            warnings.warn(
                f"{self.__class__} does not define `support` to enable "
                + "sample validation. Please initialize the distribution with "
                + "`validate_args=False` to turn off validation."
            )
            return
        assert support is not None
        valid = support.check(value)
        if not valid.all():
            raise ValueError(
                "Expected value argument "
                f"({type(value).__name__} of shape {tuple(value.shape)}) "
                f"to be within the support ({repr(support)}) "
                f"of the distribution {repr(self)}, "
                f"but found invalid values:\n{value}"
            )

    def _get_checked_instance(self, cls, _instance=None):
        if _instance is None and type(self).__init__ != cls.__init__:
            raise NotImplementedError(
                f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method "
                "must also define a custom .expand() method."
            )
        return self.__new__(type(self)) if _instance is None else _instance

    def __repr__(self) -> str:
        param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
        args_string = ", ".join(
            [
                f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}"
                for p in param_names
            ]
        )
        return self.__class__.__name__ + "(" + args_string + ")"