Spaces:
Sleeping
Sleeping
""" | |
Joint Random Variables Module | |
See Also | |
======== | |
sympy.stats.rv | |
sympy.stats.frv | |
sympy.stats.crv | |
sympy.stats.drv | |
""" | |
from math import prod | |
from sympy.core.basic import Basic | |
from sympy.core.function import Lambda | |
from sympy.core.singleton import S | |
from sympy.core.symbol import (Dummy, Symbol) | |
from sympy.core.sympify import sympify | |
from sympy.sets.sets import ProductSet | |
from sympy.tensor.indexed import Indexed | |
from sympy.concrete.products import Product | |
from sympy.concrete.summations import Sum, summation | |
from sympy.core.containers import Tuple | |
from sympy.integrals.integrals import Integral, integrate | |
from sympy.matrices import ImmutableMatrix, matrix2numpy, list2numpy | |
from sympy.stats.crv import SingleContinuousDistribution, SingleContinuousPSpace | |
from sympy.stats.drv import SingleDiscreteDistribution, SingleDiscretePSpace | |
from sympy.stats.rv import (ProductPSpace, NamedArgsMixin, Distribution, | |
ProductDomain, RandomSymbol, random_symbols, | |
SingleDomain, _symbol_converter) | |
from sympy.utilities.iterables import iterable | |
from sympy.utilities.misc import filldedent | |
from sympy.external import import_module | |
# __all__ = ['marginal_distribution'] | |
class JointPSpace(ProductPSpace): | |
""" | |
Represents a joint probability space. Represented using symbols for | |
each component and a distribution. | |
""" | |
def __new__(cls, sym, dist): | |
if isinstance(dist, SingleContinuousDistribution): | |
return SingleContinuousPSpace(sym, dist) | |
if isinstance(dist, SingleDiscreteDistribution): | |
return SingleDiscretePSpace(sym, dist) | |
sym = _symbol_converter(sym) | |
return Basic.__new__(cls, sym, dist) | |
def set(self): | |
return self.domain.set | |
def symbol(self): | |
return self.args[0] | |
def distribution(self): | |
return self.args[1] | |
def value(self): | |
return JointRandomSymbol(self.symbol, self) | |
def component_count(self): | |
_set = self.distribution.set | |
if isinstance(_set, ProductSet): | |
return S(len(_set.args)) | |
elif isinstance(_set, Product): | |
return _set.limits[0][-1] | |
return S.One | |
def pdf(self): | |
sym = [Indexed(self.symbol, i) for i in range(self.component_count)] | |
return self.distribution(*sym) | |
def domain(self): | |
rvs = random_symbols(self.distribution) | |
if not rvs: | |
return SingleDomain(self.symbol, self.distribution.set) | |
return ProductDomain(*[rv.pspace.domain for rv in rvs]) | |
def component_domain(self, index): | |
return self.set.args[index] | |
def marginal_distribution(self, *indices): | |
count = self.component_count | |
if count.atoms(Symbol): | |
raise ValueError("Marginal distributions cannot be computed " | |
"for symbolic dimensions. It is a work under progress.") | |
orig = [Indexed(self.symbol, i) for i in range(count)] | |
all_syms = [Symbol(str(i)) for i in orig] | |
replace_dict = dict(zip(all_syms, orig)) | |
sym = tuple(Symbol(str(Indexed(self.symbol, i))) for i in indices) | |
limits = [[i,] for i in all_syms if i not in sym] | |
index = 0 | |
for i in range(count): | |
if i not in indices: | |
limits[index].append(self.distribution.set.args[i]) | |
limits[index] = tuple(limits[index]) | |
index += 1 | |
if self.distribution.is_Continuous: | |
f = Lambda(sym, integrate(self.distribution(*all_syms), *limits)) | |
elif self.distribution.is_Discrete: | |
f = Lambda(sym, summation(self.distribution(*all_syms), *limits)) | |
return f.xreplace(replace_dict) | |
def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): | |
syms = tuple(self.value[i] for i in range(self.component_count)) | |
rvs = rvs or syms | |
if not any(i in rvs for i in syms): | |
return expr | |
expr = expr*self.pdf | |
for rv in rvs: | |
if isinstance(rv, Indexed): | |
expr = expr.xreplace({rv: Indexed(str(rv.base), rv.args[1])}) | |
elif isinstance(rv, RandomSymbol): | |
expr = expr.xreplace({rv: rv.symbol}) | |
if self.value in random_symbols(expr): | |
raise NotImplementedError(filldedent(''' | |
Expectations of expression with unindexed joint random symbols | |
cannot be calculated yet.''')) | |
limits = tuple((Indexed(str(rv.base),rv.args[1]), | |
self.distribution.set.args[rv.args[1]]) for rv in syms) | |
return Integral(expr, *limits) | |
def where(self, condition): | |
raise NotImplementedError() | |
def compute_density(self, expr): | |
raise NotImplementedError() | |
def sample(self, size=(), library='scipy', seed=None): | |
""" | |
Internal sample method | |
Returns dictionary mapping RandomSymbol to realization value. | |
""" | |
return {RandomSymbol(self.symbol, self): self.distribution.sample(size, | |
library=library, seed=seed)} | |
def probability(self, condition): | |
raise NotImplementedError() | |
class SampleJointScipy: | |
"""Returns the sample from scipy of the given distribution""" | |
def __new__(cls, dist, size, seed=None): | |
return cls._sample_scipy(dist, size, seed) | |
def _sample_scipy(cls, dist, size, seed): | |
"""Sample from SciPy.""" | |
import numpy | |
if seed is None or isinstance(seed, int): | |
rand_state = numpy.random.default_rng(seed=seed) | |
else: | |
rand_state = seed | |
from scipy import stats as scipy_stats | |
scipy_rv_map = { | |
'MultivariateNormalDistribution': lambda dist, size: scipy_stats.multivariate_normal.rvs( | |
mean=matrix2numpy(dist.mu).flatten(), | |
cov=matrix2numpy(dist.sigma), size=size, random_state=rand_state), | |
'MultivariateBetaDistribution': lambda dist, size: scipy_stats.dirichlet.rvs( | |
alpha=list2numpy(dist.alpha, float).flatten(), size=size, random_state=rand_state), | |
'MultinomialDistribution': lambda dist, size: scipy_stats.multinomial.rvs( | |
n=int(dist.n), p=list2numpy(dist.p, float).flatten(), size=size, random_state=rand_state) | |
} | |
sample_shape = { | |
'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape, | |
'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape, | |
'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape | |
} | |
dist_list = scipy_rv_map.keys() | |
if dist.__class__.__name__ not in dist_list: | |
return None | |
samples = scipy_rv_map[dist.__class__.__name__](dist, size) | |
return samples.reshape(size + sample_shape[dist.__class__.__name__](dist)) | |
class SampleJointNumpy: | |
"""Returns the sample from numpy of the given distribution""" | |
def __new__(cls, dist, size, seed=None): | |
return cls._sample_numpy(dist, size, seed) | |
def _sample_numpy(cls, dist, size, seed): | |
"""Sample from NumPy.""" | |
import numpy | |
if seed is None or isinstance(seed, int): | |
rand_state = numpy.random.default_rng(seed=seed) | |
else: | |
rand_state = seed | |
numpy_rv_map = { | |
'MultivariateNormalDistribution': lambda dist, size: rand_state.multivariate_normal( | |
mean=matrix2numpy(dist.mu, float).flatten(), | |
cov=matrix2numpy(dist.sigma, float), size=size), | |
'MultivariateBetaDistribution': lambda dist, size: rand_state.dirichlet( | |
alpha=list2numpy(dist.alpha, float).flatten(), size=size), | |
'MultinomialDistribution': lambda dist, size: rand_state.multinomial( | |
n=int(dist.n), pvals=list2numpy(dist.p, float).flatten(), size=size) | |
} | |
sample_shape = { | |
'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape, | |
'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape, | |
'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape | |
} | |
dist_list = numpy_rv_map.keys() | |
if dist.__class__.__name__ not in dist_list: | |
return None | |
samples = numpy_rv_map[dist.__class__.__name__](dist, prod(size)) | |
return samples.reshape(size + sample_shape[dist.__class__.__name__](dist)) | |
class SampleJointPymc: | |
"""Returns the sample from pymc of the given distribution""" | |
def __new__(cls, dist, size, seed=None): | |
return cls._sample_pymc(dist, size, seed) | |
def _sample_pymc(cls, dist, size, seed): | |
"""Sample from PyMC.""" | |
try: | |
import pymc | |
except ImportError: | |
import pymc3 as pymc | |
pymc_rv_map = { | |
'MultivariateNormalDistribution': lambda dist: | |
pymc.MvNormal('X', mu=matrix2numpy(dist.mu, float).flatten(), | |
cov=matrix2numpy(dist.sigma, float), shape=(1, dist.mu.shape[0])), | |
'MultivariateBetaDistribution': lambda dist: | |
pymc.Dirichlet('X', a=list2numpy(dist.alpha, float).flatten()), | |
'MultinomialDistribution': lambda dist: | |
pymc.Multinomial('X', n=int(dist.n), | |
p=list2numpy(dist.p, float).flatten(), shape=(1, len(dist.p))) | |
} | |
sample_shape = { | |
'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape, | |
'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape, | |
'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape | |
} | |
dist_list = pymc_rv_map.keys() | |
if dist.__class__.__name__ not in dist_list: | |
return None | |
import logging | |
logging.getLogger("pymc3").setLevel(logging.ERROR) | |
with pymc.Model(): | |
pymc_rv_map[dist.__class__.__name__](dist) | |
samples = pymc.sample(draws=prod(size), chains=1, progressbar=False, random_seed=seed, return_inferencedata=False, compute_convergence_checks=False)[:]['X'] | |
return samples.reshape(size + sample_shape[dist.__class__.__name__](dist)) | |
_get_sample_class_jrv = { | |
'scipy': SampleJointScipy, | |
'pymc3': SampleJointPymc, | |
'pymc': SampleJointPymc, | |
'numpy': SampleJointNumpy | |
} | |
class JointDistribution(Distribution, NamedArgsMixin): | |
""" | |
Represented by the random variables part of the joint distribution. | |
Contains methods for PDF, CDF, sampling, marginal densities, etc. | |
""" | |
_argnames = ('pdf', ) | |
def __new__(cls, *args): | |
args = list(map(sympify, args)) | |
for i in range(len(args)): | |
if isinstance(args[i], list): | |
args[i] = ImmutableMatrix(args[i]) | |
return Basic.__new__(cls, *args) | |
def domain(self): | |
return ProductDomain(self.symbols) | |
def pdf(self): | |
return self.density.args[1] | |
def cdf(self, other): | |
if not isinstance(other, dict): | |
raise ValueError("%s should be of type dict, got %s"%(other, type(other))) | |
rvs = other.keys() | |
_set = self.domain.set.sets | |
expr = self.pdf(tuple(i.args[0] for i in self.symbols)) | |
for i in range(len(other)): | |
if rvs[i].is_Continuous: | |
density = Integral(expr, (rvs[i], _set[i].inf, | |
other[rvs[i]])) | |
elif rvs[i].is_Discrete: | |
density = Sum(expr, (rvs[i], _set[i].inf, | |
other[rvs[i]])) | |
return density | |
def sample(self, size=(), library='scipy', seed=None): | |
""" A random realization from the distribution """ | |
libraries = ('scipy', 'numpy', 'pymc3', 'pymc') | |
if library not in libraries: | |
raise NotImplementedError("Sampling from %s is not supported yet." | |
% str(library)) | |
if not import_module(library): | |
raise ValueError("Failed to import %s" % library) | |
samps = _get_sample_class_jrv[library](self, size, seed=seed) | |
if samps is not None: | |
return samps | |
raise NotImplementedError( | |
"Sampling for %s is not currently implemented from %s" | |
% (self.__class__.__name__, library) | |
) | |
def __call__(self, *args): | |
return self.pdf(*args) | |
class JointRandomSymbol(RandomSymbol): | |
""" | |
Representation of random symbols with joint probability distributions | |
to allow indexing." | |
""" | |
def __getitem__(self, key): | |
if isinstance(self.pspace, JointPSpace): | |
if (self.pspace.component_count <= key) == True: | |
raise ValueError("Index keys for %s can only up to %s." % | |
(self.name, self.pspace.component_count - 1)) | |
return Indexed(self, key) | |
class MarginalDistribution(Distribution): | |
""" | |
Represents the marginal distribution of a joint probability space. | |
Initialised using a probability distribution and random variables(or | |
their indexed components) which should be a part of the resultant | |
distribution. | |
""" | |
def __new__(cls, dist, *rvs): | |
if len(rvs) == 1 and iterable(rvs[0]): | |
rvs = tuple(rvs[0]) | |
if not all(isinstance(rv, (Indexed, RandomSymbol)) for rv in rvs): | |
raise ValueError(filldedent('''Marginal distribution can be | |
intitialised only in terms of random variables or indexed random | |
variables''')) | |
rvs = Tuple.fromiter(rv for rv in rvs) | |
if not isinstance(dist, JointDistribution) and len(random_symbols(dist)) == 0: | |
return dist | |
return Basic.__new__(cls, dist, rvs) | |
def check(self): | |
pass | |
def set(self): | |
rvs = [i for i in self.args[1] if isinstance(i, RandomSymbol)] | |
return ProductSet(*[rv.pspace.set for rv in rvs]) | |
def symbols(self): | |
rvs = self.args[1] | |
return {rv.pspace.symbol for rv in rvs} | |
def pdf(self, *x): | |
expr, rvs = self.args[0], self.args[1] | |
marginalise_out = [i for i in random_symbols(expr) if i not in rvs] | |
if isinstance(expr, JointDistribution): | |
count = len(expr.domain.args) | |
x = Dummy('x', real=True) | |
syms = tuple(Indexed(x, i) for i in count) | |
expr = expr.pdf(syms) | |
else: | |
syms = tuple(rv.pspace.symbol if isinstance(rv, RandomSymbol) else rv.args[0] for rv in rvs) | |
return Lambda(syms, self.compute_pdf(expr, marginalise_out))(*x) | |
def compute_pdf(self, expr, rvs): | |
for rv in rvs: | |
lpdf = 1 | |
if isinstance(rv, RandomSymbol): | |
lpdf = rv.pspace.pdf | |
expr = self.marginalise_out(expr*lpdf, rv) | |
return expr | |
def marginalise_out(self, expr, rv): | |
from sympy.concrete.summations import Sum | |
if isinstance(rv, RandomSymbol): | |
dom = rv.pspace.set | |
elif isinstance(rv, Indexed): | |
dom = rv.base.component_domain( | |
rv.pspace.component_domain(rv.args[1])) | |
expr = expr.xreplace({rv: rv.pspace.symbol}) | |
if rv.pspace.is_Continuous: | |
#TODO: Modify to support integration | |
#for all kinds of sets. | |
expr = Integral(expr, (rv.pspace.symbol, dom)) | |
elif rv.pspace.is_Discrete: | |
#incorporate this into `Sum`/`summation` | |
if dom in (S.Integers, S.Naturals, S.Naturals0): | |
dom = (dom.inf, dom.sup) | |
expr = Sum(expr, (rv.pspace.symbol, dom)) | |
return expr | |
def __call__(self, *args): | |
return self.pdf(*args) | |