numba.jit can’t compile np.roll

Question:

I’m trying to compile the "foo" function using jit

import numpy as np
from numba import jit

dy = 5
@jit
def foo(grid):
    return np.sum([np.roll(np.roll(grid, y, axis = 1), x, axis = 0)
                   for x in (-1, 0, 1) for y in (-1, 0, 1) if x or y], axis=0)


ex_grid = np.random.rand(5,5)>0.5
result = foo(ex_grid)

And I get the following error:

Compilation is falling back to object mode WITH looplifting enabled because Function "foo" failed type inference due to: Invalid use of Function(<function roll at 0x00000161E45C7D90>) with argument(s) of type(s): (array(bool, 2d, C), Literal[int](5), axis=Literal[int](1))
 * parameterized
In definition 0:
    TypeError: np_roll() got an unexpected keyword argument 'axis'

The function works, but the compilation fails.

How can I fix this error, Is np.roll compatible with numba, and if not, is there any alternative?

Asked By: rambi

||

Answers:

If you check the docs you’ll see that for np.roll only the two first arguments are supported, hence it will only perform the rolling on a flattened array (since you cannot specify an axis).

numpy.roll() (only the 2 first arguments; second argument shift must be an integer)

Note however that it does not really make sense to use numba here, since you’re performing a single vectorized operation, which will already run very fast. Numba would only make sense if you had to loop over the array to apply some logic.

So the only possible way to roll the rows of your array here using numba would be to loop over them:

@njit
def foo(a, dy):
    out = np.empty(a.shape, np.int32)
    for i in range(a.shape[0]):
        out[i] = np.roll(a[i], dy)
    return out

np.allclose(foo(ex_grid, 3).astype(bool), np.roll(ex_grid, 3, axis=1))
# True

Though as mentioned, this will be much slower than simply using np.roll setting axis=1, since this is already vectorized and all looping is done on C level:

ex_grid = np.random.rand(5000,5000)>0.5

%timeit foo(ex_grid, 3)
# 111 ms ± 820 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit np.roll(ex_grid, 1, axis=1)
# 13.8 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Answered By: yatu

You can use np.roll-equivalent parallel no-python numba scheme which will be faster than np.roll; Using np.roll in a loop is not a good choice when using numba jit and it will be much better to write the equivalent np.roll to be used by numba in loops. So, for axis=1:

@nb.njit(parallel=True)
def numba_(a, shf):
    b = np.empty_like(a)
    rows_num = a.shape[0]
    cols_num = a.shape[1]
    for i in nb.prange(rows_num):
        b[i, shf:] = a[i, :cols_num - shf]
        b[i, :shf] = a[i, cols_num - shf:]
    return b

For (5000, 5000) with different shf (horizontal axis): benchmark temporary link

enter image description here

The needed code in the question can be written in much faster way using numba, if it is needed yet. I tried in this example show the performance and capability of numba in this regard, just as an example; I will write the code if it is needed yet.

Answered By: Ali_Sh
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.