Spaces:
Sleeping
Sleeping
from sympy.core import S, Pow | |
from sympy.core.function import (Derivative, AppliedUndef, diff) | |
from sympy.core.relational import Equality, Eq | |
from sympy.core.symbol import Dummy | |
from sympy.core.sympify import sympify | |
from sympy.logic.boolalg import BooleanAtom | |
from sympy.functions import exp | |
from sympy.series import Order | |
from sympy.simplify.simplify import simplify, posify, besselsimp | |
from sympy.simplify.trigsimp import trigsimp | |
from sympy.simplify.sqrtdenest import sqrtdenest | |
from sympy.solvers import solve | |
from sympy.solvers.deutils import _preprocess, ode_order | |
from sympy.utilities.iterables import iterable, is_sequence | |
def sub_func_doit(eq, func, new): | |
r""" | |
When replacing the func with something else, we usually want the | |
derivative evaluated, so this function helps in making that happen. | |
Examples | |
======== | |
>>> from sympy import Derivative, symbols, Function | |
>>> from sympy.solvers.ode.subscheck import sub_func_doit | |
>>> x, z = symbols('x, z') | |
>>> y = Function('y') | |
>>> sub_func_doit(3*Derivative(y(x), x) - 1, y(x), x) | |
2 | |
>>> sub_func_doit(x*Derivative(y(x), x) - y(x)**2 + y(x), y(x), | |
... 1/(x*(z + 1/x))) | |
x*(-1/(x**2*(z + 1/x)) + 1/(x**3*(z + 1/x)**2)) + 1/(x*(z + 1/x)) | |
...- 1/(x**2*(z + 1/x)**2) | |
""" | |
reps= {func: new} | |
for d in eq.atoms(Derivative): | |
if d.expr == func: | |
reps[d] = new.diff(*d.variable_count) | |
else: | |
reps[d] = d.xreplace({func: new}).doit(deep=False) | |
return eq.xreplace(reps) | |
def checkodesol(ode, sol, func=None, order='auto', solve_for_func=True): | |
r""" | |
Substitutes ``sol`` into ``ode`` and checks that the result is ``0``. | |
This works when ``func`` is one function, like `f(x)` or a list of | |
functions like `[f(x), g(x)]` when `ode` is a system of ODEs. ``sol`` can | |
be a single solution or a list of solutions. Each solution may be an | |
:py:class:`~sympy.core.relational.Equality` that the solution satisfies, | |
e.g. ``Eq(f(x), C1), Eq(f(x) + C1, 0)``; or simply an | |
:py:class:`~sympy.core.expr.Expr`, e.g. ``f(x) - C1``. In most cases it | |
will not be necessary to explicitly identify the function, but if the | |
function cannot be inferred from the original equation it can be supplied | |
through the ``func`` argument. | |
If a sequence of solutions is passed, the same sort of container will be | |
used to return the result for each solution. | |
It tries the following methods, in order, until it finds zero equivalence: | |
1. Substitute the solution for `f` in the original equation. This only | |
works if ``ode`` is solved for `f`. It will attempt to solve it first | |
unless ``solve_for_func == False``. | |
2. Take `n` derivatives of the solution, where `n` is the order of | |
``ode``, and check to see if that is equal to the solution. This only | |
works on exact ODEs. | |
3. Take the 1st, 2nd, ..., `n`\th derivatives of the solution, each time | |
solving for the derivative of `f` of that order (this will always be | |
possible because `f` is a linear operator). Then back substitute each | |
derivative into ``ode`` in reverse order. | |
This function returns a tuple. The first item in the tuple is ``True`` if | |
the substitution results in ``0``, and ``False`` otherwise. The second | |
item in the tuple is what the substitution results in. It should always | |
be ``0`` if the first item is ``True``. Sometimes this function will | |
return ``False`` even when an expression is identically equal to ``0``. | |
This happens when :py:meth:`~sympy.simplify.simplify.simplify` does not | |
reduce the expression to ``0``. If an expression returned by this | |
function vanishes identically, then ``sol`` really is a solution to | |
the ``ode``. | |
If this function seems to hang, it is probably because of a hard | |
simplification. | |
To use this function to test, test the first item of the tuple. | |
Examples | |
======== | |
>>> from sympy import (Eq, Function, checkodesol, symbols, | |
... Derivative, exp) | |
>>> x, C1, C2 = symbols('x,C1,C2') | |
>>> f, g = symbols('f g', cls=Function) | |
>>> checkodesol(f(x).diff(x), Eq(f(x), C1)) | |
(True, 0) | |
>>> assert checkodesol(f(x).diff(x), C1)[0] | |
>>> assert not checkodesol(f(x).diff(x), x)[0] | |
>>> checkodesol(f(x).diff(x, 2), x**2) | |
(False, 2) | |
>>> eqs = [Eq(Derivative(f(x), x), f(x)), Eq(Derivative(g(x), x), g(x))] | |
>>> sol = [Eq(f(x), C1*exp(x)), Eq(g(x), C2*exp(x))] | |
>>> checkodesol(eqs, sol) | |
(True, [0, 0]) | |
""" | |
if iterable(ode): | |
return checksysodesol(ode, sol, func=func) | |
if not isinstance(ode, Equality): | |
ode = Eq(ode, 0) | |
if func is None: | |
try: | |
_, func = _preprocess(ode.lhs) | |
except ValueError: | |
funcs = [s.atoms(AppliedUndef) for s in ( | |
sol if is_sequence(sol, set) else [sol])] | |
funcs = set().union(*funcs) | |
if len(funcs) != 1: | |
raise ValueError( | |
'must pass func arg to checkodesol for this case.') | |
func = funcs.pop() | |
if not isinstance(func, AppliedUndef) or len(func.args) != 1: | |
raise ValueError( | |
"func must be a function of one variable, not %s" % func) | |
if is_sequence(sol, set): | |
return type(sol)([checkodesol(ode, i, order=order, solve_for_func=solve_for_func) for i in sol]) | |
if not isinstance(sol, Equality): | |
sol = Eq(func, sol) | |
elif sol.rhs == func: | |
sol = sol.reversed | |
if order == 'auto': | |
order = ode_order(ode, func) | |
solved = sol.lhs == func and not sol.rhs.has(func) | |
if solve_for_func and not solved: | |
rhs = solve(sol, func) | |
if rhs: | |
eqs = [Eq(func, t) for t in rhs] | |
if len(rhs) == 1: | |
eqs = eqs[0] | |
return checkodesol(ode, eqs, order=order, | |
solve_for_func=False) | |
x = func.args[0] | |
# Handle series solutions here | |
if sol.has(Order): | |
assert sol.lhs == func | |
Oterm = sol.rhs.getO() | |
solrhs = sol.rhs.removeO() | |
Oexpr = Oterm.expr | |
assert isinstance(Oexpr, Pow) | |
sorder = Oexpr.exp | |
assert Oterm == Order(x**sorder) | |
odesubs = (ode.lhs-ode.rhs).subs(func, solrhs).doit().expand() | |
neworder = Order(x**(sorder - order)) | |
odesubs = odesubs + neworder | |
assert odesubs.getO() == neworder | |
residual = odesubs.removeO() | |
return (residual == 0, residual) | |
s = True | |
testnum = 0 | |
while s: | |
if testnum == 0: | |
# First pass, try substituting a solved solution directly into the | |
# ODE. This has the highest chance of succeeding. | |
ode_diff = ode.lhs - ode.rhs | |
if sol.lhs == func: | |
s = sub_func_doit(ode_diff, func, sol.rhs) | |
s = besselsimp(s) | |
else: | |
testnum += 1 | |
continue | |
ss = simplify(s.rewrite(exp)) | |
if ss: | |
# with the new numer_denom in power.py, if we do a simple | |
# expansion then testnum == 0 verifies all solutions. | |
s = ss.expand(force=True) | |
else: | |
s = 0 | |
testnum += 1 | |
elif testnum == 1: | |
# Second pass. If we cannot substitute f, try seeing if the nth | |
# derivative is equal, this will only work for odes that are exact, | |
# by definition. | |
s = simplify( | |
trigsimp(diff(sol.lhs, x, order) - diff(sol.rhs, x, order)) - | |
trigsimp(ode.lhs) + trigsimp(ode.rhs)) | |
# s2 = simplify( | |
# diff(sol.lhs, x, order) - diff(sol.rhs, x, order) - \ | |
# ode.lhs + ode.rhs) | |
testnum += 1 | |
elif testnum == 2: | |
# Third pass. Try solving for df/dx and substituting that into the | |
# ODE. Thanks to Chris Smith for suggesting this method. Many of | |
# the comments below are his, too. | |
# The method: | |
# - Take each of 1..n derivatives of the solution. | |
# - Solve each nth derivative for d^(n)f/dx^(n) | |
# (the differential of that order) | |
# - Back substitute into the ODE in decreasing order | |
# (i.e., n, n-1, ...) | |
# - Check the result for zero equivalence | |
if sol.lhs == func and not sol.rhs.has(func): | |
diffsols = {0: sol.rhs} | |
elif sol.rhs == func and not sol.lhs.has(func): | |
diffsols = {0: sol.lhs} | |
else: | |
diffsols = {} | |
sol = sol.lhs - sol.rhs | |
for i in range(1, order + 1): | |
# Differentiation is a linear operator, so there should always | |
# be 1 solution. Nonetheless, we test just to make sure. | |
# We only need to solve once. After that, we automatically | |
# have the solution to the differential in the order we want. | |
if i == 1: | |
ds = sol.diff(x) | |
try: | |
sdf = solve(ds, func.diff(x, i)) | |
if not sdf: | |
raise NotImplementedError | |
except NotImplementedError: | |
testnum += 1 | |
break | |
else: | |
diffsols[i] = sdf[0] | |
else: | |
# This is what the solution says df/dx should be. | |
diffsols[i] = diffsols[i - 1].diff(x) | |
# Make sure the above didn't fail. | |
if testnum > 2: | |
continue | |
else: | |
# Substitute it into ODE to check for self consistency. | |
lhs, rhs = ode.lhs, ode.rhs | |
for i in range(order, -1, -1): | |
if i == 0 and 0 not in diffsols: | |
# We can only substitute f(x) if the solution was | |
# solved for f(x). | |
break | |
lhs = sub_func_doit(lhs, func.diff(x, i), diffsols[i]) | |
rhs = sub_func_doit(rhs, func.diff(x, i), diffsols[i]) | |
ode_or_bool = Eq(lhs, rhs) | |
ode_or_bool = simplify(ode_or_bool) | |
if isinstance(ode_or_bool, (bool, BooleanAtom)): | |
if ode_or_bool: | |
lhs = rhs = S.Zero | |
else: | |
lhs = ode_or_bool.lhs | |
rhs = ode_or_bool.rhs | |
# No sense in overworking simplify -- just prove that the | |
# numerator goes to zero | |
num = trigsimp((lhs - rhs).as_numer_denom()[0]) | |
# since solutions are obtained using force=True we test | |
# using the same level of assumptions | |
## replace function with dummy so assumptions will work | |
_func = Dummy('func') | |
num = num.subs(func, _func) | |
## posify the expression | |
num, reps = posify(num) | |
s = simplify(num).xreplace(reps).xreplace({_func: func}) | |
testnum += 1 | |
else: | |
break | |
if not s: | |
return (True, s) | |
elif s is True: # The code above never was able to change s | |
raise NotImplementedError("Unable to test if " + str(sol) + | |
" is a solution to " + str(ode) + ".") | |
else: | |
return (False, s) | |
def checksysodesol(eqs, sols, func=None): | |
r""" | |
Substitutes corresponding ``sols`` for each functions into each ``eqs`` and | |
checks that the result of substitutions for each equation is ``0``. The | |
equations and solutions passed can be any iterable. | |
This only works when each ``sols`` have one function only, like `x(t)` or `y(t)`. | |
For each function, ``sols`` can have a single solution or a list of solutions. | |
In most cases it will not be necessary to explicitly identify the function, | |
but if the function cannot be inferred from the original equation it | |
can be supplied through the ``func`` argument. | |
When a sequence of equations is passed, the same sequence is used to return | |
the result for each equation with each function substituted with corresponding | |
solutions. | |
It tries the following method to find zero equivalence for each equation: | |
Substitute the solutions for functions, like `x(t)` and `y(t)` into the | |
original equations containing those functions. | |
This function returns a tuple. The first item in the tuple is ``True`` if | |
the substitution results for each equation is ``0``, and ``False`` otherwise. | |
The second item in the tuple is what the substitution results in. Each element | |
of the ``list`` should always be ``0`` corresponding to each equation if the | |
first item is ``True``. Note that sometimes this function may return ``False``, | |
but with an expression that is identically equal to ``0``, instead of returning | |
``True``. This is because :py:meth:`~sympy.simplify.simplify.simplify` cannot | |
reduce the expression to ``0``. If an expression returned by each function | |
vanishes identically, then ``sols`` really is a solution to ``eqs``. | |
If this function seems to hang, it is probably because of a difficult simplification. | |
Examples | |
======== | |
>>> from sympy import Eq, diff, symbols, sin, cos, exp, sqrt, S, Function | |
>>> from sympy.solvers.ode.subscheck import checksysodesol | |
>>> C1, C2 = symbols('C1:3') | |
>>> t = symbols('t') | |
>>> x, y = symbols('x, y', cls=Function) | |
>>> eq = (Eq(diff(x(t),t), x(t) + y(t) + 17), Eq(diff(y(t),t), -2*x(t) + y(t) + 12)) | |
>>> sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - S(5)/3), | |
... Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - S(46)/3)] | |
>>> checksysodesol(eq, sol) | |
(True, [0, 0]) | |
>>> eq = (Eq(diff(x(t),t),x(t)*y(t)**4), Eq(diff(y(t),t),y(t)**3)) | |
>>> sol = [Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), -sqrt(2)*sqrt(-1/(C2 + t))/2), | |
... Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), sqrt(2)*sqrt(-1/(C2 + t))/2)] | |
>>> checksysodesol(eq, sol) | |
(True, [0, 0]) | |
""" | |
def _sympify(eq): | |
return list(map(sympify, eq if iterable(eq) else [eq])) | |
eqs = _sympify(eqs) | |
for i in range(len(eqs)): | |
if isinstance(eqs[i], Equality): | |
eqs[i] = eqs[i].lhs - eqs[i].rhs | |
if func is None: | |
funcs = [] | |
for eq in eqs: | |
derivs = eq.atoms(Derivative) | |
func = set().union(*[d.atoms(AppliedUndef) for d in derivs]) | |
funcs.extend(func) | |
funcs = list(set(funcs)) | |
if not all(isinstance(func, AppliedUndef) and len(func.args) == 1 for func in funcs)\ | |
and len({func.args for func in funcs})!=1: | |
raise ValueError("func must be a function of one variable, not %s" % func) | |
for sol in sols: | |
if len(sol.atoms(AppliedUndef)) != 1: | |
raise ValueError("solutions should have one function only") | |
if len(funcs) != len({sol.lhs for sol in sols}): | |
raise ValueError("number of solutions provided does not match the number of equations") | |
dictsol = {} | |
for sol in sols: | |
func = list(sol.atoms(AppliedUndef))[0] | |
if sol.rhs == func: | |
sol = sol.reversed | |
solved = sol.lhs == func and not sol.rhs.has(func) | |
if not solved: | |
rhs = solve(sol, func) | |
if not rhs: | |
raise NotImplementedError | |
else: | |
rhs = sol.rhs | |
dictsol[func] = rhs | |
checkeq = [] | |
for eq in eqs: | |
for func in funcs: | |
eq = sub_func_doit(eq, func, dictsol[func]) | |
ss = simplify(eq) | |
if ss != 0: | |
eq = ss.expand(force=True) | |
if eq != 0: | |
eq = sqrtdenest(eq).simplify() | |
else: | |
eq = 0 | |
checkeq.append(eq) | |
if len(set(checkeq)) == 1 and list(set(checkeq))[0] == 0: | |
return (True, checkeq) | |
else: | |
return (False, checkeq) | |