Vectorize a function in NumPy

Question:

I have the following function

from numpy.random import default_rng

def foo(args):
    [a, x, y, z, b1] = args
    vals = np.random.uniform(0, 10, a)
    rr= np.random.uniform(1, 2, a)
    u_1 = vals - x
    u_2 = vals * rr - y
    u_3 = vals / rr - z
    Q = sum(sum(u_1[None, :] > np.maximum(u_2[None, :], u_3[None, :])))
    # print(Q)
    if Q > b1:
        Q = 10
    return Q

args = [10, 2, 40, 1, 2]
print(foo(args))  # this works fine
x_ = [*range(5, 13, 1)]
y_ = [*range(2, 50, 5)]
z_ = [*range(4, 8, 1)]
x, y, z = np.meshgrid(x_, y_, z_, indexing="ij")
args = [10, x, y, z, 2]
print(foo(args))  # this does not work

I get the following error:

ValueError: operands could not be broadcast together with shapes (10,) (8,10,4)

I want to evaluate the function foo(args) at all the points of the meshgrid. How should I modify foo(args)?
I need a vectorization implementation of this operation as it needs to be very fast.

I already know the following answer: How do I apply some function to a python meshgrid?. I just think there might be a way to speed up the evaluation of the function.

Asked By: Dreamer93

||

Answers:

You need to understand the broadcasting rules. Use the np.expand_dims() like this.

    x, y, z = [np.expand_dims(e, axis=-1) for e in (x, y, z)]
    u_1 = vals - x
    u_2 = vals * rr - y
    u_3 = vals / rr - z
    Q = np.sum(u_1 > np.maximum(u_2, u_3), axis=-1)
    Q = np.where(Q > b1, 10, Q)

And I refactored unnecessary indexing like u_1[None, :] and sum() calls.

Answered By: relent95

You can use np.vectorize to vectorize your foo function. And use exlucded option to exclude the arguments for being vectorized. In your case, arguments a and b1 doesn’t need to be vectorized. So, you need to modify your foo first:

def foo(a, x, y, z, b1):
    # [a, x, y, z, b1] = args -> comment this line

Passing the list to the function using foo(*args) instead of foo(args). Then vectorize the foo function:

vec_func = np.vectorize(foo, excluded=['a', 'b1'])
args = [10, x, y, z, 2]
print(foo(*args))  # it works now
Answered By: Elliot Su
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.