How to speed up this big additive for loop in Python with Numba?

Question:

I’m trying to speed up this function, which takes an array D of size (M,N,O), an array Pi of size (M,M) and arrays x_i, y_i, x_pi, y_pi of size (M,N,O) as inputs, and returns an array D_new of similar size.

@njit
def forward_policy_2d_njit(D, Pi, x_i, y_i, x_pi, y_pi):
    nZ, nX, nY = D.shape
    Dnew = np.zeros_like(D)

    for iz_next in range(nZ):
        for iz in range(nZ):
            for ix in range(nX):
                for iy in range(nY):
                    ixp = x_i[iz, ix, iy]
                    iyp = y_i[iz, ix, iy]
                    beta = x_pi[iz, ix, iy]
                    alpha = y_pi[iz, ix, iy]

                    Dnew[iz_next, ixp, iyp] += alpha * beta * Pi[iz, iz_next] * D[iz, ix, iy]
                    Dnew[iz_next, ixp+1, iyp] += alpha * (1 - beta) * Pi[iz, iz_next] * D[iz, ix, iy]
                    Dnew[iz_next, ixp, iyp+1] += (1 - alpha) * beta * Pi[iz, iz_next] * D[iz, ix, iy]
                    Dnew[iz_next, ixp+1, iyp+1] += (1 - alpha) * (1 - beta) * Pi[iz, iz_next] * D[iz, ix, iy]
    return Dnew

I have tried to @guvectorize it, but it does not give back the proper result. Do you know how I could speed it up, or make good use of guvectorize?

Edit: after the comment of Jérôme Richard, I add a running example of the code. The matrix Dnew should sum to approximatively 1 in all cases.

from numba import njit, guvectorize
import numpy as np 

nA = 60
nB = 40
nZ = 3

Pi = np.array([
    [0.5,0.5,0],
    [0,0.5,0.5],
    [0.5,0,0.5]
])

D = np.ones((nZ,nB,nA)) / (nZ*nB*nA)

x_pi = np.random.uniform(low = 0, high = 1, size = (nZ, nB, nA))
y_pi = np.random.uniform(low = 0, high = 1, size = (nZ, nB, nA))

x_i = np.random.randint(0, nB-1, (nZ, nB, nA))
y_i = np.random.randint(0, nA-1, (nZ, nB, nA))


@njit
def forward_policy_2d_njit(D, Pi, x_i, y_i, x_pi, y_pi):
    nZ, nX, nY = D.shape
    Dnew = np.zeros_like(D)

    for iz_next in range(nZ):
        for iz in range(nZ):
            for ix in range(nX):
                for iy in range(nY):
                    ixp = x_i[iz, ix, iy]
                    iyp = y_i[iz, ix, iy]
                    beta = x_pi[iz, ix, iy]
                    alpha = y_pi[iz, ix, iy]

                    Dnew[iz_next, ixp, iyp] += alpha * beta * Pi[iz, iz_next] * D[iz, ix, iy]
                    Dnew[iz_next, ixp+1, iyp] += alpha * (1 - beta) * Pi[iz, iz_next] * D[iz, ix, iy]
                    Dnew[iz_next, ixp, iyp+1] += (1 - alpha) * beta * Pi[iz, iz_next] * D[iz, ix, iy]
                    Dnew[iz_next, ixp+1, iyp+1] += (1 - alpha) * (1 - beta) * Pi[iz, iz_next] * D[iz, ix, iy]
    return Dnew

@guvectorize(['void(float64[:,:,:], float64[:,:], int64[:,:,:], int64[:,:,:], float64[:,:,:], float64[:,:,:], float64[:,:,:])'], '(i,j,k),(i,i),(i,j,k),(i,j,k),(i,j,k),(i,j,k)->(i,j,k)')
def forward_policy_2d_vec(D, Pi, x_i, y_i, x_pi, y_pi, Dnew):
    nZ, nX, nY = D.shape
    Dnew = np.zeros_like(D)

    for iz_next in range(nZ):
        for iz in range(nZ):
            for ix in range(nX):
                for iy in range(nY):
                    ixp = x_i[iz, ix, iy]
                    iyp = y_i[iz, ix, iy]
                    beta = x_pi[iz, ix, iy]
                    alpha = y_pi[iz, ix, iy]

                    Dnew[iz_next, ixp, iyp] += alpha * beta * Pi[iz, iz_next] * D[iz, ix, iy]
                    Dnew[iz_next, ixp+1, iyp] += alpha * (1 - beta) * Pi[iz, iz_next] * D[iz, ix, iy]
                    Dnew[iz_next, ixp, iyp+1] += (1 - alpha) * beta * Pi[iz, iz_next] * D[iz, ix, iy]
                    Dnew[iz_next, ixp+1, iyp+1] += (1 - alpha) * (1 - beta) * Pi[iz, iz_next] * D[iz, ix, iy]
                    
Dnew_jit = forward_policy_2d_njit(D, Pi, x_i, y_i, x_pi, y_pi)
Dnew_vec = forward_policy_2d_vec(D, Pi, x_i, y_i, x_pi, y_pi)

print(np.sum(Dnew_jit)) # Check: Dnew should be equal to 1
print(np.sum(Dnew_vec)) # Check: Dnew should be equal to 1
Asked By: Mr. Fafa

||

Answers:

This is hard to optimize this code. Indeed, the arrays are too small for a multithreading to be efficient, the accesses are mainly not contiguous so SIMD vectorization is not easily nor efficient. That being said, we can move the iz_next loop so to factorize operations, that is, not recomputing things nor reading the same array multiple times. Moreover, we can tell to the compiler that nZ is 3 using assertions so it can generate a faster code thanks to loop unrolling. If this is not always the case, you can write different functions with different assertions. The main code can be shared in another Numba function without impacting performance as long as the inner function is inlined (using the flag inline=always"). We can also tell to the compiler that arrays are contiguous in practice so to avoid additional indexing instructions being generated.

Here is the resulting code:

@njit(['(float64[:,:,::1], float64[:,::1], int32[:,:,::1], int32[:,:,::1], float64[:,:,::1], float64[:,:,::1])',
    '(float64[:,:,::1], float64[:,::1], int64[:,:,::1], int64[:,:,::1], float64[:,:,::1], float64[:,:,::1])'])
def forward_policy_2d_njit_opt(D, Pi, x_i, y_i, x_pi, y_pi):
    nZ, nX, nY = D.shape
    assert nZ == 3 
    Dnew = np.zeros_like(D)

    for iz in range(nZ):
        for ix in range(nX):
            for iy in range(nY):
                ixp = x_i[iz, ix, iy]
                iyp = y_i[iz, ix, iy]
                beta = x_pi[iz, ix, iy]
                alpha = y_pi[iz, ix, iy]
                D_value = D[iz, ix, iy]
                tmp_1 = alpha * beta * D_value
                tmp_2 = alpha * (1 - beta) * D_value
                tmp_3 = (1 - alpha) * beta * D_value
                tmp_4 = (1 - alpha) * (1 - beta) * D_value
                tmp_view_1 = Dnew[:, ixp, iyp]
                tmp_view_2 = Dnew[:, ixp+1, iyp]
                tmp_view_3 = Dnew[:, ixp, iyp+1]
                tmp_view_4 = Dnew[:, ixp+1, iyp+1]

                for iz_next in range(nZ):
                    Pi_value = Pi[iz, iz_next]
                    tmp_view_1[iz_next] += tmp_1 * Pi_value
                    tmp_view_2[iz_next] += tmp_2 * Pi_value
                    tmp_view_3[iz_next] += tmp_3 * Pi_value
                    tmp_view_4[iz_next] += tmp_4 * Pi_value
    return Dnew

This code is twice faster on my machine. It takes about 65 µs on the provided input array.

The loop can theoretically be slightly vectorized using SIMD instructions (along the iz_next axis) but Numba fail to do that, even when Dnew is transposed. In fact, this is apparently due to LLVM-JIT and more specifically to LLVM. Using a C/C++ code should not help (at least, not using Clang and probably not other compilers like GCC nor MSVC), unless you write C/C++ code using low-level SIMD intrinsics which is tedious, non-portable (dependent of the processor instruction set) and bug-prone. Moreover, one need to consider the overhead of transposing Dnew in the end. I expect this to be a bit faster but I am not sure it is worth it.

Answered By: Jérôme Richard
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.