Spaces:
Running
Running
"""Real and complex elements. """ | |
from sympy.external.gmpy import MPQ | |
from sympy.polys.domains.domainelement import DomainElement | |
from sympy.utilities import public | |
from mpmath.ctx_mp_python import PythonMPContext, _mpf, _mpc, _constant | |
from mpmath.libmp import (MPZ_ONE, fzero, fone, finf, fninf, fnan, | |
round_nearest, mpf_mul, repr_dps, int_types, | |
from_int, from_float, from_str, to_rational) | |
class RealElement(_mpf, DomainElement): | |
"""An element of a real domain. """ | |
__slots__ = ('__mpf__',) | |
def _set_mpf(self, val): | |
self.__mpf__ = val | |
_mpf_ = property(lambda self: self.__mpf__, _set_mpf) | |
def parent(self): | |
return self.context._parent | |
class ComplexElement(_mpc, DomainElement): | |
"""An element of a complex domain. """ | |
__slots__ = ('__mpc__',) | |
def _set_mpc(self, val): | |
self.__mpc__ = val | |
_mpc_ = property(lambda self: self.__mpc__, _set_mpc) | |
def parent(self): | |
return self.context._parent | |
new = object.__new__ | |
class MPContext(PythonMPContext): | |
def __init__(ctx, prec=53, dps=None, tol=None, real=False): | |
ctx._prec_rounding = [prec, round_nearest] | |
if dps is None: | |
ctx._set_prec(prec) | |
else: | |
ctx._set_dps(dps) | |
ctx.mpf = RealElement | |
ctx.mpc = ComplexElement | |
ctx.mpf._ctxdata = [ctx.mpf, new, ctx._prec_rounding] | |
ctx.mpc._ctxdata = [ctx.mpc, new, ctx._prec_rounding] | |
if real: | |
ctx.mpf.context = ctx | |
else: | |
ctx.mpc.context = ctx | |
ctx.constant = _constant | |
ctx.constant._ctxdata = [ctx.mpf, new, ctx._prec_rounding] | |
ctx.constant.context = ctx | |
ctx.types = [ctx.mpf, ctx.mpc, ctx.constant] | |
ctx.trap_complex = True | |
ctx.pretty = True | |
if tol is None: | |
ctx.tol = ctx._make_tol() | |
elif tol is False: | |
ctx.tol = fzero | |
else: | |
ctx.tol = ctx._convert_tol(tol) | |
ctx.tolerance = ctx.make_mpf(ctx.tol) | |
if not ctx.tolerance: | |
ctx.max_denom = 1000000 | |
else: | |
ctx.max_denom = int(1/ctx.tolerance) | |
ctx.zero = ctx.make_mpf(fzero) | |
ctx.one = ctx.make_mpf(fone) | |
ctx.j = ctx.make_mpc((fzero, fone)) | |
ctx.inf = ctx.make_mpf(finf) | |
ctx.ninf = ctx.make_mpf(fninf) | |
ctx.nan = ctx.make_mpf(fnan) | |
def _make_tol(ctx): | |
hundred = (0, 25, 2, 5) | |
eps = (0, MPZ_ONE, 1-ctx.prec, 1) | |
return mpf_mul(hundred, eps) | |
def make_tol(ctx): | |
return ctx.make_mpf(ctx._make_tol()) | |
def _convert_tol(ctx, tol): | |
if isinstance(tol, int_types): | |
return from_int(tol) | |
if isinstance(tol, float): | |
return from_float(tol) | |
if hasattr(tol, "_mpf_"): | |
return tol._mpf_ | |
prec, rounding = ctx._prec_rounding | |
if isinstance(tol, str): | |
return from_str(tol, prec, rounding) | |
raise ValueError("expected a real number, got %s" % tol) | |
def _convert_fallback(ctx, x, strings): | |
raise TypeError("cannot create mpf from " + repr(x)) | |
def _repr_digits(ctx): | |
return repr_dps(ctx._prec) | |
def _str_digits(ctx): | |
return ctx._dps | |
def to_rational(ctx, s, limit=True): | |
p, q = to_rational(s._mpf_) | |
# Needed for GROUND_TYPES=flint if gmpy2 is installed because mpmath's | |
# to_rational() function returns a gmpy2.mpz instance and if MPQ is | |
# flint.fmpq then MPQ(p, q) will fail. | |
p = int(p) | |
if not limit or q <= ctx.max_denom: | |
return p, q | |
p0, q0, p1, q1 = 0, 1, 1, 0 | |
n, d = p, q | |
while True: | |
a = n//d | |
q2 = q0 + a*q1 | |
if q2 > ctx.max_denom: | |
break | |
p0, q0, p1, q1 = p1, q1, p0 + a*p1, q2 | |
n, d = d, n - a*d | |
k = (ctx.max_denom - q0)//q1 | |
number = MPQ(p, q) | |
bound1 = MPQ(p0 + k*p1, q0 + k*q1) | |
bound2 = MPQ(p1, q1) | |
if not bound2 or not bound1: | |
return p, q | |
elif abs(bound2 - number) <= abs(bound1 - number): | |
return bound2.numerator, bound2.denominator | |
else: | |
return bound1.numerator, bound1.denominator | |
def almosteq(ctx, s, t, rel_eps=None, abs_eps=None): | |
t = ctx.convert(t) | |
if abs_eps is None and rel_eps is None: | |
rel_eps = abs_eps = ctx.tolerance or ctx.make_tol() | |
if abs_eps is None: | |
abs_eps = ctx.convert(rel_eps) | |
elif rel_eps is None: | |
rel_eps = ctx.convert(abs_eps) | |
diff = abs(s-t) | |
if diff <= abs_eps: | |
return True | |
abss = abs(s) | |
abst = abs(t) | |
if abss < abst: | |
err = diff/abst | |
else: | |
err = diff/abss | |
return err <= rel_eps | |