Spaces:
Sleeping
Sleeping
| from collections import defaultdict | |
| from sympy.core import sympify, S, Mul, Derivative, Pow | |
| from sympy.core.add import _unevaluated_Add, Add | |
| from sympy.core.assumptions import assumptions | |
| from sympy.core.exprtools import Factors, gcd_terms | |
| from sympy.core.function import _mexpand, expand_mul, expand_power_base | |
| from sympy.core.mul import _keep_coeff, _unevaluated_Mul, _mulsort | |
| from sympy.core.numbers import Rational, zoo, nan | |
| from sympy.core.parameters import global_parameters | |
| from sympy.core.sorting import ordered, default_sort_key | |
| from sympy.core.symbol import Dummy, Wild, symbols | |
| from sympy.functions import exp, sqrt, log | |
| from sympy.functions.elementary.complexes import Abs | |
| from sympy.polys import gcd | |
| from sympy.simplify.sqrtdenest import sqrtdenest | |
| from sympy.utilities.iterables import iterable, sift | |
| def collect(expr, syms, func=None, evaluate=None, exact=False, distribute_order_term=True): | |
| """ | |
| Collect additive terms of an expression. | |
| Explanation | |
| =========== | |
| This function collects additive terms of an expression with respect | |
| to a list of expression up to powers with rational exponents. By the | |
| term symbol here are meant arbitrary expressions, which can contain | |
| powers, products, sums etc. In other words symbol is a pattern which | |
| will be searched for in the expression's terms. | |
| The input expression is not expanded by :func:`collect`, so user is | |
| expected to provide an expression in an appropriate form. This makes | |
| :func:`collect` more predictable as there is no magic happening behind the | |
| scenes. However, it is important to note, that powers of products are | |
| converted to products of powers using the :func:`~.expand_power_base` | |
| function. | |
| There are two possible types of output. First, if ``evaluate`` flag is | |
| set, this function will return an expression with collected terms or | |
| else it will return a dictionary with expressions up to rational powers | |
| as keys and collected coefficients as values. | |
| Examples | |
| ======== | |
| >>> from sympy import S, collect, expand, factor, Wild | |
| >>> from sympy.abc import a, b, c, x, y | |
| This function can collect symbolic coefficients in polynomials or | |
| rational expressions. It will manage to find all integer or rational | |
| powers of collection variable:: | |
| >>> collect(a*x**2 + b*x**2 + a*x - b*x + c, x) | |
| c + x**2*(a + b) + x*(a - b) | |
| The same result can be achieved in dictionary form:: | |
| >>> d = collect(a*x**2 + b*x**2 + a*x - b*x + c, x, evaluate=False) | |
| >>> d[x**2] | |
| a + b | |
| >>> d[x] | |
| a - b | |
| >>> d[S.One] | |
| c | |
| You can also work with multivariate polynomials. However, remember that | |
| this function is greedy so it will care only about a single symbol at time, | |
| in specification order:: | |
| >>> collect(x**2 + y*x**2 + x*y + y + a*y, [x, y]) | |
| x**2*(y + 1) + x*y + y*(a + 1) | |
| Also more complicated expressions can be used as patterns:: | |
| >>> from sympy import sin, log | |
| >>> collect(a*sin(2*x) + b*sin(2*x), sin(2*x)) | |
| (a + b)*sin(2*x) | |
| >>> collect(a*x*log(x) + b*(x*log(x)), x*log(x)) | |
| x*(a + b)*log(x) | |
| You can use wildcards in the pattern:: | |
| >>> w = Wild('w1') | |
| >>> collect(a*x**y - b*x**y, w**y) | |
| x**y*(a - b) | |
| It is also possible to work with symbolic powers, although it has more | |
| complicated behavior, because in this case power's base and symbolic part | |
| of the exponent are treated as a single symbol:: | |
| >>> collect(a*x**c + b*x**c, x) | |
| a*x**c + b*x**c | |
| >>> collect(a*x**c + b*x**c, x**c) | |
| x**c*(a + b) | |
| However if you incorporate rationals to the exponents, then you will get | |
| well known behavior:: | |
| >>> collect(a*x**(2*c) + b*x**(2*c), x**c) | |
| x**(2*c)*(a + b) | |
| Note also that all previously stated facts about :func:`collect` function | |
| apply to the exponential function, so you can get:: | |
| >>> from sympy import exp | |
| >>> collect(a*exp(2*x) + b*exp(2*x), exp(x)) | |
| (a + b)*exp(2*x) | |
| If you are interested only in collecting specific powers of some symbols | |
| then set ``exact`` flag to True:: | |
| >>> collect(a*x**7 + b*x**7, x, exact=True) | |
| a*x**7 + b*x**7 | |
| >>> collect(a*x**7 + b*x**7, x**7, exact=True) | |
| x**7*(a + b) | |
| If you want to collect on any object containing symbols, set | |
| ``exact`` to None: | |
| >>> collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None) | |
| x*exp(x) + 3*x + (y + 2)*sin(x) | |
| >>> collect(a*x*y + x*y + b*x + x, [x, y], exact=None) | |
| x*y*(a + 1) + x*(b + 1) | |
| You can also apply this function to differential equations, where | |
| derivatives of arbitrary order can be collected. Note that if you | |
| collect with respect to a function or a derivative of a function, all | |
| derivatives of that function will also be collected. Use | |
| ``exact=True`` to prevent this from happening:: | |
| >>> from sympy import Derivative as D, collect, Function | |
| >>> f = Function('f') (x) | |
| >>> collect(a*D(f,x) + b*D(f,x), D(f,x)) | |
| (a + b)*Derivative(f(x), x) | |
| >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), f) | |
| (a + b)*Derivative(f(x), (x, 2)) | |
| >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), D(f,x), exact=True) | |
| a*Derivative(f(x), (x, 2)) + b*Derivative(f(x), (x, 2)) | |
| >>> collect(a*D(f,x) + b*D(f,x) + a*f + b*f, f) | |
| (a + b)*f(x) + (a + b)*Derivative(f(x), x) | |
| Or you can even match both derivative order and exponent at the same time:: | |
| >>> collect(a*D(D(f,x),x)**2 + b*D(D(f,x),x)**2, D(f,x)) | |
| (a + b)*Derivative(f(x), (x, 2))**2 | |
| Finally, you can apply a function to each of the collected coefficients. | |
| For example you can factorize symbolic coefficients of polynomial:: | |
| >>> f = expand((x + a + 1)**3) | |
| >>> collect(f, x, factor) | |
| x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + (a + 1)**3 | |
| .. note:: Arguments are expected to be in expanded form, so you might have | |
| to call :func:`~.expand` prior to calling this function. | |
| See Also | |
| ======== | |
| collect_const, collect_sqrt, rcollect | |
| """ | |
| expr = sympify(expr) | |
| syms = [sympify(i) for i in (syms if iterable(syms) else [syms])] | |
| # replace syms[i] if it is not x, -x or has Wild symbols | |
| cond = lambda x: x.is_Symbol or (-x).is_Symbol or bool( | |
| x.atoms(Wild)) | |
| _, nonsyms = sift(syms, cond, binary=True) | |
| if nonsyms: | |
| reps = dict(zip(nonsyms, [Dummy(**assumptions(i)) for i in nonsyms])) | |
| syms = [reps.get(s, s) for s in syms] | |
| rv = collect(expr.subs(reps), syms, | |
| func=func, evaluate=evaluate, exact=exact, | |
| distribute_order_term=distribute_order_term) | |
| urep = {v: k for k, v in reps.items()} | |
| if not isinstance(rv, dict): | |
| return rv.xreplace(urep) | |
| else: | |
| return {urep.get(k, k).xreplace(urep): v.xreplace(urep) | |
| for k, v in rv.items()} | |
| # see if other expressions should be considered | |
| if exact is None: | |
| _syms = set() | |
| for i in Add.make_args(expr): | |
| if not i.has_free(*syms) or i in syms: | |
| continue | |
| if not i.is_Mul and i not in syms: | |
| _syms.add(i) | |
| else: | |
| # identify compound generators | |
| g = i._new_rawargs(*i.as_coeff_mul(*syms)[1]) | |
| if g not in syms: | |
| _syms.add(g) | |
| simple = all(i.is_Pow and i.base in syms for i in _syms) | |
| syms = syms + list(ordered(_syms)) | |
| if not simple: | |
| return collect(expr, syms, | |
| func=func, evaluate=evaluate, exact=False, | |
| distribute_order_term=distribute_order_term) | |
| if evaluate is None: | |
| evaluate = global_parameters.evaluate | |
| def make_expression(terms): | |
| product = [] | |
| for term, rat, sym, deriv in terms: | |
| if deriv is not None: | |
| var, order = deriv | |
| while order > 0: | |
| term, order = Derivative(term, var), order - 1 | |
| if sym is None: | |
| if rat is S.One: | |
| product.append(term) | |
| else: | |
| product.append(Pow(term, rat)) | |
| else: | |
| product.append(Pow(term, rat*sym)) | |
| return Mul(*product) | |
| def parse_derivative(deriv): | |
| # scan derivatives tower in the input expression and return | |
| # underlying function and maximal differentiation order | |
| expr, sym, order = deriv.expr, deriv.variables[0], 1 | |
| for s in deriv.variables[1:]: | |
| if s == sym: | |
| order += 1 | |
| else: | |
| raise NotImplementedError( | |
| 'Improve MV Derivative support in collect') | |
| while isinstance(expr, Derivative): | |
| s0 = expr.variables[0] | |
| for s in expr.variables: | |
| if s != s0: | |
| raise NotImplementedError( | |
| 'Improve MV Derivative support in collect') | |
| if s0 == sym: | |
| expr, order = expr.expr, order + len(expr.variables) | |
| else: | |
| break | |
| return expr, (sym, Rational(order)) | |
| def parse_term(expr): | |
| """Parses expression expr and outputs tuple (sexpr, rat_expo, | |
| sym_expo, deriv) | |
| where: | |
| - sexpr is the base expression | |
| - rat_expo is the rational exponent that sexpr is raised to | |
| - sym_expo is the symbolic exponent that sexpr is raised to | |
| - deriv contains the derivatives of the expression | |
| For example, the output of x would be (x, 1, None, None) | |
| the output of 2**x would be (2, 1, x, None). | |
| """ | |
| rat_expo, sym_expo = S.One, None | |
| sexpr, deriv = expr, None | |
| if expr.is_Pow: | |
| if isinstance(expr.base, Derivative): | |
| sexpr, deriv = parse_derivative(expr.base) | |
| else: | |
| sexpr = expr.base | |
| if expr.base == S.Exp1: | |
| arg = expr.exp | |
| if arg.is_Rational: | |
| sexpr, rat_expo = S.Exp1, arg | |
| elif arg.is_Mul: | |
| coeff, tail = arg.as_coeff_Mul(rational=True) | |
| sexpr, rat_expo = exp(tail), coeff | |
| elif expr.exp.is_Number: | |
| rat_expo = expr.exp | |
| else: | |
| coeff, tail = expr.exp.as_coeff_Mul() | |
| if coeff.is_Number: | |
| rat_expo, sym_expo = coeff, tail | |
| else: | |
| sym_expo = expr.exp | |
| elif isinstance(expr, exp): | |
| arg = expr.exp | |
| if arg.is_Rational: | |
| sexpr, rat_expo = S.Exp1, arg | |
| elif arg.is_Mul: | |
| coeff, tail = arg.as_coeff_Mul(rational=True) | |
| sexpr, rat_expo = exp(tail), coeff | |
| elif isinstance(expr, Derivative): | |
| sexpr, deriv = parse_derivative(expr) | |
| return sexpr, rat_expo, sym_expo, deriv | |
| def parse_expression(terms, pattern): | |
| """Parse terms searching for a pattern. | |
| Terms is a list of tuples as returned by parse_terms; | |
| Pattern is an expression treated as a product of factors. | |
| """ | |
| pattern = Mul.make_args(pattern) | |
| if len(terms) < len(pattern): | |
| # pattern is longer than matched product | |
| # so no chance for positive parsing result | |
| return None | |
| else: | |
| pattern = [parse_term(elem) for elem in pattern] | |
| terms = terms[:] # need a copy | |
| elems, common_expo, has_deriv = [], None, False | |
| for elem, e_rat, e_sym, e_ord in pattern: | |
| if elem.is_Number and e_rat == 1 and e_sym is None: | |
| # a constant is a match for everything | |
| continue | |
| for j in range(len(terms)): | |
| if terms[j] is None: | |
| continue | |
| term, t_rat, t_sym, t_ord = terms[j] | |
| # keeping track of whether one of the terms had | |
| # a derivative or not as this will require rebuilding | |
| # the expression later | |
| if t_ord is not None: | |
| has_deriv = True | |
| if (term.match(elem) is not None and | |
| (t_sym == e_sym or t_sym is not None and | |
| e_sym is not None and | |
| t_sym.match(e_sym) is not None)): | |
| if exact is False: | |
| # we don't have to be exact so find common exponent | |
| # for both expression's term and pattern's element | |
| expo = t_rat / e_rat | |
| if common_expo is None: | |
| # first time | |
| common_expo = expo | |
| else: | |
| # common exponent was negotiated before so | |
| # there is no chance for a pattern match unless | |
| # common and current exponents are equal | |
| if common_expo != expo: | |
| common_expo = 1 | |
| else: | |
| # we ought to be exact so all fields of | |
| # interest must match in every details | |
| if e_rat != t_rat or e_ord != t_ord: | |
| continue | |
| # found common term so remove it from the expression | |
| # and try to match next element in the pattern | |
| elems.append(terms[j]) | |
| terms[j] = None | |
| break | |
| else: | |
| # pattern element not found | |
| return None | |
| return [_f for _f in terms if _f], elems, common_expo, has_deriv | |
| if evaluate: | |
| if expr.is_Add: | |
| o = expr.getO() or 0 | |
| expr = expr.func(*[ | |
| collect(a, syms, func, True, exact, distribute_order_term) | |
| for a in expr.args if a != o]) + o | |
| elif expr.is_Mul: | |
| return expr.func(*[ | |
| collect(term, syms, func, True, exact, distribute_order_term) | |
| for term in expr.args]) | |
| elif expr.is_Pow: | |
| b = collect( | |
| expr.base, syms, func, True, exact, distribute_order_term) | |
| return Pow(b, expr.exp) | |
| syms = [expand_power_base(i, deep=False) for i in syms] | |
| order_term = None | |
| if distribute_order_term: | |
| order_term = expr.getO() | |
| if order_term is not None: | |
| if order_term.has(*syms): | |
| order_term = None | |
| else: | |
| expr = expr.removeO() | |
| summa = [expand_power_base(i, deep=False) for i in Add.make_args(expr)] | |
| collected, disliked = defaultdict(list), S.Zero | |
| for product in summa: | |
| c, nc = product.args_cnc(split_1=False) | |
| args = list(ordered(c)) + nc | |
| terms = [parse_term(i) for i in args] | |
| small_first = True | |
| for symbol in syms: | |
| if isinstance(symbol, Derivative) and small_first: | |
| terms = list(reversed(terms)) | |
| small_first = not small_first | |
| result = parse_expression(terms, symbol) | |
| if result is not None: | |
| if not symbol.is_commutative: | |
| raise AttributeError("Can not collect noncommutative symbol") | |
| terms, elems, common_expo, has_deriv = result | |
| # when there was derivative in current pattern we | |
| # will need to rebuild its expression from scratch | |
| if not has_deriv: | |
| margs = [] | |
| for elem in elems: | |
| if elem[2] is None: | |
| e = elem[1] | |
| else: | |
| e = elem[1]*elem[2] | |
| margs.append(Pow(elem[0], e)) | |
| index = Mul(*margs) | |
| else: | |
| index = make_expression(elems) | |
| terms = expand_power_base(make_expression(terms), deep=False) | |
| index = expand_power_base(index, deep=False) | |
| collected[index].append(terms) | |
| break | |
| else: | |
| # none of the patterns matched | |
| disliked += product | |
| # add terms now for each key | |
| collected = {k: Add(*v) for k, v in collected.items()} | |
| if disliked is not S.Zero: | |
| collected[S.One] = disliked | |
| if order_term is not None: | |
| for key, val in collected.items(): | |
| collected[key] = val + order_term | |
| if func is not None: | |
| collected = { | |
| key: func(val) for key, val in collected.items()} | |
| if evaluate: | |
| return Add(*[key*val for key, val in collected.items()]) | |
| else: | |
| return collected | |
| def rcollect(expr, *vars): | |
| """ | |
| Recursively collect sums in an expression. | |
| Examples | |
| ======== | |
| >>> from sympy.simplify import rcollect | |
| >>> from sympy.abc import x, y | |
| >>> expr = (x**2*y + x*y + x + y)/(x + y) | |
| >>> rcollect(expr, y) | |
| (x + y*(x**2 + x + 1))/(x + y) | |
| See Also | |
| ======== | |
| collect, collect_const, collect_sqrt | |
| """ | |
| if expr.is_Atom or not expr.has(*vars): | |
| return expr | |
| else: | |
| expr = expr.__class__(*[rcollect(arg, *vars) for arg in expr.args]) | |
| if expr.is_Add: | |
| return collect(expr, vars) | |
| else: | |
| return expr | |
| def collect_sqrt(expr, evaluate=None): | |
| """Return expr with terms having common square roots collected together. | |
| If ``evaluate`` is False a count indicating the number of sqrt-containing | |
| terms will be returned and, if non-zero, the terms of the Add will be | |
| returned, else the expression itself will be returned as a single term. | |
| If ``evaluate`` is True, the expression with any collected terms will be | |
| returned. | |
| Note: since I = sqrt(-1), it is collected, too. | |
| Examples | |
| ======== | |
| >>> from sympy import sqrt | |
| >>> from sympy.simplify.radsimp import collect_sqrt | |
| >>> from sympy.abc import a, b | |
| >>> r2, r3, r5 = [sqrt(i) for i in [2, 3, 5]] | |
| >>> collect_sqrt(a*r2 + b*r2) | |
| sqrt(2)*(a + b) | |
| >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r3) | |
| sqrt(2)*(a + b) + sqrt(3)*(a + b) | |
| >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5) | |
| sqrt(3)*a + sqrt(5)*b + sqrt(2)*(a + b) | |
| If evaluate is False then the arguments will be sorted and | |
| returned as a list and a count of the number of sqrt-containing | |
| terms will be returned: | |
| >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5, evaluate=False) | |
| ((sqrt(3)*a, sqrt(5)*b, sqrt(2)*(a + b)), 3) | |
| >>> collect_sqrt(a*sqrt(2) + b, evaluate=False) | |
| ((b, sqrt(2)*a), 1) | |
| >>> collect_sqrt(a + b, evaluate=False) | |
| ((a + b,), 0) | |
| See Also | |
| ======== | |
| collect, collect_const, rcollect | |
| """ | |
| if evaluate is None: | |
| evaluate = global_parameters.evaluate | |
| # this step will help to standardize any complex arguments | |
| # of sqrts | |
| coeff, expr = expr.as_content_primitive() | |
| vars = set() | |
| for a in Add.make_args(expr): | |
| for m in a.args_cnc()[0]: | |
| if m.is_number and ( | |
| m.is_Pow and m.exp.is_Rational and m.exp.q == 2 or | |
| m is S.ImaginaryUnit): | |
| vars.add(m) | |
| # we only want radicals, so exclude Number handling; in this case | |
| # d will be evaluated | |
| d = collect_const(expr, *vars, Numbers=False) | |
| hit = expr != d | |
| if not evaluate: | |
| nrad = 0 | |
| # make the evaluated args canonical | |
| args = list(ordered(Add.make_args(d))) | |
| for i, m in enumerate(args): | |
| c, nc = m.args_cnc() | |
| for ci in c: | |
| # XXX should this be restricted to ci.is_number as above? | |
| if ci.is_Pow and ci.exp.is_Rational and ci.exp.q == 2 or \ | |
| ci is S.ImaginaryUnit: | |
| nrad += 1 | |
| break | |
| args[i] *= coeff | |
| if not (hit or nrad): | |
| args = [Add(*args)] | |
| return tuple(args), nrad | |
| return coeff*d | |
| def collect_abs(expr): | |
| """Return ``expr`` with arguments of multiple Abs in a term collected | |
| under a single instance. | |
| Examples | |
| ======== | |
| >>> from sympy.simplify.radsimp import collect_abs | |
| >>> from sympy.abc import x | |
| >>> collect_abs(abs(x + 1)/abs(x**2 - 1)) | |
| Abs((x + 1)/(x**2 - 1)) | |
| >>> collect_abs(abs(1/x)) | |
| Abs(1/x) | |
| """ | |
| def _abs(mul): | |
| c, nc = mul.args_cnc() | |
| a = [] | |
| o = [] | |
| for i in c: | |
| if isinstance(i, Abs): | |
| a.append(i.args[0]) | |
| elif isinstance(i, Pow) and isinstance(i.base, Abs) and i.exp.is_real: | |
| a.append(i.base.args[0]**i.exp) | |
| else: | |
| o.append(i) | |
| if len(a) < 2 and not any(i.exp.is_negative for i in a if isinstance(i, Pow)): | |
| return mul | |
| absarg = Mul(*a) | |
| A = Abs(absarg) | |
| args = [A] | |
| args.extend(o) | |
| if not A.has(Abs): | |
| args.extend(nc) | |
| return Mul(*args) | |
| if not isinstance(A, Abs): | |
| # reevaluate and make it unevaluated | |
| A = Abs(absarg, evaluate=False) | |
| args[0] = A | |
| _mulsort(args) | |
| args.extend(nc) # nc always go last | |
| return Mul._from_args(args, is_commutative=not nc) | |
| return expr.replace( | |
| lambda x: isinstance(x, Mul), | |
| lambda x: _abs(x)).replace( | |
| lambda x: isinstance(x, Pow), | |
| lambda x: _abs(x)) | |
| def collect_const(expr, *vars, Numbers=True): | |
| """A non-greedy collection of terms with similar number coefficients in | |
| an Add expr. If ``vars`` is given then only those constants will be | |
| targeted. Although any Number can also be targeted, if this is not | |
| desired set ``Numbers=False`` and no Float or Rational will be collected. | |
| Parameters | |
| ========== | |
| expr : SymPy expression | |
| This parameter defines the expression the expression from which | |
| terms with similar coefficients are to be collected. A non-Add | |
| expression is returned as it is. | |
| vars : variable length collection of Numbers, optional | |
| Specifies the constants to target for collection. Can be multiple in | |
| number. | |
| Numbers : bool | |
| Specifies to target all instance of | |
| :class:`sympy.core.numbers.Number` class. If ``Numbers=False``, then | |
| no Float or Rational will be collected. | |
| Returns | |
| ======= | |
| expr : Expr | |
| Returns an expression with similar coefficient terms collected. | |
| Examples | |
| ======== | |
| >>> from sympy import sqrt | |
| >>> from sympy.abc import s, x, y, z | |
| >>> from sympy.simplify.radsimp import collect_const | |
| >>> collect_const(sqrt(3) + sqrt(3)*(1 + sqrt(2))) | |
| sqrt(3)*(sqrt(2) + 2) | |
| >>> collect_const(sqrt(3)*s + sqrt(7)*s + sqrt(3) + sqrt(7)) | |
| (sqrt(3) + sqrt(7))*(s + 1) | |
| >>> s = sqrt(2) + 2 | |
| >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7)) | |
| (sqrt(2) + 3)*(sqrt(3) + sqrt(7)) | |
| >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7), sqrt(3)) | |
| sqrt(7) + sqrt(3)*(sqrt(2) + 3) + sqrt(7)*(sqrt(2) + 2) | |
| The collection is sign-sensitive, giving higher precedence to the | |
| unsigned values: | |
| >>> collect_const(x - y - z) | |
| x - (y + z) | |
| >>> collect_const(-y - z) | |
| -(y + z) | |
| >>> collect_const(2*x - 2*y - 2*z, 2) | |
| 2*(x - y - z) | |
| >>> collect_const(2*x - 2*y - 2*z, -2) | |
| 2*x - 2*(y + z) | |
| See Also | |
| ======== | |
| collect, collect_sqrt, rcollect | |
| """ | |
| if not expr.is_Add: | |
| return expr | |
| recurse = False | |
| if not vars: | |
| recurse = True | |
| vars = set() | |
| for a in expr.args: | |
| for m in Mul.make_args(a): | |
| if m.is_number: | |
| vars.add(m) | |
| else: | |
| vars = sympify(vars) | |
| if not Numbers: | |
| vars = [v for v in vars if not v.is_Number] | |
| vars = list(ordered(vars)) | |
| for v in vars: | |
| terms = defaultdict(list) | |
| Fv = Factors(v) | |
| for m in Add.make_args(expr): | |
| f = Factors(m) | |
| q, r = f.div(Fv) | |
| if r.is_one: | |
| # only accept this as a true factor if | |
| # it didn't change an exponent from an Integer | |
| # to a non-Integer, e.g. 2/sqrt(2) -> sqrt(2) | |
| # -- we aren't looking for this sort of change | |
| fwas = f.factors.copy() | |
| fnow = q.factors | |
| if not any(k in fwas and fwas[k].is_Integer and not | |
| fnow[k].is_Integer for k in fnow): | |
| terms[v].append(q.as_expr()) | |
| continue | |
| terms[S.One].append(m) | |
| args = [] | |
| hit = False | |
| uneval = False | |
| for k in ordered(terms): | |
| v = terms[k] | |
| if k is S.One: | |
| args.extend(v) | |
| continue | |
| if len(v) > 1: | |
| v = Add(*v) | |
| hit = True | |
| if recurse and v != expr: | |
| vars.append(v) | |
| else: | |
| v = v[0] | |
| # be careful not to let uneval become True unless | |
| # it must be because it's going to be more expensive | |
| # to rebuild the expression as an unevaluated one | |
| if Numbers and k.is_Number and v.is_Add: | |
| args.append(_keep_coeff(k, v, sign=True)) | |
| uneval = True | |
| else: | |
| args.append(k*v) | |
| if hit: | |
| if uneval: | |
| expr = _unevaluated_Add(*args) | |
| else: | |
| expr = Add(*args) | |
| if not expr.is_Add: | |
| break | |
| return expr | |
| def radsimp(expr, symbolic=True, max_terms=4): | |
| r""" | |
| Rationalize the denominator by removing square roots. | |
| Explanation | |
| =========== | |
| The expression returned from radsimp must be used with caution | |
| since if the denominator contains symbols, it will be possible to make | |
| substitutions that violate the assumptions of the simplification process: | |
| that for a denominator matching a + b*sqrt(c), a != +/-b*sqrt(c). (If | |
| there are no symbols, this assumptions is made valid by collecting terms | |
| of sqrt(c) so the match variable ``a`` does not contain ``sqrt(c)``.) If | |
| you do not want the simplification to occur for symbolic denominators, set | |
| ``symbolic`` to False. | |
| If there are more than ``max_terms`` radical terms then the expression is | |
| returned unchanged. | |
| Examples | |
| ======== | |
| >>> from sympy import radsimp, sqrt, Symbol, pprint | |
| >>> from sympy import factor_terms, fraction, signsimp | |
| >>> from sympy.simplify.radsimp import collect_sqrt | |
| >>> from sympy.abc import a, b, c | |
| >>> radsimp(1/(2 + sqrt(2))) | |
| (2 - sqrt(2))/2 | |
| >>> x,y = map(Symbol, 'xy') | |
| >>> e = ((2 + 2*sqrt(2))*x + (2 + sqrt(8))*y)/(2 + sqrt(2)) | |
| >>> radsimp(e) | |
| sqrt(2)*(x + y) | |
| No simplification beyond removal of the gcd is done. One might | |
| want to polish the result a little, however, by collecting | |
| square root terms: | |
| >>> r2 = sqrt(2) | |
| >>> r5 = sqrt(5) | |
| >>> ans = radsimp(1/(y*r2 + x*r2 + a*r5 + b*r5)); pprint(ans) | |
| ___ ___ ___ ___ | |
| \/ 5 *a + \/ 5 *b - \/ 2 *x - \/ 2 *y | |
| ------------------------------------------ | |
| 2 2 2 2 | |
| 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y | |
| >>> n, d = fraction(ans) | |
| >>> pprint(factor_terms(signsimp(collect_sqrt(n))/d, radical=True)) | |
| ___ ___ | |
| \/ 5 *(a + b) - \/ 2 *(x + y) | |
| ------------------------------------------ | |
| 2 2 2 2 | |
| 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y | |
| If radicals in the denominator cannot be removed or there is no denominator, | |
| the original expression will be returned. | |
| >>> radsimp(sqrt(2)*x + sqrt(2)) | |
| sqrt(2)*x + sqrt(2) | |
| Results with symbols will not always be valid for all substitutions: | |
| >>> eq = 1/(a + b*sqrt(c)) | |
| >>> eq.subs(a, b*sqrt(c)) | |
| 1/(2*b*sqrt(c)) | |
| >>> radsimp(eq).subs(a, b*sqrt(c)) | |
| nan | |
| If ``symbolic=False``, symbolic denominators will not be transformed (but | |
| numeric denominators will still be processed): | |
| >>> radsimp(eq, symbolic=False) | |
| 1/(a + b*sqrt(c)) | |
| """ | |
| from sympy.simplify.simplify import signsimp | |
| syms = symbols("a:d A:D") | |
| def _num(rterms): | |
| # return the multiplier that will simplify the expression described | |
| # by rterms [(sqrt arg, coeff), ... ] | |
| a, b, c, d, A, B, C, D = syms | |
| if len(rterms) == 2: | |
| reps = dict(list(zip([A, a, B, b], [j for i in rterms for j in i]))) | |
| return ( | |
| sqrt(A)*a - sqrt(B)*b).xreplace(reps) | |
| if len(rterms) == 3: | |
| reps = dict(list(zip([A, a, B, b, C, c], [j for i in rterms for j in i]))) | |
| return ( | |
| (sqrt(A)*a + sqrt(B)*b - sqrt(C)*c)*(2*sqrt(A)*sqrt(B)*a*b - A*a**2 - | |
| B*b**2 + C*c**2)).xreplace(reps) | |
| elif len(rterms) == 4: | |
| reps = dict(list(zip([A, a, B, b, C, c, D, d], [j for i in rterms for j in i]))) | |
| return ((sqrt(A)*a + sqrt(B)*b - sqrt(C)*c - sqrt(D)*d)*(2*sqrt(A)*sqrt(B)*a*b | |
| - A*a**2 - B*b**2 - 2*sqrt(C)*sqrt(D)*c*d + C*c**2 + | |
| D*d**2)*(-8*sqrt(A)*sqrt(B)*sqrt(C)*sqrt(D)*a*b*c*d + A**2*a**4 - | |
| 2*A*B*a**2*b**2 - 2*A*C*a**2*c**2 - 2*A*D*a**2*d**2 + B**2*b**4 - | |
| 2*B*C*b**2*c**2 - 2*B*D*b**2*d**2 + C**2*c**4 - 2*C*D*c**2*d**2 + | |
| D**2*d**4)).xreplace(reps) | |
| elif len(rterms) == 1: | |
| return sqrt(rterms[0][0]) | |
| else: | |
| raise NotImplementedError | |
| def ispow2(d, log2=False): | |
| if not d.is_Pow: | |
| return False | |
| e = d.exp | |
| if e.is_Rational and e.q == 2 or symbolic and denom(e) == 2: | |
| return True | |
| if log2: | |
| q = 1 | |
| if e.is_Rational: | |
| q = e.q | |
| elif symbolic: | |
| d = denom(e) | |
| if d.is_Integer: | |
| q = d | |
| if q != 1 and log(q, 2).is_Integer: | |
| return True | |
| return False | |
| def handle(expr): | |
| # Handle first reduces to the case | |
| # expr = 1/d, where d is an add, or d is base**p/2. | |
| # We do this by recursively calling handle on each piece. | |
| from sympy.simplify.simplify import nsimplify | |
| n, d = fraction(expr) | |
| if expr.is_Atom or (d.is_Atom and n.is_Atom): | |
| return expr | |
| elif not n.is_Atom: | |
| n = n.func(*[handle(a) for a in n.args]) | |
| return _unevaluated_Mul(n, handle(1/d)) | |
| elif n is not S.One: | |
| return _unevaluated_Mul(n, handle(1/d)) | |
| elif d.is_Mul: | |
| return _unevaluated_Mul(*[handle(1/d) for d in d.args]) | |
| # By this step, expr is 1/d, and d is not a mul. | |
| if not symbolic and d.free_symbols: | |
| return expr | |
| if ispow2(d): | |
| d2 = sqrtdenest(sqrt(d.base))**numer(d.exp) | |
| if d2 != d: | |
| return handle(1/d2) | |
| elif d.is_Pow and (d.exp.is_integer or d.base.is_positive): | |
| # (1/d**i) = (1/d)**i | |
| return handle(1/d.base)**d.exp | |
| if not (d.is_Add or ispow2(d)): | |
| return 1/d.func(*[handle(a) for a in d.args]) | |
| # handle 1/d treating d as an Add (though it may not be) | |
| keep = True # keep changes that are made | |
| # flatten it and collect radicals after checking for special | |
| # conditions | |
| d = _mexpand(d) | |
| # did it change? | |
| if d.is_Atom: | |
| return 1/d | |
| # is it a number that might be handled easily? | |
| if d.is_number: | |
| _d = nsimplify(d) | |
| if _d.is_Number and _d.equals(d): | |
| return 1/_d | |
| while True: | |
| # collect similar terms | |
| collected = defaultdict(list) | |
| for m in Add.make_args(d): # d might have become non-Add | |
| p2 = [] | |
| other = [] | |
| for i in Mul.make_args(m): | |
| if ispow2(i, log2=True): | |
| p2.append(i.base if i.exp is S.Half else i.base**(2*i.exp)) | |
| elif i is S.ImaginaryUnit: | |
| p2.append(S.NegativeOne) | |
| else: | |
| other.append(i) | |
| collected[tuple(ordered(p2))].append(Mul(*other)) | |
| rterms = list(ordered(list(collected.items()))) | |
| rterms = [(Mul(*i), Add(*j)) for i, j in rterms] | |
| nrad = len(rterms) - (1 if rterms[0][0] is S.One else 0) | |
| if nrad < 1: | |
| break | |
| elif nrad > max_terms: | |
| # there may have been invalid operations leading to this point | |
| # so don't keep changes, e.g. this expression is troublesome | |
| # in collecting terms so as not to raise the issue of 2834: | |
| # r = sqrt(sqrt(5) + 5) | |
| # eq = 1/(sqrt(5)*r + 2*sqrt(5)*sqrt(-sqrt(5) + 5) + 5*r) | |
| keep = False | |
| break | |
| if len(rterms) > 4: | |
| # in general, only 4 terms can be removed with repeated squaring | |
| # but other considerations can guide selection of radical terms | |
| # so that radicals are removed | |
| if all(x.is_Integer and (y**2).is_Rational for x, y in rterms): | |
| nd, d = rad_rationalize(S.One, Add._from_args( | |
| [sqrt(x)*y for x, y in rterms])) | |
| n *= nd | |
| else: | |
| # is there anything else that might be attempted? | |
| keep = False | |
| break | |
| from sympy.simplify.powsimp import powsimp, powdenest | |
| num = powsimp(_num(rterms)) | |
| n *= num | |
| d *= num | |
| d = powdenest(_mexpand(d), force=symbolic) | |
| if d.has(S.Zero, nan, zoo): | |
| return expr | |
| if d.is_Atom: | |
| break | |
| if not keep: | |
| return expr | |
| return _unevaluated_Mul(n, 1/d) | |
| coeff, expr = expr.as_coeff_Add() | |
| expr = expr.normal() | |
| old = fraction(expr) | |
| n, d = fraction(handle(expr)) | |
| if old != (n, d): | |
| if not d.is_Atom: | |
| was = (n, d) | |
| n = signsimp(n, evaluate=False) | |
| d = signsimp(d, evaluate=False) | |
| u = Factors(_unevaluated_Mul(n, 1/d)) | |
| u = _unevaluated_Mul(*[k**v for k, v in u.factors.items()]) | |
| n, d = fraction(u) | |
| if old == (n, d): | |
| n, d = was | |
| n = expand_mul(n) | |
| if d.is_Number or d.is_Add: | |
| n2, d2 = fraction(gcd_terms(_unevaluated_Mul(n, 1/d))) | |
| if d2.is_Number or (d2.count_ops() <= d.count_ops()): | |
| n, d = [signsimp(i) for i in (n2, d2)] | |
| if n.is_Mul and n.args[0].is_Number: | |
| n = n.func(*n.args) | |
| return coeff + _unevaluated_Mul(n, 1/d) | |
| def rad_rationalize(num, den): | |
| """ | |
| Rationalize ``num/den`` by removing square roots in the denominator; | |
| num and den are sum of terms whose squares are positive rationals. | |
| Examples | |
| ======== | |
| >>> from sympy import sqrt | |
| >>> from sympy.simplify.radsimp import rad_rationalize | |
| >>> rad_rationalize(sqrt(3), 1 + sqrt(2)/3) | |
| (-sqrt(3) + sqrt(6)/3, -7/9) | |
| """ | |
| if not den.is_Add: | |
| return num, den | |
| g, a, b = split_surds(den) | |
| a = a*sqrt(g) | |
| num = _mexpand((a - b)*num) | |
| den = _mexpand(a**2 - b**2) | |
| return rad_rationalize(num, den) | |
| def fraction(expr, exact=False): | |
| """Returns a pair with expression's numerator and denominator. | |
| If the given expression is not a fraction then this function | |
| will return the tuple (expr, 1). | |
| This function will not make any attempt to simplify nested | |
| fractions or to do any term rewriting at all. | |
| If only one of the numerator/denominator pair is needed then | |
| use numer(expr) or denom(expr) functions respectively. | |
| >>> from sympy import fraction, Rational, Symbol | |
| >>> from sympy.abc import x, y | |
| >>> fraction(x/y) | |
| (x, y) | |
| >>> fraction(x) | |
| (x, 1) | |
| >>> fraction(1/y**2) | |
| (1, y**2) | |
| >>> fraction(x*y/2) | |
| (x*y, 2) | |
| >>> fraction(Rational(1, 2)) | |
| (1, 2) | |
| This function will also work fine with assumptions: | |
| >>> k = Symbol('k', negative=True) | |
| >>> fraction(x * y**k) | |
| (x, y**(-k)) | |
| If we know nothing about sign of some exponent and ``exact`` | |
| flag is unset, then the exponent's structure will | |
| be analyzed and pretty fraction will be returned: | |
| >>> from sympy import exp, Mul | |
| >>> fraction(2*x**(-y)) | |
| (2, x**y) | |
| >>> fraction(exp(-x)) | |
| (1, exp(x)) | |
| >>> fraction(exp(-x), exact=True) | |
| (exp(-x), 1) | |
| The ``exact`` flag will also keep any unevaluated Muls from | |
| being evaluated: | |
| >>> u = Mul(2, x + 1, evaluate=False) | |
| >>> fraction(u) | |
| (2*x + 2, 1) | |
| >>> fraction(u, exact=True) | |
| (2*(x + 1), 1) | |
| """ | |
| expr = sympify(expr) | |
| numer, denom = [], [] | |
| for term in Mul.make_args(expr): | |
| if term.is_commutative and (term.is_Pow or isinstance(term, exp)): | |
| b, ex = term.as_base_exp() | |
| if ex.is_negative: | |
| if ex is S.NegativeOne: | |
| denom.append(b) | |
| elif exact: | |
| if ex.is_constant(): | |
| denom.append(Pow(b, -ex)) | |
| else: | |
| numer.append(term) | |
| else: | |
| denom.append(Pow(b, -ex)) | |
| elif ex.is_positive: | |
| numer.append(term) | |
| elif not exact and ex.is_Mul: | |
| n, d = term.as_numer_denom() # this will cause evaluation | |
| if n != 1: | |
| numer.append(n) | |
| denom.append(d) | |
| else: | |
| numer.append(term) | |
| elif term.is_Rational and not term.is_Integer: | |
| if term.p != 1: | |
| numer.append(term.p) | |
| denom.append(term.q) | |
| else: | |
| numer.append(term) | |
| return Mul(*numer, evaluate=not exact), Mul(*denom, evaluate=not exact) | |
| def numer(expr, exact=False): # default matches fraction's default | |
| return fraction(expr, exact=exact)[0] | |
| def denom(expr, exact=False): # default matches fraction's default | |
| return fraction(expr, exact=exact)[1] | |
| def fraction_expand(expr, **hints): | |
| return expr.expand(frac=True, **hints) | |
| def numer_expand(expr, **hints): | |
| # default matches fraction's default | |
| a, b = fraction(expr, exact=hints.get('exact', False)) | |
| return a.expand(numer=True, **hints) / b | |
| def denom_expand(expr, **hints): | |
| # default matches fraction's default | |
| a, b = fraction(expr, exact=hints.get('exact', False)) | |
| return a / b.expand(denom=True, **hints) | |
| expand_numer = numer_expand | |
| expand_denom = denom_expand | |
| expand_fraction = fraction_expand | |
| def split_surds(expr): | |
| """ | |
| Split an expression with terms whose squares are positive rationals | |
| into a sum of terms whose surds squared have gcd equal to g | |
| and a sum of terms with surds squared prime with g. | |
| Examples | |
| ======== | |
| >>> from sympy import sqrt | |
| >>> from sympy.simplify.radsimp import split_surds | |
| >>> split_surds(3*sqrt(3) + sqrt(5)/7 + sqrt(6) + sqrt(10) + sqrt(15)) | |
| (3, sqrt(2) + sqrt(5) + 3, sqrt(5)/7 + sqrt(10)) | |
| """ | |
| args = sorted(expr.args, key=default_sort_key) | |
| coeff_muls = [x.as_coeff_Mul() for x in args] | |
| surds = [x[1]**2 for x in coeff_muls if x[1].is_Pow] | |
| surds.sort(key=default_sort_key) | |
| g, b1, b2 = _split_gcd(*surds) | |
| g2 = g | |
| if not b2 and len(b1) >= 2: | |
| b1n = [x/g for x in b1] | |
| b1n = [x for x in b1n if x != 1] | |
| # only a common factor has been factored; split again | |
| g1, b1n, b2 = _split_gcd(*b1n) | |
| g2 = g*g1 | |
| a1v, a2v = [], [] | |
| for c, s in coeff_muls: | |
| if s.is_Pow and s.exp == S.Half: | |
| s1 = s.base | |
| if s1 in b1: | |
| a1v.append(c*sqrt(s1/g2)) | |
| else: | |
| a2v.append(c*s) | |
| else: | |
| a2v.append(c*s) | |
| a = Add(*a1v) | |
| b = Add(*a2v) | |
| return g2, a, b | |
| def _split_gcd(*a): | |
| """ | |
| Split the list of integers ``a`` into a list of integers, ``a1`` having | |
| ``g = gcd(a1)``, and a list ``a2`` whose elements are not divisible by | |
| ``g``. Returns ``g, a1, a2``. | |
| Examples | |
| ======== | |
| >>> from sympy.simplify.radsimp import _split_gcd | |
| >>> _split_gcd(55, 35, 22, 14, 77, 10) | |
| (5, [55, 35, 10], [22, 14, 77]) | |
| """ | |
| g = a[0] | |
| b1 = [g] | |
| b2 = [] | |
| for x in a[1:]: | |
| g1 = gcd(g, x) | |
| if g1 == 1: | |
| b2.append(x) | |
| else: | |
| g = g1 | |
| b1.append(x) | |
| return g, b1, b2 | |