SymPy division doesn't cancel what it can when using symbolic denominator

Question:

I have some code using sympy.solvers.solve() that basically leads to the following:

>>> k, u, p, q = sympy.symbols('k u p q')
>>> solution = (k*u + p*u + q)/(k+p)
>>> solution.simplify()
(k*u + p*u + q)/(k + p)

Now, my problem is that it is not simplified enough/correctly. It should be giving the following:

q/(k + p) + u

From the original equation q = (k + p)*(m - u) this is more obvious (when you solve it manually, which my students will be doing).

I have tried many combinations of sol.simplify(), sol.cancel(), sol.collect(u) but I haven’t found what can make it work (btw, the collect I can’t really use, as I won’t know beforehand which symbol will have to be collected, unless you can make something that collects all the symbols in the solution).

I am working with BookWidgets, which automatically corrects the answers that students give, which is why it’s important that I have an output which will match what the students will enter.

Asked By: MikeA

||

Answers:

First things first:

  • there is no "standard" output to a simplification step.
  • if the output of a simplification step doesn’t suit your need, you might want to manipulate the expression with simplify, expand, collect, …
  • two or more sequences of operations (simplify, expand, collect, …) might lead to different results, or might lead to the same results. It depends on the expression being manipulated.

Let me show you with your example:

k, u, p, q = symbols('k u p q')
solution = (k*u + p*u + q)/(k+p)
# out1: (k*u + p*u + q)/(k + p)

solution = solution.collect(u)
# out2: (q + u*(k + p))/(k + p)

num, den = fraction(solution)
# use the linearity of addition
solution = Add(*[t / den for t in num.args])
# out3: q/(k + p) + u

In the above code, out1, out2, out3 are mathematically equivalent.

Instead of spending time to simplify outputs, I would test for mathematical equivalence with the equals method. For example:

verified_solution = (k*u + p*u + q)/(k+p)

num, den = fraction(verified_solution)
first_studend_sol = Add(*[t / den for t in num.args])
print(verified_solution.equals(first_studend_sol))
# True

second_student_solution = q/(k + p) + u
print(verified_solution.equals(second_student_solution))
# True

third_student_solution = q/(k + p) + u + 2
print(verified_solution.equals(third_student_solution))
# False
Answered By: Davide_sd

It looks like you want the expression in quotient/remainder form:

>>> n, d = solution.as_numer_denom()
>>> div(n, d)
(u, q)
>>> _[0] + _[1]/d
q/(k + p) + u

But that SymPy function may give unexpected results when the symbol names are changed as described here. Here is an alternative (for which I did not find and existing function in SymPy) that attempts more a synthetic division result:

def sdiv(p, q):
    """return w, r if p = w*q + r else 0, p

    Examples
    ========

    >>> from sympy.abc import x, y
    >>> sdiv(x, x)
    (1, 0)
    >>> sdiv(x, y)
    (0, x)
    >>> sdiv(2*x + 3, x)
    (2, 3)
    >>> a, b=x + 2*y + z, x + y
    >>> sdiv(a, b)
    (1, y + z)
    >>> sdiv(a, -b)
    (-1, y + z)
    >>> sdiv(-a, -b)
    (1, -y - z)
    >>> sdiv(-a, b)
    (-1, -y - z)
    """
    from sympy.core.function import _mexpand
    P, Q = map(lambda i: _mexpand(i, recursive=True), (p, q))
    r, wq = P.as_independent(*Q.free_symbols, as_Add=True)
    # quick exit if no full division possible
    if Q.is_Add and not wq.is_Add:
        return S.Zero, P
    # check multiplicative cancellation
    w, bot = fraction((wq/Q).cancel())
    if bot != 1 and wq.is_Add and Q.is_Add:
        # try maximal additive extraction
        s1 = s2 = 1
        if signsimp(Q, evaluate=False).is_Mul:
            wq = -wq
            r = -r
            Q = -Q
            s1 = -1
        if signsimp(wq, evaluate=False).is_Mul:
            wq = -wq
            s2 = -1
        xa = wq.extract_additively(Q)
        if xa:
            was = wq.as_coefficients_dict()
            now = xa.as_coefficients_dict()
            dif = {k: was[k] - now.get(k, 0) for k in was}
            n = min(was[k]//dif[k] for k in dif)
            dr = wq - n*Q
            w = s2*n
            r = s1*(r + s2*dr)
            assert _mexpand(p - (w*q + r)) == 0
            bot = 1
    return (w, r) if bot == 1 else (S.Zero, p)

The more general suggestion from Davide_sd about using equals is good if you are only testing the equality of two expressions in different forms.

Answered By: smichr
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.