Spaces:
Sleeping
Sleeping
import warnings | |
from typing import Any, Dict, Optional, Tuple | |
import torch | |
from torch.distributions import constraints | |
from torch.distributions.utils import lazy_property | |
from torch.types import _size | |
__all__ = ["Distribution"] | |
class Distribution: | |
r""" | |
Distribution is the abstract base class for probability distributions. | |
""" | |
has_rsample = False | |
has_enumerate_support = False | |
_validate_args = __debug__ | |
def set_default_validate_args(value: bool) -> None: | |
""" | |
Sets whether validation is enabled or disabled. | |
The default behavior mimics Python's ``assert`` statement: validation | |
is on by default, but is disabled if Python is run in optimized mode | |
(via ``python -O``). Validation may be expensive, so you may want to | |
disable it once a model is working. | |
Args: | |
value (bool): Whether to enable validation. | |
""" | |
if value not in [True, False]: | |
raise ValueError | |
Distribution._validate_args = value | |
def __init__( | |
self, | |
batch_shape: torch.Size = torch.Size(), | |
event_shape: torch.Size = torch.Size(), | |
validate_args: Optional[bool] = None, | |
): | |
self._batch_shape = batch_shape | |
self._event_shape = event_shape | |
if validate_args is not None: | |
self._validate_args = validate_args | |
if self._validate_args: | |
try: | |
arg_constraints = self.arg_constraints | |
except NotImplementedError: | |
arg_constraints = {} | |
warnings.warn( | |
f"{self.__class__} does not define `arg_constraints`. " | |
+ "Please set `arg_constraints = {}` or initialize the distribution " | |
+ "with `validate_args=False` to turn off validation." | |
) | |
for param, constraint in arg_constraints.items(): | |
if constraints.is_dependent(constraint): | |
continue # skip constraints that cannot be checked | |
if param not in self.__dict__ and isinstance( | |
getattr(type(self), param), lazy_property | |
): | |
continue # skip checking lazily-constructed args | |
value = getattr(self, param) | |
valid = constraint.check(value) | |
if not valid.all(): | |
raise ValueError( | |
f"Expected parameter {param} " | |
f"({type(value).__name__} of shape {tuple(value.shape)}) " | |
f"of distribution {repr(self)} " | |
f"to satisfy the constraint {repr(constraint)}, " | |
f"but found invalid values:\n{value}" | |
) | |
super().__init__() | |
def expand(self, batch_shape: torch.Size, _instance=None): | |
""" | |
Returns a new distribution instance (or populates an existing instance | |
provided by a derived class) with batch dimensions expanded to | |
`batch_shape`. This method calls :class:`~torch.Tensor.expand` on | |
the distribution's parameters. As such, this does not allocate new | |
memory for the expanded distribution instance. Additionally, | |
this does not repeat any args checking or parameter broadcasting in | |
`__init__.py`, when an instance is first created. | |
Args: | |
batch_shape (torch.Size): the desired expanded size. | |
_instance: new instance provided by subclasses that | |
need to override `.expand`. | |
Returns: | |
New distribution instance with batch dimensions expanded to | |
`batch_size`. | |
""" | |
raise NotImplementedError | |
def batch_shape(self) -> torch.Size: | |
""" | |
Returns the shape over which parameters are batched. | |
""" | |
return self._batch_shape | |
def event_shape(self) -> torch.Size: | |
""" | |
Returns the shape of a single sample (without batching). | |
""" | |
return self._event_shape | |
def arg_constraints(self) -> Dict[str, constraints.Constraint]: | |
""" | |
Returns a dictionary from argument names to | |
:class:`~torch.distributions.constraints.Constraint` objects that | |
should be satisfied by each argument of this distribution. Args that | |
are not tensors need not appear in this dict. | |
""" | |
raise NotImplementedError | |
def support(self) -> Optional[Any]: | |
""" | |
Returns a :class:`~torch.distributions.constraints.Constraint` object | |
representing this distribution's support. | |
""" | |
raise NotImplementedError | |
def mean(self) -> torch.Tensor: | |
""" | |
Returns the mean of the distribution. | |
""" | |
raise NotImplementedError | |
def mode(self) -> torch.Tensor: | |
""" | |
Returns the mode of the distribution. | |
""" | |
raise NotImplementedError(f"{self.__class__} does not implement mode") | |
def variance(self) -> torch.Tensor: | |
""" | |
Returns the variance of the distribution. | |
""" | |
raise NotImplementedError | |
def stddev(self) -> torch.Tensor: | |
""" | |
Returns the standard deviation of the distribution. | |
""" | |
return self.variance.sqrt() | |
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: | |
""" | |
Generates a sample_shape shaped sample or sample_shape shaped batch of | |
samples if the distribution parameters are batched. | |
""" | |
with torch.no_grad(): | |
return self.rsample(sample_shape) | |
def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: | |
""" | |
Generates a sample_shape shaped reparameterized sample or sample_shape | |
shaped batch of reparameterized samples if the distribution parameters | |
are batched. | |
""" | |
raise NotImplementedError | |
def sample_n(self, n: int) -> torch.Tensor: | |
""" | |
Generates n samples or n batches of samples if the distribution | |
parameters are batched. | |
""" | |
warnings.warn( | |
"sample_n will be deprecated. Use .sample((n,)) instead", UserWarning | |
) | |
return self.sample(torch.Size((n,))) | |
def log_prob(self, value: torch.Tensor) -> torch.Tensor: | |
""" | |
Returns the log of the probability density/mass function evaluated at | |
`value`. | |
Args: | |
value (Tensor): | |
""" | |
raise NotImplementedError | |
def cdf(self, value: torch.Tensor) -> torch.Tensor: | |
""" | |
Returns the cumulative density/mass function evaluated at | |
`value`. | |
Args: | |
value (Tensor): | |
""" | |
raise NotImplementedError | |
def icdf(self, value: torch.Tensor) -> torch.Tensor: | |
""" | |
Returns the inverse cumulative density/mass function evaluated at | |
`value`. | |
Args: | |
value (Tensor): | |
""" | |
raise NotImplementedError | |
def enumerate_support(self, expand: bool = True) -> torch.Tensor: | |
""" | |
Returns tensor containing all values supported by a discrete | |
distribution. The result will enumerate over dimension 0, so the shape | |
of the result will be `(cardinality,) + batch_shape + event_shape` | |
(where `event_shape = ()` for univariate distributions). | |
Note that this enumerates over all batched tensors in lock-step | |
`[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens | |
along dim 0, but with the remaining batch dimensions being | |
singleton dimensions, `[[0], [1], ..`. | |
To iterate over the full Cartesian product use | |
`itertools.product(m.enumerate_support())`. | |
Args: | |
expand (bool): whether to expand the support over the | |
batch dims to match the distribution's `batch_shape`. | |
Returns: | |
Tensor iterating over dimension 0. | |
""" | |
raise NotImplementedError | |
def entropy(self) -> torch.Tensor: | |
""" | |
Returns entropy of distribution, batched over batch_shape. | |
Returns: | |
Tensor of shape batch_shape. | |
""" | |
raise NotImplementedError | |
def perplexity(self) -> torch.Tensor: | |
""" | |
Returns perplexity of distribution, batched over batch_shape. | |
Returns: | |
Tensor of shape batch_shape. | |
""" | |
return torch.exp(self.entropy()) | |
def _extended_shape(self, sample_shape: _size = torch.Size()) -> Tuple[int, ...]: | |
""" | |
Returns the size of the sample returned by the distribution, given | |
a `sample_shape`. Note, that the batch and event shapes of a distribution | |
instance are fixed at the time of construction. If this is empty, the | |
returned shape is upcast to (1,). | |
Args: | |
sample_shape (torch.Size): the size of the sample to be drawn. | |
""" | |
if not isinstance(sample_shape, torch.Size): | |
sample_shape = torch.Size(sample_shape) | |
return torch.Size(sample_shape + self._batch_shape + self._event_shape) | |
def _validate_sample(self, value: torch.Tensor) -> None: | |
""" | |
Argument validation for distribution methods such as `log_prob`, | |
`cdf` and `icdf`. The rightmost dimensions of a value to be | |
scored via these methods must agree with the distribution's batch | |
and event shapes. | |
Args: | |
value (Tensor): the tensor whose log probability is to be | |
computed by the `log_prob` method. | |
Raises | |
ValueError: when the rightmost dimensions of `value` do not match the | |
distribution's batch and event shapes. | |
""" | |
if not isinstance(value, torch.Tensor): | |
raise ValueError("The value argument to log_prob must be a Tensor") | |
event_dim_start = len(value.size()) - len(self._event_shape) | |
if value.size()[event_dim_start:] != self._event_shape: | |
raise ValueError( | |
f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}." | |
) | |
actual_shape = value.size() | |
expected_shape = self._batch_shape + self._event_shape | |
for i, j in zip(reversed(actual_shape), reversed(expected_shape)): | |
if i != 1 and j != 1 and i != j: | |
raise ValueError( | |
f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}." | |
) | |
try: | |
support = self.support | |
except NotImplementedError: | |
warnings.warn( | |
f"{self.__class__} does not define `support` to enable " | |
+ "sample validation. Please initialize the distribution with " | |
+ "`validate_args=False` to turn off validation." | |
) | |
return | |
assert support is not None | |
valid = support.check(value) | |
if not valid.all(): | |
raise ValueError( | |
"Expected value argument " | |
f"({type(value).__name__} of shape {tuple(value.shape)}) " | |
f"to be within the support ({repr(support)}) " | |
f"of the distribution {repr(self)}, " | |
f"but found invalid values:\n{value}" | |
) | |
def _get_checked_instance(self, cls, _instance=None): | |
if _instance is None and type(self).__init__ != cls.__init__: | |
raise NotImplementedError( | |
f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method " | |
"must also define a custom .expand() method." | |
) | |
return self.__new__(type(self)) if _instance is None else _instance | |
def __repr__(self) -> str: | |
param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] | |
args_string = ", ".join( | |
[ | |
f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}" | |
for p in param_names | |
] | |
) | |
return self.__class__.__name__ + "(" + args_string + ")" | |