Using Numba njit with np.array

Question:

I have two Python functions that I am trying to speed up with njit as they are impacting the performance of my program. Below is a MWE that reproduces the following error when we add the @njit(fastmath=True) decorator to
f. Otherwise it works. I believe the error is because the array A has dtype object. Can I use Numba to decorate f in addition to g? If not, what is the fastest way to map g to the elements of A? Roughly, the length of A = B ~ 5000. These functions are called around 500 MM times though as part of a hpc workflow.

@njit(fastmath=True)
def g(a, B):
    # some function of a and B
    return 19.12 / (len(a) + len(B))

def f(A, B):
    total = 0.0
    for i in range(len(B)):
        total += g(A[i], B)
    return total

A = [[2, 5], [4, 5, 6, 7], [0, 8], [6, 7], [1, 8], [0, 1], [1, 3], [1, 3], [2, 4]]
B = [1, 1, 1, 1, 1, 1, 1, 1, 1]

A = np.array([np.array(a, dtype=int) for a in A], dtype=object)
B = np.array(B, dtype=int)
    
f(A, B)

TypingError: Failed in nopython mode pipeline (step: nopython
frontend) non-precise type array(pyobject, 1d, C) During: typing of
argument at
/var/folders/9x/hnb8fg0x2p1c9p69p_70jnn40000gq/T/ipykernel_59724/1681580915.py
(8)

File
"../../../../var/folders/9x/hnb8fg0x2p1c9p69p_70jnn40000gq/T/ipykernel_59724/1681580915.py",
line 8: <source missing, REPL/exec in use?>

Asked By: AngusTheMan

||

Answers:

Can I use Numba to decorate f in addition to g?

No. You cannot use CPython objects in @njit-decorated Numba function. Numba is mainly fast because of native types (enabling the generation of a fast compiled code as opposed to an interpreted dynamic code).

If not, what is the fastest way to map g to the elements of A?

Jagged arrays are inefficient. In general, a fast solution to this problem is to use 2 arrays: one containing all the values and one containing the start-end range of value for each row (a bit like sparse matrices, but using ranges). Storing the length of each segment also works (and it is more compact) though the start-end ranges need a cumulated-sum which sometimes makes things more complex (eg. the dependencies prevent a straightforward parallelization).

Answered By: Jérôme Richard

To create the non-jagged array @Jérôme Richard mentions, we can do this:

# Imports.
import numpy as np
from numba import njit, prange

# Data.
A_list = [[2, 5], [4, 5, 6, 7], [0, 8], [6, 7], [1, 8], [0, 1], [1, 3], [1, 3], [2, 4]]
B_list = [1, 1, 1, 1, 1, 1, 1, 1, 1]

A_lenghts = np.array([len(sublist) for sublist in A_list])
dim1 = len(A_list)
dim2 = A_lenghts.max()
A = np.empty(shape=(dim1, dim2), dtype=int) # 9x4.
for i, (sublist, length) in enumerate(zip(A_list, A_lenghts)):
    A[i, :length][:] = sublist

B = np.array(B_list, dtype=int)
assert A.shape[0] == B.size

The array A will look something like this:

array([[      2,       5, xxxxxx, xxxxxx],
       [      4,       5,      6,      7],
       [      0,       8, xxxxxx, xxxxxx],
       [      6,       7, xxxxxx, xxxxxx],
       [      1,       8, xxxxxx, xxxxxx],
       [      0,       1, xxxxxx, xxxxxx],
       [      1,       3, xxxxxx, xxxxxx],
       [      1,       3, xxxxxx, xxxxxx],
       [      2,       4, xxxxxx, xxxxxx]])

The xxxxxx are random values that we get because we created the array with np.empty. This is why you keep A_lengths as a way to determine where data becomes irrelevant, for each line.

Back to the calculations, I just added the optimizations tof: the @njit(parallel=True) decorator and numba.prange instead of Python’s range.

# Calculations.
@njit(fastmath=True)
def g(a, b):
    return 19.12 / (len(a) + len(b))


@njit(parallel=True)
def f(A, B):
    total = 0.0
    for i in prange(len(B)):
        total += g(A[i], B)
    return total


# Test.
print(f(A, B))
Answered By: Guimoute
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.