Passing a shape to numpy.reshape in a numba njit environment fails, how can I create a suitable iterable for the target shape?

Question:

I have a function that takes in an array, performs an arbitrary calculation and returns a new shape in which it can be broadcasted.
I would like to use this function in a numba.njit environment:

import numpy as np
import numba as nb

@nb.njit
def generate_target_shape(my_array):
    ### some functionality that calculates the desired target shape ###
    return tuple([2,2])
    
@nb.njit
def test():
    my_array = np.array([1,2,3,4])
    target_shape = generate_target_shape(my_array)
    reshaped = my_array.reshape(target_shape)
    print(reshaped)
test()

However, tuple creation is not supported in numba and I get the following error message when trying to cast the result of generate_target_shape to a tuple with the tuple() operator:

No implementation of function Function(<class 'tuple'>) found for signature:
 
 >>> tuple(list(int64)<iv=None>)
 
There are 2 candidate implementations:
   - Of which 2 did not match due to:
   Overload of function 'tuple': File: numba/core/typing/builtins.py: Line 572.
     With argument(s): '(list(int64)<iv=None>)':
    No match.

During: resolving callee type: Function(<class 'tuple'>

If I try to change the return type of generate_target_shape from tuple to list or np.array, I receive the following error message:

Invalid use of BoundFunction(array.reshape for array(float64, 1d, C)) with parameters (array(int64, 1d, C))

Is there a way for me to create an iterable object inside a nb.njit function that can be passed to np.reshape?

EDIT: I worked around this problem as suggested in the accepted solution by using the objmode constructor.

Asked By: Yes

||

Answers:

It seems like the standard python function tuple() is not supported by numba. You can easily work around this issue by rewriting your code a litte bit:

import numpy as np
import numba as nb

@nb.njit
def generate_target_shape(my_array):
    ### some functionality that calculates the desired target shape ###
    a, b = [2, 2] # (this will also work if the list is a numpy array)
    return a, b

The general case however, is a lot trickier. I am going to backtrack on what i said in the comments: it is not possible or advisable to make a numba compiled function that works with tuples of many different sizes. Doing so would require you to recompile your function for every tuple of an unique size. @Jérôme Richard explains the problem very well in this stackoverflow answer.

What i would recommend that you do, is to simply take the array containing the shape, and your data, and calculate my_array.reshape(tuple(target_shape)) outside of your numba compiled function. It is not pretty, but it will allow you to continue with your project.

Answered By: Rafnus
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.