Python Dynamic Programming Problem – ( 2 dimension recursion stuck in infinite loop )

Question:

In the book "A Practical Guide to Quantitative Finance Interview", there is a question called Dynamic Card Game, 5.3 Dynamic Programming)

The solution according to the book is basically the following:

E[f(b,r)] = max(b−r,(b/(b+r))∗E[f(b−1,r)]+(r/(b+r))∗E[f(b,r−1)])

with the following boundary conditions.

f(0,r)=0, f(b,0)=b

I tried implementing it in python as follows:

def f(b,r):
    if b == 0:
        return 0
    elif r == 0:
        return b
    else:
        var = (b/(b+r)) * f(b-1, r) + (r/(b+r)) * f(b, r-1) 
        return max( b-r,  var )

print("The solution is")
print(f(26,26))

But, for some reason, the above code got stuck in infinite loop and the program does not return anything for large input such as f(26,26).

It works fine for smaller number. For example, f(5,5) would return 1.11904 immediately.

Can anyone explain what I am doing wrong here in the code?

Asked By: nyan314sn

||

Answers:

Dynamic programming can be used to compute this efficiently. It’s just a case of creating a table of results E(f(b, r)), and making sure that you fill in the table in an order such that table entries needed to compute the current results are already computed.

This code solves the problem exactly (using fractions):

from fractions import Fraction as F

def S(b, r):
    E = [[i] + [0] * r for i in range(b+1)]
    for i in range(1, b+1):
        for j in range(1, r+1):
            E[i][j] = max(i-j, F(i, i+j) * E[i-1][j] + F(j, i+j) * E[i][j-1])
    return E[b][r]

print(S(5, 5))
print(S(26, 26))

Output:

47/42
41984711742427/15997372030584

The program takes essentially no time with these inputs (0.027s on my machine).

Answered By: Paul Hankin

The issue with your recursive implementation is that you are re-calculating f(b,r) again and again for same b and r.

To illustrate what I mean, you can run this snippet –

n = 0
def f(b,r):
    global n
    n += 1
    if b == 0:
        return 0
    elif r == 0:
        return b
    else:
        var = (b/(b+r)) * f(b-1, r) + (r/(b+r)) * f(b, r-1) 
        return max( b-r,  var )

for i in range(5, 12):
    n = 0
    f(i, i)
    print(f"Number of times function f gets called for f({i},{i}) - {n}")

Output:

Number of times function f gets called for f(5,5) - 503
Number of times function f gets called for f(6,6) - 1847
Number of times function f gets called for f(7,7) - 6863
Number of times function f gets called for f(8,8) - 25739
Number of times function f gets called for f(9,9) - 97239
Number of times function f gets called for f(10,10) - 369511
Number of times function f gets called for f(11,11) - 1410863

In python, an easy way to cache the data for top-down recursive function is using the builtin functools.lru_cache decorator

So updating the code to this –

from functools import lru_cache

@lru_cache
def f(b,r):
    if b == 0:
        return 0
    elif r == 0:
        return b
    else:
        var = (b/(b+r)) * f(b-1, r) + (r/(b+r)) * f(b, r-1) 
        return max( b-r,  var )

fixes the issue.

I can get the result for f(26,26) using above func in 41ms as 2.6244755489939244.

Repeating the same test as in the first example with our code having lru_cache results in –

Number of times function f gets called for f(5,5) - 35
Number of times function f gets called for f(6,6) - 13
Number of times function f gets called for f(7,7) - 15
Number of times function f gets called for f(8,8) - 17
Number of times function f gets called for f(9,9) - 19
Number of times function f gets called for f(10,10) - 21
Number of times function f gets called for f(11,11) - 23

The counts are lesser in higher values above is because we are not clearing the cache.

Answered By: Jay

Based on the comment regarding "memoization", adding this simple decorator @functools.lru_cache solved the issue.

import functools

@functools.lru_cache(maxsize=None)
def f(b,r):
    if b == 0:
        return 0
    elif r == 0:
        return b
    else:
        var = (b/(b+r)) * f(b-1, r) + (r/(b+r)) * f(b, r-1) 
        return max( b-r,  var )

I also did %timeit between recursion and pure dynamic programming solution provided by Paul Hankin, recursion seems to be a lot faster. 90.1 ns vs 6.1 ms.

Answered By: nyan314sn