Spaces:
Sleeping
Sleeping
from sympy.concrete.summations import Sum | |
from sympy.core.basic import Basic | |
from sympy.core.containers import Tuple | |
from sympy.core.function import Lambda | |
from sympy.core.numbers import (Rational, nan, oo, pi) | |
from sympy.core.relational import Eq | |
from sympy.core.singleton import S | |
from sympy.core.symbol import (Symbol, symbols) | |
from sympy.functions.combinatorial.factorials import (FallingFactorial, binomial) | |
from sympy.functions.elementary.exponential import (exp, log) | |
from sympy.functions.elementary.trigonometric import (cos, sin) | |
from sympy.functions.special.delta_functions import DiracDelta | |
from sympy.integrals.integrals import integrate | |
from sympy.logic.boolalg import (And, Or) | |
from sympy.matrices.dense import Matrix | |
from sympy.sets.sets import Interval | |
from sympy.tensor.indexed import Indexed | |
from sympy.stats import (Die, Normal, Exponential, FiniteRV, P, E, H, variance, | |
density, given, independent, dependent, where, pspace, GaussianUnitaryEnsemble, | |
random_symbols, sample, Geometric, factorial_moment, Binomial, Hypergeometric, | |
DiscreteUniform, Poisson, characteristic_function, moment_generating_function, | |
BernoulliProcess, Variance, Expectation, Probability, Covariance, covariance, cmoment, | |
moment, median) | |
from sympy.stats.rv import (IndependentProductPSpace, rs_swap, Density, NamedArgsMixin, | |
RandomSymbol, sample_iter, PSpace, is_random, RandomIndexedSymbol, RandomMatrixSymbol) | |
from sympy.testing.pytest import raises, skip, XFAIL, warns_deprecated_sympy | |
from sympy.external import import_module | |
from sympy.core.numbers import comp | |
from sympy.stats.frv_types import BernoulliDistribution | |
from sympy.core.symbol import Dummy | |
from sympy.functions.elementary.piecewise import Piecewise | |
def test_where(): | |
X, Y = Die('X'), Die('Y') | |
Z = Normal('Z', 0, 1) | |
assert where(Z**2 <= 1).set == Interval(-1, 1) | |
assert where(Z**2 <= 1).as_boolean() == Interval(-1, 1).as_relational(Z.symbol) | |
assert where(And(X > Y, Y > 4)).as_boolean() == And( | |
Eq(X.symbol, 6), Eq(Y.symbol, 5)) | |
assert len(where(X < 3).set) == 2 | |
assert 1 in where(X < 3).set | |
X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) | |
assert where(And(X**2 <= 1, X >= 0)).set == Interval(0, 1) | |
XX = given(X, And(X**2 <= 1, X >= 0)) | |
assert XX.pspace.domain.set == Interval(0, 1) | |
assert XX.pspace.domain.as_boolean() == \ | |
And(0 <= X.symbol, X.symbol**2 <= 1, -oo < X.symbol, X.symbol < oo) | |
with raises(TypeError): | |
XX = given(X, X + 3) | |
def test_random_symbols(): | |
X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) | |
assert set(random_symbols(2*X + 1)) == {X} | |
assert set(random_symbols(2*X + Y)) == {X, Y} | |
assert set(random_symbols(2*X + Y.symbol)) == {X} | |
assert set(random_symbols(2)) == set() | |
def test_characteristic_function(): | |
# Imports I from sympy | |
from sympy.core.numbers import I | |
X = Normal('X',0,1) | |
Y = DiscreteUniform('Y', [1,2,7]) | |
Z = Poisson('Z', 2) | |
t = symbols('_t') | |
P = Lambda(t, exp(-t**2/2)) | |
Q = Lambda(t, exp(7*t*I)/3 + exp(2*t*I)/3 + exp(t*I)/3) | |
R = Lambda(t, exp(2 * exp(t*I) - 2)) | |
assert characteristic_function(X).dummy_eq(P) | |
assert characteristic_function(Y).dummy_eq(Q) | |
assert characteristic_function(Z).dummy_eq(R) | |
def test_moment_generating_function(): | |
X = Normal('X',0,1) | |
Y = DiscreteUniform('Y', [1,2,7]) | |
Z = Poisson('Z', 2) | |
t = symbols('_t') | |
P = Lambda(t, exp(t**2/2)) | |
Q = Lambda(t, (exp(7*t)/3 + exp(2*t)/3 + exp(t)/3)) | |
R = Lambda(t, exp(2 * exp(t) - 2)) | |
assert moment_generating_function(X).dummy_eq(P) | |
assert moment_generating_function(Y).dummy_eq(Q) | |
assert moment_generating_function(Z).dummy_eq(R) | |
def test_sample_iter(): | |
X = Normal('X',0,1) | |
Y = DiscreteUniform('Y', [1, 2, 7]) | |
Z = Poisson('Z', 2) | |
scipy = import_module('scipy') | |
if not scipy: | |
skip('Scipy is not installed. Abort tests') | |
expr = X**2 + 3 | |
iterator = sample_iter(expr) | |
expr2 = Y**2 + 5*Y + 4 | |
iterator2 = sample_iter(expr2) | |
expr3 = Z**3 + 4 | |
iterator3 = sample_iter(expr3) | |
def is_iterator(obj): | |
if ( | |
hasattr(obj, '__iter__') and | |
(hasattr(obj, 'next') or | |
hasattr(obj, '__next__')) and | |
callable(obj.__iter__) and | |
obj.__iter__() is obj | |
): | |
return True | |
else: | |
return False | |
assert is_iterator(iterator) | |
assert is_iterator(iterator2) | |
assert is_iterator(iterator3) | |
def test_pspace(): | |
X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) | |
x = Symbol('x') | |
raises(ValueError, lambda: pspace(5 + 3)) | |
raises(ValueError, lambda: pspace(x < 1)) | |
assert pspace(X) == X.pspace | |
assert pspace(2*X + 1) == X.pspace | |
assert pspace(2*X + Y) == IndependentProductPSpace(Y.pspace, X.pspace) | |
def test_rs_swap(): | |
X = Normal('x', 0, 1) | |
Y = Exponential('y', 1) | |
XX = Normal('x', 0, 2) | |
YY = Normal('y', 0, 3) | |
expr = 2*X + Y | |
assert expr.subs(rs_swap((X, Y), (YY, XX))) == 2*XX + YY | |
def test_RandomSymbol(): | |
X = Normal('x', 0, 1) | |
Y = Normal('x', 0, 2) | |
assert X.symbol == Y.symbol | |
assert X != Y | |
assert X.name == X.symbol.name | |
X = Normal('lambda', 0, 1) # make sure we can use protected terms | |
X = Normal('Lambda', 0, 1) # make sure we can use SymPy terms | |
def test_RandomSymbol_diff(): | |
X = Normal('x', 0, 1) | |
assert (2*X).diff(X) | |
def test_random_symbol_no_pspace(): | |
x = RandomSymbol(Symbol('x')) | |
assert x.pspace == PSpace() | |
def test_overlap(): | |
X = Normal('x', 0, 1) | |
Y = Normal('x', 0, 2) | |
raises(ValueError, lambda: P(X > Y)) | |
def test_IndependentProductPSpace(): | |
X = Normal('X', 0, 1) | |
Y = Normal('Y', 0, 1) | |
px = X.pspace | |
py = Y.pspace | |
assert pspace(X + Y) == IndependentProductPSpace(px, py) | |
assert pspace(X + Y) == IndependentProductPSpace(py, px) | |
def test_E(): | |
assert E(5) == 5 | |
def test_H(): | |
X = Normal('X', 0, 1) | |
D = Die('D', sides = 4) | |
G = Geometric('G', 0.5) | |
assert H(X, X > 0) == -log(2)/2 + S.Half + log(pi)/2 | |
assert H(D, D > 2) == log(2) | |
assert comp(H(G).evalf().round(2), 1.39) | |
def test_Sample(): | |
X = Die('X', 6) | |
Y = Normal('Y', 0, 1) | |
z = Symbol('z', integer=True) | |
scipy = import_module('scipy') | |
if not scipy: | |
skip('Scipy is not installed. Abort tests') | |
assert sample(X) in [1, 2, 3, 4, 5, 6] | |
assert isinstance(sample(X + Y), float) | |
assert P(X + Y > 0, Y < 0, numsamples=10).is_number | |
assert E(X + Y, numsamples=10).is_number | |
assert E(X**2 + Y, numsamples=10).is_number | |
assert E((X + Y)**2, numsamples=10).is_number | |
assert variance(X + Y, numsamples=10).is_number | |
raises(TypeError, lambda: P(Y > z, numsamples=5)) | |
assert P(sin(Y) <= 1, numsamples=10) == 1.0 | |
assert P(sin(Y) <= 1, cos(Y) < 1, numsamples=10) == 1.0 | |
assert all(i in range(1, 7) for i in density(X, numsamples=10)) | |
assert all(i in range(4, 7) for i in density(X, X>3, numsamples=10)) | |
numpy = import_module('numpy') | |
if not numpy: | |
skip('Numpy is not installed. Abort tests') | |
#Test Issue #21563: Output of sample must be a float or array | |
assert isinstance(sample(X), (numpy.int32, numpy.int64)) | |
assert isinstance(sample(Y), numpy.float64) | |
assert isinstance(sample(X, size=2), numpy.ndarray) | |
with warns_deprecated_sympy(): | |
sample(X, numsamples=2) | |
def test_samplingE(): | |
scipy = import_module('scipy') | |
if not scipy: | |
skip('Scipy is not installed. Abort tests') | |
Y = Normal('Y', 0, 1) | |
z = Symbol('z', integer=True) | |
assert E(Sum(1/z**Y, (z, 1, oo)), Y > 2, numsamples=3).is_number | |
def test_given(): | |
X = Normal('X', 0, 1) | |
Y = Normal('Y', 0, 1) | |
A = given(X, True) | |
B = given(X, Y > 2) | |
assert X == A == B | |
def test_factorial_moment(): | |
X = Poisson('X', 2) | |
Y = Binomial('Y', 2, S.Half) | |
Z = Hypergeometric('Z', 4, 2, 2) | |
assert factorial_moment(X, 2) == 4 | |
assert factorial_moment(Y, 2) == S.Half | |
assert factorial_moment(Z, 2) == Rational(1, 3) | |
x, y, z, l = symbols('x y z l') | |
Y = Binomial('Y', 2, y) | |
Z = Hypergeometric('Z', 10, 2, 3) | |
assert factorial_moment(Y, l) == y**2*FallingFactorial( | |
2, l) + 2*y*(1 - y)*FallingFactorial(1, l) + (1 - y)**2*\ | |
FallingFactorial(0, l) | |
assert factorial_moment(Z, l) == 7*FallingFactorial(0, l)/\ | |
15 + 7*FallingFactorial(1, l)/15 + FallingFactorial(2, l)/15 | |
def test_dependence(): | |
X, Y = Die('X'), Die('Y') | |
assert independent(X, 2*Y) | |
assert not dependent(X, 2*Y) | |
X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) | |
assert independent(X, Y) | |
assert dependent(X, 2*X) | |
# Create a dependency | |
XX, YY = given(Tuple(X, Y), Eq(X + Y, 3)) | |
assert dependent(XX, YY) | |
def test_dependent_finite(): | |
X, Y = Die('X'), Die('Y') | |
# Dependence testing requires symbolic conditions which currently break | |
# finite random variables | |
assert dependent(X, Y + X) | |
XX, YY = given(Tuple(X, Y), X + Y > 5) # Create a dependency | |
assert dependent(XX, YY) | |
def test_normality(): | |
X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) | |
x = Symbol('x', real=True) | |
z = Symbol('z', real=True) | |
dens = density(X - Y, Eq(X + Y, z)) | |
assert integrate(dens(x), (x, -oo, oo)) == 1 | |
def test_Density(): | |
X = Die('X', 6) | |
d = Density(X) | |
assert d.doit() == density(X) | |
def test_NamedArgsMixin(): | |
class Foo(Basic, NamedArgsMixin): | |
_argnames = 'foo', 'bar' | |
a = Foo(S(1), S(2)) | |
assert a.foo == 1 | |
assert a.bar == 2 | |
raises(AttributeError, lambda: a.baz) | |
class Bar(Basic, NamedArgsMixin): | |
pass | |
raises(AttributeError, lambda: Bar(S(1), S(2)).foo) | |
def test_density_constant(): | |
assert density(3)(2) == 0 | |
assert density(3)(3) == DiracDelta(0) | |
def test_cmoment_constant(): | |
assert variance(3) == 0 | |
assert cmoment(3, 3) == 0 | |
assert cmoment(3, 4) == 0 | |
x = Symbol('x') | |
assert variance(x) == 0 | |
assert cmoment(x, 15) == 0 | |
assert cmoment(x, 0) == 1 | |
def test_moment_constant(): | |
assert moment(3, 0) == 1 | |
assert moment(3, 1) == 3 | |
assert moment(3, 2) == 9 | |
x = Symbol('x') | |
assert moment(x, 2) == x**2 | |
def test_median_constant(): | |
assert median(3) == 3 | |
x = Symbol('x') | |
assert median(x) == x | |
def test_real(): | |
x = Normal('x', 0, 1) | |
assert x.is_real | |
def test_issue_10052(): | |
X = Exponential('X', 3) | |
assert P(X < oo) == 1 | |
assert P(X > oo) == 0 | |
assert P(X < 2, X > oo) == 0 | |
assert P(X < oo, X > oo) == 0 | |
assert P(X < oo, X > 2) == 1 | |
assert P(X < 3, X == 2) == 0 | |
raises(ValueError, lambda: P(1)) | |
raises(ValueError, lambda: P(X < 1, 2)) | |
def test_issue_11934(): | |
density = {0: .5, 1: .5} | |
X = FiniteRV('X', density) | |
assert E(X) == 0.5 | |
assert P( X>= 2) == 0 | |
def test_issue_8129(): | |
X = Exponential('X', 4) | |
assert P(X >= X) == 1 | |
assert P(X > X) == 0 | |
assert P(X > X+1) == 0 | |
def test_issue_12237(): | |
X = Normal('X', 0, 1) | |
Y = Normal('Y', 0, 1) | |
U = P(X > 0, X) | |
V = P(Y < 0, X) | |
W = P(X + Y > 0, X) | |
assert W == P(X + Y > 0, X) | |
assert U == BernoulliDistribution(S.Half, S.Zero, S.One) | |
assert V == S.Half | |
def test_is_random(): | |
X = Normal('X', 0, 1) | |
Y = Normal('Y', 0, 1) | |
a, b = symbols('a, b') | |
G = GaussianUnitaryEnsemble('U', 2) | |
B = BernoulliProcess('B', 0.9) | |
assert not is_random(a) | |
assert not is_random(a + b) | |
assert not is_random(a * b) | |
assert not is_random(Matrix([a**2, b**2])) | |
assert is_random(X) | |
assert is_random(X**2 + Y) | |
assert is_random(Y + b**2) | |
assert is_random(Y > 5) | |
assert is_random(B[3] < 1) | |
assert is_random(G) | |
assert is_random(X * Y * B[1]) | |
assert is_random(Matrix([[X, B[2]], [G, Y]])) | |
assert is_random(Eq(X, 4)) | |
def test_issue_12283(): | |
x = symbols('x') | |
X = RandomSymbol(x) | |
Y = RandomSymbol('Y') | |
Z = RandomMatrixSymbol('Z', 2, 1) | |
W = RandomMatrixSymbol('W', 2, 1) | |
RI = RandomIndexedSymbol(Indexed('RI', 3)) | |
assert pspace(Z) == PSpace() | |
assert pspace(RI) == PSpace() | |
assert pspace(X) == PSpace() | |
assert E(X) == Expectation(X) | |
assert P(Y > 3) == Probability(Y > 3) | |
assert variance(X) == Variance(X) | |
assert variance(RI) == Variance(RI) | |
assert covariance(X, Y) == Covariance(X, Y) | |
assert covariance(W, Z) == Covariance(W, Z) | |
def test_issue_6810(): | |
X = Die('X', 6) | |
Y = Normal('Y', 0, 1) | |
assert P(Eq(X, 2)) == S(1)/6 | |
assert P(Eq(Y, 0)) == 0 | |
assert P(Or(X > 2, X < 3)) == 1 | |
assert P(And(X > 3, X > 2)) == S(1)/2 | |
def test_issue_20286(): | |
n, p = symbols('n p') | |
B = Binomial('B', n, p) | |
k = Dummy('k', integer = True) | |
eq = Sum(Piecewise((-p**k*(1 - p)**(-k + n)*log(p**k*(1 - p)**(-k + n)*binomial(n, k))*binomial(n, k), (k >= 0) & (k <= n)), (nan, True)), (k, 0, n)) | |
assert eq.dummy_eq(H(B)) | |