Easy parallelization of numpy.apply_along_axis()?

Question:

How could the application of a function to the elements of a NumPy array through numpy.apply_along_axis() be parallelized so as to take advantage of multiple cores? This seems to be a natural thing to do, in the common case where all the calls to the function being applied are independent.

In my particular case—if this matters—, the axis of application is axis 0: np.apply_along_axis(func, axis=0, arr=param_grid) (np being NumPy).

I had a quick look at Numba, but I can’t seem to get this parallelization, with a loop like:

@numba.jit(parallel=True)
result = np.empty(shape=params.shape[1:])
for index in np.ndindex(*result.shape)):  # All the indices of params[0,...]
    result[index] = func(params[(slice(None),) + index])  # Applying func along axis 0

There is also apparently a compilation option in NumPy for parallelization through OpenMP, but it does not seem to be accessible through MacPorts.

One can also think of maybe cutting the array in a few pieces and using threads (so as to avoid copying the data) and applying the function on each piece in parallel. This is more complex than what I am looking for (and might not work if the Global Interpreter Lock is not released enough).

It would be very nice to be able to use multiple cores in a simple way for simple parallelizable tasks like applying a function to all the elements of an array (which is essentially what is needed here, with the small complication that function func() takes a 1D array of parameters).

Asked By: Eric O Lebigot

||

Answers:

Alright, I worked it out: an idea is to use the standard multiprocessing module and split the original array in just a few chunks (so as to limit communication overhead with the workers). This can be done relatively easily as follows:

import multiprocessing

import numpy as np

def parallel_apply_along_axis(func1d, axis, arr, *args, **kwargs):
    """
    Like numpy.apply_along_axis(), but takes advantage of multiple
    cores.
    """        
    # Effective axis where apply_along_axis() will be applied by each
    # worker (any non-zero axis number would work, so as to allow the use
    # of `np.array_split()`, which is only done on axis 0):
    effective_axis = 1 if axis == 0 else axis
    if effective_axis != axis:
        arr = arr.swapaxes(axis, effective_axis)

    # Chunks for the mapping (only a few chunks):
    chunks = [(func1d, effective_axis, sub_arr, args, kwargs)
              for sub_arr in np.array_split(arr, multiprocessing.cpu_count())]

    pool = multiprocessing.Pool()
    individual_results = pool.map(unpacking_apply_along_axis, chunks)
    # Freeing the workers:
    pool.close()
    pool.join()

    return np.concatenate(individual_results)

where the function unpacking_apply_along_axis() being applied in Pool.map() is separate as it should (so that subprocesses can import it), and is simply a thin wrapper that handles the fact that Pool.map() only takes a single argument:

def unpacking_apply_along_axis((func1d, axis, arr, args, kwargs)):
    """
    Like numpy.apply_along_axis(), but with arguments in a tuple
    instead.

    This function is useful with multiprocessing.Pool().map(): (1)
    map() only handles functions that take a single argument, and (2)
    this function can generally be imported from a module, as required
    by map().
    """
    return np.apply_along_axis(func1d, axis, arr, *args, **kwargs)

(in Python 3, the argument unpacking must be done manually:

def unpacking_apply_along_axis(all_args):
    """…"""
    (func1d, axis, arr, args, kwargs) = all_args
    …

because argument unpacking was removed).

In my particular case, this resulted in a 2x speedup on 2 cores with hyper-threading. A factor closer to 4x would have been nicer, but the speed up is already nice, in just a few lines of code, and it should be better for machines with more cores (which are quite common). Maybe there is a way of avoiding data copies and using shared memory (maybe through the multiprocessing module itself)?

Answered By: Eric O Lebigot